In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, num_layers=4, base_filters=64, activation='relu', conv_dim=2):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_layers = num_layers
        self.base_filters = base_filters
        self.activation = activation
        self.conv_dim = conv_dim
        
        # Choose convolution and pooling layers based on the specified dimension
        if conv_dim == 2:
            self.Conv = nn.Conv2d
            self.ConvTranspose = nn.ConvTranspose2d
            self.MaxPool = nn.MaxPool2d
        elif conv_dim == 3:
            self.Conv = nn.Conv3d
            self.ConvTranspose = nn.ConvTranspose3d
            self.MaxPool = nn.MaxPool3d
        else:
            raise ValueError("conv_dim must be 2 or 3")

        # Encoder
        self.enc_convs = nn.ModuleList()
        self.pools = nn.ModuleList()
        for i in range(num_layers):
            in_f = in_channels if i == 0 else base_filters * (2 ** (i - 1))
            out_f = base_filters * (2 ** i)
            self.enc_convs.append(self.double_conv(in_f, out_f))
            self.pools.append(self.MaxPool(2))

        # Bottleneck
        self.bottleneck = self.double_conv(base_filters * (2 ** (num_layers - 1)), base_filters * (2 ** num_layers))

        # Decoder
        self.up_convs = nn.ModuleList()
        self.dec_convs = nn.ModuleList()
        for i in range(num_layers - 1, -1, -1):
            in_f = base_filters * (2 ** (i + 1))
            out_f = base_filters * (2 ** i)
            self.up_convs.append(self.ConvTranspose(in_f, out_f, kernel_size=2, stride=2))
            self.dec_convs.append(self.double_conv(in_f, out_f))

        # Final output layer
        self.final_conv = self.Conv(base_filters, out_channels, kernel_size=1)
    
    def double_conv(self, in_channels, out_channels):
        activation_func = self.get_activation(self.activation)
        return nn.Sequential(
            self.Conv(in_channels, out_channels, kernel_size=3, padding=1),
            activation_func(),
            self.Conv(out_channels, out_channels, kernel_size=3, padding=1),
            activation_func()
        )
    
    def get_activation(self, activation):
        if activation == 'relu':
            return nn.ReLU
        elif activation == 'leaky_relu':
            return nn.LeakyReLU
        elif activation == 'elu':
            return nn.ELU
        elif activation == 'sigmoid':
            return nn.Sigmoid
        elif activation == 'tanh':
            return nn.Tanh
        else:
            raise ValueError("Unsupported activation function")
    
    def forward(self, x):
        # Encoder
        enc_outs = []
        for i in range(self.num_layers):
            x = self.enc_convs[i](x)
            enc_outs.append(x)
            x = self.pools[i](x)

        # Bottleneck
        x = self.bottleneck(x)

        # Decoder
        for i in range(self.num_layers - 1, -1, -1):
            x = self.up_convs[i](x)
            x = torch.cat([x, enc_outs[i]], dim=1)
            x = self.dec_convs[i](x)

        x = self.final_conv(x)
        return x

# Example usage
# input_shape = (1, 128, 128)  # Adjust this according to your data
input_shape = (1, 512, 512)
in_channels = input_shape[0]
# out_channels = 1  # For binary segmentation
out_channels = 2
# num_layers = 4
num_layers = 3
# base_filters = 64
base_filters = 16
# activation = 'relu'  # 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh'
activation = 'leaky_relu'  # 'relu', 'leaky_relu', 'elu', 'sigmoid', 'tanh'
# conv_dim = 2  # 2 for 2D convolutions, 3 for 3D convolutions
conv_dim = 3  # 2 for 2D convolutions, 3 for 3D convolutions

model = UNet(in_channels, out_channels, num_layers, base_filters, activation, conv_dim)
print(model)


UNet(
  (enc_convs): ModuleList(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (3): LeakyReLU(negative_slope=0.01)
    )
  )
  (pools): ModuleList(
    (0-1): 2 x MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (bottleneck): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (1): LeakyReLU(negative_slope=0.01)
    (2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (3): LeakyReLU(negative_slope=0.01)
  )
 

In [6]:
!pip install torchvision

Collecting torchvision
  Downloading torchvision-0.18.0-cp310-cp310-manylinux1_x86_64.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting torch==2.3.0
  Downloading torch-2.3.0-cp310-cp310-manylinux1_x86_64.whl (779.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m779.1/779.1 MB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting nvidia-nccl-cu12==2.20.5
  Downloading nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl (176.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m176.2/176.2 MB[0m [31m11.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Collecting triton==2.3.0
  Downloading triton-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (168.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.1/168.1 MB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Insta

In [5]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import random

class NoisyImageDataset(Dataset):
    def __init__(self, image_dir, transform=None, noise_level=0.1):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, fname) for fname in os.listdir(image_dir) if fname.endswith(('.png', '.jpg', '.jpeg'))]
        self.transform = transform
        self.noise_level = noise_level

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert("RGB")
        
        if self.transform:
            image = self.transform(image)
        
        # Add noise
        noise = torch.randn_like(image) * self.noise_level
        noisy_image = image + noise
        noisy_image = torch.clamp(noisy_image, 0, 1)
        
        return noisy_image, image

# Define transformations
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

# Create the dataset
image_dir = 'path/to/your/images'
dataset = NoisyImageDataset(image_dir=image_dir, transform=transform, noise_level=0.1)

# Split the dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

# Create data loaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


ModuleNotFoundError: No module named 'torchvision'