In [1]:
# Image Segmentation in Pytorch using UNet Model

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader as dataloader
import torchvision.models as models

# You'll need to install albumentations!
import albumentations as A
from albumentations.pytorch import ToTensorV2

import time
import os
import random
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm
from PIL import Image, ImageOps
import copy
import pandas as pd

from Trainer import ModelTrainer
from Datasets import CUB200

ModuleNotFoundError: No module named 'Trainer'

In [None]:
# The size of our mini batches
batch_size = 4

# How many itterations of our dataset
num_epochs = 32

# Optimizer learning rate
learning_rate = 1e-4

# You'll need to Download Dataset found here
# https://www.kaggle.com/datasets/wenewone/cub2002011
# Unzip and rename to cub_200
# Where to load/save the dataset from 
data_set_root = "../datasets/cub_200"

# What to resize our images to 
image_size = 128

In [None]:
start_from_checkpoint = False

save_dir = '../data/Models'
model_name = 'UNet_CUB'

In [None]:
# Set device to GPU_indx if GPU is avaliable
gpu_indx = 0
device = torch.device(gpu_indx if torch.cuda.is_available() else 'cpu')

In [None]:
# Only include the augmentations if you can use the v2 transforms that will augment 
# both the image and bounding boxes (you'll need to modify the dataset class too!)

train_transform = A.Compose([A.SmallestMaxSize(max_size=image_size),
                             A.RandomCrop(height=image_size, width=image_size),
                             A.HorizontalFlip(p=0.5),
                             A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5),
                             A.RGBShift(r_shift_limit=25, g_shift_limit=25, b_shift_limit=25, p=0.5),
                             A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.5),
                             A.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225]),
                            ToTensorV2()], 
                            bbox_params=A.BboxParams(format='coco',
                                                     min_area=0, min_visibility=0.0, 
                                                     label_fields=['class_labels']))

transform = A.Compose([A.SmallestMaxSize(max_size=image_size),
                       A.CenterCrop(height=image_size, width=image_size),
                       A.Normalize(mean=[0.485, 0.456, 0.406],
                                   std=[0.229, 0.224, 0.225]),
                       ToTensorV2()], 
                      bbox_params=A.BboxParams(format='coco',
                                               min_area=0, min_visibility=0.0, 
                                               label_fields=['class_labels']))

In [None]:
# nn.Module class that will return the IoU for a batch of outputs
class MaskIOU(nn.Module):

    def mask_intersection_over_union(self, pred_bbox, target_bbox):

        # compute the area of intersection rectangle
        interArea = (pred_bbox * target_bbox).sum(dim=[1, 2])

        area1 = pred_bbox.sum(dim=[1, 2])
        area2 = target_bbox.sum(dim=[1, 2])

        # compute the intersection over union by taking the intersection
        # area and dividing it by the sum of prediction + ground-truth
        # areas - the interesection area
        iou = interArea / (area1 + area2 - interArea + 1e-5)

        # return the intersection over union value
        return iou

    def forward(self, predictions, data):
        """
        data: list of data, index 0 is the input image index [0] is the target
        predictions: raw output of the model
        """
        
        pred_mask = predictions.argmax(1)
        target_mask = data[1].to(pred_mask.device)
        
        return self.mask_intersection_over_union(pred_mask, target_mask)

In [None]:
# Define our Datasets
# You'll need to download the dataset from Kaggle
# https://www.kaggle.com/datasets/wenewone/cub2002011
# Unzip it (and the directories it contains) into the datasets directory 
# and rename the top-level directory cub_200

train_data = CUB200(data_set_root, image_size=image_size, transform=train_transform, 
                    test_train=0, return_masks=True)

test_data = CUB200(data_set_root, image_size=image_size, transform=transform, 
                   test_train=1, return_masks=True)

# Split trainging data into train and validation set with 90/10% traning/validation split
validation_split = 0.9

n_train_examples = int(len(train_data)*validation_split)
n_valid_examples = len(train_data) - n_train_examples
train_data, valid_data = torch.utils.data.random_split(train_data, [n_train_examples, n_valid_examples],
                                                       generator=torch.Generator().manual_seed(42))

In [None]:
# Custom Unet
class UnetDown(nn.Module):
    def __init__(self, input_size, output_size):
        super(UnetDown, self).__init__()
        
        model = [nn.BatchNorm2d(input_size),
                 nn.ELU(),
                 nn.Conv2d(input_size, output_size, kernel_size=3, stride=1, padding=1),
                 nn.BatchNorm2d(output_size),
                 nn.ELU(),
                 nn.MaxPool2d(2),
                 nn.Conv2d(output_size, output_size, kernel_size=3, stride=1, padding=1)]
        
        self.model = nn.Sequential(*model)
        
    def forward(self, x):        
        return self.model(x)
      

class UnetUp(nn.Module):
    def __init__(self, input_size, output_size):
        super(UnetUp, self).__init__()

        model = [nn.BatchNorm2d(input_size),
                 nn.ELU(),
                 nn.Conv2d(input_size, output_size, kernel_size=3, stride=1, padding=1),
                 nn.BatchNorm2d(output_size),
                 nn.ELU(),
                 nn.Upsample(scale_factor=2, mode="nearest"),
                 nn.Conv2d(output_size, output_size, kernel_size=3, stride=1, padding=1)]
          
        self.model = nn.Sequential(*model)
            
    def forward(self, x):
        return self.model(x)
            
         
