In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from torchvision.transforms import transforms
from tqdm import tqdm

## Download the dataset

In [2]:
!git clone https://github.com/VikramShenoy97/Human-Segmentation-Dataset

Cloning into 'Human-Segmentation-Dataset'...
remote: Enumerating objects: 596, done.[K
remote: Total 596 (delta 0), reused 0 (delta 0), pack-reused 596 (from 1)[K
Receiving objects: 100% (596/596), 13.60 MiB | 34.99 MiB/s, done.
Resolving deltas: 100% (7/7), done.


## Prepare dataset and Dataloder

In [3]:
class SegmentDataset(Dataset):
  def __init__(self, image_dir, mask_dir, image_transform=None, mask_transform=None):
    super().__init__()

    self.image_dir = image_dir
    self.mask_dir = mask_dir
    self.image_transform = image_transform
    self.mask_transform = mask_transform

    self.valid_ext = {'.png', '.jpeg', '.jpg', '.gif'}
    self.images = [f for f in os.listdir(image_dir) if os.path.splitext(f)[1] in self.valid_ext]

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    image_path = os.path.join(self.image_dir, self.images[idx])
    name, ext = os.path.splitext(self.images[idx])
    mask_path = os.path.join(self.mask_dir, f'{name}.png')

    image = Image.open(image_path).convert('RGB')
    mask = Image.open(mask_path).convert('L')

    if self.image_transform:
      image = self.image_transform(image)
    if self.mask_transform:
      mask = self.mask_transform(mask)

    mask = (mask > 0.5).float()

    return image, mask

In [4]:
# Transforms object
transform = {
    'images': transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]),
    'mask': transforms.Compose([
        transforms.Resize((512, 512)),
        transforms.ToTensor()
    ])
}

In [5]:
# Create a Dataset
images_dir = '/content/Human-Segmentation-Dataset/Training_Images'
mask_dir = '/content/Human-Segmentation-Dataset/Ground_Truth'

train_dataset = SegmentDataset(images_dir, mask_dir, transform['images'], transform['mask'])

In [6]:
image, mask = next(iter(train_dataset))
print(image.shape, mask.shape)

torch.Size([3, 512, 512]) torch.Size([1, 512, 512])


In [7]:
os.cpu_count()

2

In [8]:
# Create a dataloader
train_loader = DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=os.cpu_count(),
    pin_memory=True
)

In [9]:
image, mask = next(iter(train_loader))
print(image.shape, mask.shape)

torch.Size([8, 3, 512, 512]) torch.Size([8, 1, 512, 512])


## UNet Architechture
![](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

## Build the architecture

In [10]:
class DoubleConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.features = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
        nn.ReLU(inplace=True)
    )

  def forward(self, x):
    return self.features(x)

In [11]:
# Down sample class or Encoder part
class DownSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.double_conv = DoubleConv(in_channels, out_channels)
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

  def forward(self, x):
    conv = self.double_conv(x)
    pooled = self.pool(conv)

    return conv, pooled

In [12]:
# Upsample or decoder part
class UpSample(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.up_sample = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2, padding=0)
    self.double_conv = DoubleConv(in_channels, out_channels)

  def forward(self, skip, pooled):
    pooled = self.up_sample(pooled)

    x = torch.cat([pooled, skip], dim=1)
    return self.double_conv(x)

