## **Import libs:**

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms.v2 as T

## **Check Gpu:**

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


##**Architecture:**

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, X):
        return self.net(X)


class Encoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.blocks = nn.ModuleList([DoubleConv(channels[i], channels[i+1]) for i in range (len(channels) - 1)])
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, X):
        skips = []
        for current_block in (self.blocks):
            X = current_block(X)
            skips.append(X)
            X = self.pool(X)

        return X, skips

class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.blocks = nn.ModuleList([DoubleConv(channels[i] * 2, channels[i+1]) for i in range (len(channels) - 1 )])
        self.up_conv = nn.ModuleList([nn.ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])

    def forward(self, X, skip_connection):
        for idx, (current_block, current_up_conv) in enumerate(zip(self.blocks, self.up_conv)):
            X = current_up_conv(X)
            X = torch.cat([X, skip_connection[-(idx + 1)]], dim=1)
            X = current_block(X)

        return X

class Unet(nn.Module):
    def __init__(self, nb_class):
        super(Unet, self).__init__()
        self.encoder = Encoder([3, 64, 128, 256])
        self.bottleneck = DoubleConv(256, 512)
        self.decoder = Decoder([512, 256, 128, 64])
        self.head = nn.Conv2d(64, nb_class, 1)

    def forward(self, X):
        X, skip_connection = self.encoder(X)
        X = self.bottleneck(X)
        X = self.decoder(X, skip_connection)
        X = self.head(X)
        return X


model= Unet(2)


##**DataGenerator:**

In [10]:
train_transform = T.Compose([
    T.Resize((512, 512)),
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomRotation(degrees = 15),
    T.ToImage(),
    T.ToDtype(torch.float32, scale=True),
])

In [9]:
class OxfordPetSegmentation(Dataset):
  def __init__(self, root, split = "trainval", transforms = None):
      self.dataset = OxfordIIITPet(root=root, download=True, split = split, target_types = ["segmentation", "category"])
      self.transforms = transforms

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      image, (mask, label) = self.dataset[idx]
      mask = np.array(mask)

      animal_pixels = mask = 1
      multiclass_mask = np.zeros_like(mask, dtype=np.uint8)

      if label == 0: # cat
        multiclass_mask[animal_pixels] = 1
      else: #dog
        multiclass_mask[animal_pixels] = 2

      # Transform into a PyTorch tensor:
      multiclass_mask = torch.form_numpy(multiclass_mask)
      if self.transforms:
          image, multiclass_mask = self.transforms(image, multiclass_mask)

      return image, multiclass_mask