class Unet(nn.Module):
    def __init__(self, channels_in, channels_out=2):
        super(Unet, self).__init__()
        
        self.conv_in = nn.Conv2d(channels_in, 64, 
                                 kernel_size=3, stride=1, padding=1)   # H X W --> H X W
        
        self.down1 = UnetDown(64, 64)  #  H   X W   --> H/2 X W/2
        self.down2 = UnetDown(64, 128)  #  H/2 X W/2 --> H/4 X W/4
        self.down3 = UnetDown(128, 128)  #  H/4 X W/4 --> H/8 X W/8
        self.down4 = UnetDown(128, 256)  # H/8 X W/8 --> H/16 X W/16

        self.up4 = UnetUp(256, 128)  #    H/16 X W/16 --> H/8 X W/8
        self.up5 = UnetUp(128 * 2, 128)  # H/8 X W/8 --> H/4 X W/4
        self.up6 = UnetUp(128 * 2, 64)  # H/4 X W/4 --> H/2 X W/2
        self.up7 = UnetUp(64 * 2, 64)  # H/2 X W/2 --> H   X W
        
        self.conv_out = nn.Conv2d(64 * 2, channels_out, 
                                  kernel_size=3, stride=1, padding=1)  # H X W --> H X W

    def forward(self, x):
        x0 = self.conv_in(x)  # 16 x H x W
        
        x1 = self.down1(x0)  # 32 x H/2 x W/2
        x2 = self.down2(x1)  # 64 x H/4 x W/4
        x3 = self.down3(x2)  # 64 x H/8 x W/8
        x4 = self.down4(x3)  # 128 x H/16 x W/16

        # Bottle-neck --> 128 x H/16 x W/16

        x5 = self.up4(x4)  # 64 x H/8 x W/8
        
        x5_ = torch.cat((x5, x3), 1)  # 128 x H/8 x W/8
        x6 = self.up5(x5_)  # 32 x H/4 x W/4
        
        x6_ = torch.cat((x6, x2), 1)  # 64 x H/4 x W/4
        x7 = self.up6(x6_)  # 16 x H/2 x W/2
        
        x7_ = torch.cat((x7, x1), 1)  # 64 x H/2 x W/2
        x8 = self.up7(x7_)  # 16 x H x W
        
        x8_ = F.elu(torch.cat((x8, x0), 1))  # 32 x H x W        
        return self.conv_out(x8_)  # Co x H x W
        

In [None]:
unet = Unet(channels_in=3, channels_out=2)

In [None]:
model_trainer = ModelTrainer(model=unet.to(device), output_size=-1, device=device, 
                             loss_fun=nn.CrossEntropyLoss(), batch_size=batch_size, 
                             learning_rate=learning_rate, save_dir=save_dir, model_name=model_name,
                             eval_metric=MaskIOU(), start_from_checkpoint=start_from_checkpoint)

In [None]:
model_trainer.set_data(train_set=train_data, test_set=test_data, val_set=valid_data)

In [None]:
model_trainer.set_lr_schedule(optim.lr_scheduler.StepLR(model_trainer.optimizer, 
                                                        step_size=1, 
                                                        gamma=0.95))

In [None]:
plt.figure(figsize = (20,10))
images, mask, bbox, labels = next(iter(model_trainer.train_loader))
out = torchvision.utils.make_grid(images[0:16], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid((mask[0:16]).unsqueeze(1).float(), normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Lets see how many Parameter's our Model has!
num_params = 0
for param in model_trainer.model.parameters():
    num_params += param.flatten().shape[0]
print("This model has %d (approximately %d Million) Parameters!" % (num_params, num_params//1e6))

In [None]:
model_trainer.run_training(num_epochs=num_epochs)

In [None]:
print("The highest validation IoU was %.2f" %(model_trainer.best_valid_acc))

In [None]:
_ = plt.figure(figsize = (10,5))
_ = plt.plot(model_trainer.train_loss_logger)
_ = plt.title("Training Loss")

In [None]:
_ = plt.figure(figsize = (10,5))
_ = plt.plot(model_trainer.train_acc_logger, c = "y")
_ = plt.plot(model_trainer.val_acc_logger, c = "k")

_ = plt.title("Average IoU")
_ = plt.legend(["Training IoU", "Validation IoU"])

In [None]:
images, mask, bbox, labels = next(iter(model_trainer.test_loader))
model_trainer.eval()
with torch.no_grad():
    pred_out = model_trainer(images.to(device)).argmax(1).cpu()

In [None]:
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(images[0:16], normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid((mask[0:16]).unsqueeze(1).float(), normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
plt.figure(figsize = (20,10))
out = torchvision.utils.make_grid(pred_out[0:16].unsqueeze(1).float(), normalize=True)
_ = plt.imshow(out.numpy().transpose((1, 2, 0)))

In [None]:
# Call the evaluate function and pass the evaluation/test dataloader etc
test_acc = model_trainer.evaluate_model(train_test_val="test")
print("The Test Average IoU is: %.2f" %(test_acc))