In [13]:
## UNet model
class UNet(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()

    self.downsample_1 = DownSample(in_channels, 64)
    self.downsample_2 = DownSample(64, 128)
    self.downsample_3 = DownSample(128, 256)
    self.downsample_4 = DownSample(256, 512)

    self.bottle_neck = DoubleConv(512, 1024)

    self.upsample_1 = UpSample(1024, 512)
    self.upsample_2 = UpSample(512, 256)
    self.upsample_3 = UpSample(256, 128)
    self.upsample_4 = UpSample(128, 64)

    self.out = nn.Conv2d(64, num_classes, kernel_size=1)

  def forward(self, x):
    skip1, en_1 = self.downsample_1(x)
    skip2, en_2 = self.downsample_2(en_1)
    skip3, en_3 = self.downsample_3(en_2)
    skip4, en_4 = self.downsample_4(en_3)

    b = self.bottle_neck(en_4)

    dec_1 = self.upsample_1(skip4, b)
    dec_2 = self.upsample_2(skip3, dec_1)
    dec_3 = self.upsample_3(skip2, dec_2)
    dec_4 = self.upsample_4(skip1, dec_3)

    out = self.out(dec_4)
    return out

### Test with model outputs

In [14]:
test_model = UNet(in_channels=3, num_classes=1)

In [15]:
image, mask = next(iter(train_dataset))
image, mask = image.unsqueeze(0), mask.unsqueeze(0)

print(image.shape, mask.shape)

torch.Size([1, 3, 512, 512]) torch.Size([1, 1, 512, 512])


In [16]:
outputs = test_model(image)
outputs.shape

torch.Size([1, 1, 512, 512])

## Define custom loss function

In [17]:
class DiceLoss(nn.Module):
  def __init__(self, smooth=1e-6):
    super().__init__()
    self.smooth = smooth

  def forward(self, outputs, targets):
    outputs_flat = outputs.view(-1)
    targets_flat = targets.view(-1)

    intersection = 2 * ((outputs_flat * targets_flat).sum() + self.smooth)
    union = outputs_flat.sum() + targets_flat.sum() + self.smooth

    return 1 - (intersection / union)

class BCEwithDiceLoss(nn.Module):
  def __init__(self):
    super().__init__()

    self.bce = nn.BCEWithLogitsLoss()
    self.dice = DiceLoss()

  def forward(self, outputs, targets):
    bce_loss = self.bce(outputs, targets)

    outputs_prob = torch.sigmoid(outputs)
    dice_loss = self.dice(outputs_prob, targets)

    return 0.5 * bce_loss + dice_loss

## Prepare for model training

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [19]:
model = UNet(in_channels=3, num_classes=1).to(device)

In [20]:
criterion = BCEwithDiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [21]:
# Training loop
epochs = 40

for epoch in range(epochs):
  epoch_loss = 0.0
  model.train()
  print(f'Epochs: [{epoch + 1}/{epochs}]')
  for images, masks in tqdm(train_loader, desc='Training...'):
    images, masks = images.to(device), masks.to(device)
    optimizer.zero_grad()
    outputs = model(images)
    loss = criterion(outputs, masks)
    loss.backward()
    optimizer.step()

    epoch_loss += loss.item()
  t_loss = epoch_loss / len(train_loader)
  print(f'Loss: {t_loss}\n')

model_path = f'/content/drive/MyDrive/Computer_vision/model/unet-e{epochs}.pth'
torch.save(model.state_dict(), model_path)
print(f'Model saved successfully at {model_path}')

Epochs: [1/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 1.0840602823205896

Epochs: [2/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.51s/it]


Loss: 0.9197878740929268

Epochs: [3/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.8768935090786701

Epochs: [4/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.8218710760812502

Epochs: [5/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.821382332492519

Epochs: [6/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.8098406308406109

Epochs: [7/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.8113910620276993

Epochs: [8/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.7763610280848838

Epochs: [9/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.54s/it]


Loss: 0.8766430278082151

Epochs: [10/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.53s/it]


Loss: 0.9077884313222524

Epochs: [11/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 0.8065987777065586

Epochs: [12/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 0.7741387067614375

Epochs: [13/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 1.0645109946663316

Epochs: [14/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 0.868608524670472

Epochs: [15/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 1.2154983346526687

Epochs: [16/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.52s/it]


Loss: 0.9551028335416639

Epochs: [17/40]


Training...: 100%|██████████| 37/37 [00:56<00:00,  1.51s/it]


Loss: 0.8623671257818067

Epochs: [18/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.801344185262113

Epochs: [19/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.800298665020917

Epochs: [20/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7690222698288995

Epochs: [21/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7508769808588801

Epochs: [22/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7620666639224903

Epochs: [23/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7757325961783126

Epochs: [24/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7780594600213541

Epochs: [25/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7523463046228563

Epochs: [26/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7422319521775117

Epochs: [27/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7410471342705391

Epochs: [28/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.74975194963249

Epochs: [29/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.759456362273242

Epochs: [30/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7430727304639043

Epochs: [31/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7358468722652745

Epochs: [32/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7465294567314354

Epochs: [33/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7287348138319479

Epochs: [34/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.729543606977205

Epochs: [35/40]


Training...: 100%|██████████| 37/37 [00:55<00:00,  1.49s/it]


Loss: 0.718365000711905

Epochs: [36/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7153560470890354

Epochs: [37/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7099773319991859

Epochs: [38/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7154588071075646

Epochs: [39/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7122282385826111

Epochs: [40/40]


Training...: 100%|██████████| 37/37 [00:54<00:00,  1.48s/it]


Loss: 0.7091305497530345

Model saved successfully at /content/drive/MyDrive/Computer_vision/model/unet-e40.pth


## Model Inferencing