## **Import libs:**

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.tv_tensors import Image, Mask
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms.v2 as T

In [2]:
EPOCHS = 100
BATCHSIZE = 4
LR = 1e-4

## **Check Gpu:**

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


##**Architecture:**

In [4]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, X):
        return self.net(X)


class Encoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.blocks = nn.ModuleList([DoubleConv(channels[i], channels[i+1]) for i in range (len(channels) - 1)])
        self.pool = nn.MaxPool2d(2, 2)

    def forward(self, X):
        skips = []
        for current_block in (self.blocks):
            X = current_block(X)
            skips.append(X)
            X = self.pool(X)

        return X, skips

class Decoder(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.blocks = nn.ModuleList([DoubleConv(channels[i + 1] * 2, channels[i+1]) for i in range (len(channels) - 1 )])
        self.up_conv = nn.ModuleList([nn.ConvTranspose2d(channels[i], channels[i + 1], 2, 2) for i in range(len(channels) - 1)])

    def forward(self, X, skip_connection):
        for idx, (current_block, current_up_conv) in enumerate(zip(self.blocks, self.up_conv)):
            X = current_up_conv(X)
            X = torch.cat([X, skip_connection[-(idx + 1)]], dim=1)
            X = current_block(X)

        return X

class Unet(nn.Module):
    def __init__(self, nb_class):
        super(Unet, self).__init__()
        self.encoder = Encoder([3, 64, 128, 256])
        self.bottleneck = DoubleConv(256, 512)
        self.decoder = Decoder([512, 256, 128, 64])
        self.head = nn.Conv2d(64, nb_class, 1)

    def forward(self, X):
        X, skip_connection = self.encoder(X)
        X = self.bottleneck(X)
        X = self.decoder(X, skip_connection)
        X = self.head(X)
        return X


model= Unet(3).to(device)
print(model)


Unet(
  (encoder): Encoder(
    (blocks): ModuleList(
      (0): DoubleConv(
        (net): Sequential(
          (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (5): ReLU(inplace=True)
        )
      )
      (1): DoubleConv(
        (net): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=Tr

##**DataGenerator:**

In [5]:
train_transform = T.Compose([
    T.Resize((512, 512)),
    T.RandomHorizontalFlip(p = 0.5),
    T.RandomRotation(degrees = 15),
])

val_transform = T.Compose([
    T.Resize((512, 512)),
])

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import OxfordIIITPet
import torchvision.transforms.v2 as T

class OxfordPetSegmentation(Dataset):
  def __init__(self, root, split = "trainval", transforms = None):
      self.dataset = OxfordIIITPet(root=root, download=True, split = split, target_types = ["segmentation", "category"])
      self.transforms = transforms
      self.classes = self.dataset.classes

  def __len__(self):
      return len(self.dataset)

  def __getitem__(self, idx):
      image, (mask, label) = self.dataset[idx]

      mask_np = np.array(mask)
      animal_pixels = (mask_np == 1) | (mask_np == 2)
      multiclass_mask = np.zeros_like(mask_np, dtype=np.uint8)

      breed_name = self.classes[label].lower()
      if "cat" in breed_name:
        multiclass_mask[animal_pixels] = 1
      else:
        multiclass_mask[animal_pixels] = 2

      image_dp = Image(image)
      mask_dp = Mask(multiclass_mask)

      # Apply geometric transforms to datapoints
      if self.transforms:
          # Transforms that work on (dp.Image, dp.Mask) will keep them synchronized
          image_dp, mask_dp = self.transforms(image_dp, mask_dp)

      image_tensor = T.ToDtype(torch.float32, scale=True)(image_dp)
      mask_tensor = torch.as_tensor(mask_dp, dtype=torch.long)

      return image_tensor, mask_tensor

In [16]:
train_datset = OxfordPetSegmentation(root = "data", split = "trainval", transforms = train_transform)
val_datset = OxfordPetSegmentation(root = "data", split = "test", transforms = val_transform)

In [17]:
img, mask = train_datset[0]
print(img.shape)    # [3, 512, 512]
print(mask.shape)   # [512, 512]
print(torch.unique(mask))  # tensor([0, 1, 2])


torch.Size([3, 512, 512])
torch.Size([512, 512])
tensor([0, 2])


In [18]:
train_loader = DataLoader(train_datset, batch_size = BATCHSIZE, shuffle = True, num_workers= 4)
val_loader = DataLoader(val_datset, batch_size = BATCHSIZE, shuffle = True, num_workers= 4)

##**Training:**

In [19]:
# Loss: CrossEntropyLoss
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = LR)

In [20]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0.0

    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)

        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(dataloader)

In [21]:
@torch.no_grad()
def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0

    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device)

        outputs = model(images)
        loss = criterion(outputs, masks)

        running_loss += loss.item()
    return running_loss / len(dataloader)

In [23]:
for epoch in range(EPOCHS):
    train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device)
    val_loss = validate_one_epoch(model, val_loader, criterion, device)
    print(f"Epoch {epoch+1}/{EPOCHS} - Train Loss: {train_loss:.4f} - Val Loss: {val_loss:.4f})")

torch.save(model.state_dict(), "model.pth")

model.load_state_dict(torch.load("model.pth"))
model.eval()

KeyboardInterrupt: 