### Set up google drive and unzip train data


In [None]:
# Set up google drive in google colab

from google.colab import drive
drive.mount('/content/drive')

In [2]:
# Unzip training data from drive

!unzip -q 'drive/My Drive/VOCdevkit.zip'

### Import Libraries

In [3]:
import random
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image,ImageFilter, ImageEnhance
import torch.optim as optim
import time

### Handle Training Data

In [4]:
img_dir = 'VOCdevkit/VOC2007/JPEGImages/'
out_dir = 'VOCdevkit/VOC2007/SegmentationClass/'

img_size = (128,128)


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [5]:
all_labels = []

for img in os.listdir(out_dir):
    img = Image.open(out_dir + img)
    img = np.array(img)
    all_labels += np.unique(img).tolist()


print(set(all_labels))

num_class = 3 # background and boundary and object 
print("number of classes excluding background and boundary - ", num_class)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 255}
number of classes excluding background and boundary -  3


### Data Augmentation

In [6]:

# Random blur on training image
def random_blur(img):
    if random.random() < 0.4:
        return img
    rad = random.choice([0.5,1,1.5,2])
    img = img.filter(ImageFilter.BoxBlur(radius=rad))
    return img

# Random brightness, contrast, satutration and hue
def random_color(img):
  if random.random() < 0.3:
    return img

  img = transforms.ColorJitter(brightness=(0.5,2.0), contrast=(0.5,2.0), saturation=(0.5,2.0), hue=(-0.25,0.25))(img)
  return img


# Random horizontal flip
def random_flip(img, out):
  if random.random() < 0.5:
    return img, out

  img = transforms.RandomHorizontalFlip(p=1)(img)
  out = transforms.RandomHorizontalFlip(p=1)(out)

  return img, out

# Random crop on image
def random_crop(img, out):
  if random.random() < 0.3:
    return img,out

  width, height = img.size
  select_w = random.uniform(0.7*width, width)
  select_h = random.uniform(0.7*height, height)

  start_x = random.uniform(0,width - select_w)
  start_y = random.uniform(0,height - select_h)

  left = start_x
  upper = start_y
  right = start_x + select_w
  bottom = start_y + select_h

  return img.crop((left, upper, right, bottom)) ,out.crop((left, upper, right, bottom))

### Dataset define

In [8]:
class pascal_voc_data(Dataset):
    def __init__(self, img_dir, out_dir,type_list, isTrain, transform):
        super().__init__()
        self.img_dir = img_dir
        self.out_dir = out_dir
        self.type_list = type_list
        self.isTrain = isTrain
        self.transform = transform

        self.img_names = []
        self.out_names = []

        for img in sorted(os.listdir(img_dir)):
            if img[:-4] in self.type_list:
                self.img_names.append(img)

        for img in sorted(os.listdir(out_dir)):
            if img[:-4] in self.type_list:
                self.out_names.append(img)

        self.img_names = [os.path.join(img_dir, img_name) for img_name in self.img_names]
        self.out_names = [os.path.join(out_dir, out_name) for out_name in self.out_names]

        assert (len(self.img_names) == len(self.out_names)), "Error - Input and output image size is different"

    
    def __len__(self):
        return len(self.img_names)

    def __getitem__(self, idx):
        img_name = self.img_names[idx]
        out_name = self.out_names[idx]

        
        img = Image.open(img_name)
        out = Image.open(out_name)

        if self.isTrain:
            img = random_blur(img)
            img = random_color(img)
            img,out = random_flip(img,out)
            img,out = random_crop(img,out)

        img = self.transform(img)

        out = np.array(transforms.Resize(img_size)(out))
        out[(out != 0 ) & (out != 255)] = 1   #Foreground is 1, background is 0
        out[(out == 255)] = 2 # Boundary is 2
        out = torch.IntTensor(out)

        return img,out

In [9]:
'''
While using pretrained models - 
Pytorch torchvision documentation - https://pytorch.org/docs/master/torchvision/models.html
The images have to be loaded in to a range of [0, 1] and then 
normalized using mean = [0.485, 0.456, 0.406] and std = [0.229, 0.224, 0.225]
'''

transform = transforms.Compose(
    [transforms.Resize(img_size),
     transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])


inv_normalize = transforms.Normalize(
   mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
   std=[1/0.229, 1/0.224, 1/0.225]
)

### Split into train and valid 

In [10]:
file = open('drive/My Drive/val.txt', "r")
valid_images = file.read().split('\n')
valid_images = valid_images[:-1]

train_images = []
for img in os.listdir('VOCdevkit/VOC2007/JPEGImages/'):
  if img[:-4] in valid_images:
    continue
  train_images.append(img[:-4])

print('train_images',len(train_images))
print('valid_images',len(valid_images))

train_images 209
valid_images 213


In [None]:
train_labels = []

for img in train_images:
    img = Image.open(out_dir + img + '.png')
    img = np.array(img)
    train_labels += np.unique(img).tolist()

train_labels = np.array(train_labels)
print(np.unique(train_labels, return_counts=True)[0])
print(np.unique(train_labels, return_counts=True)[1])


In [None]:
valid_labels = []

for img in valid_images:
    img = Image.open(out_dir + img + '.png')
    img = np.array(img)
    valid_labels += np.unique(img).tolist()

valid_labels = np.array(valid_labels)
print(np.unique(valid_labels, return_counts=True)[0])
print(np.unique(valid_labels, return_counts=True)[1])

In [14]:
train_dataset = pascal_voc_data(img_dir, out_dir, train_images,True,transform)
valid_dataset = pascal_voc_data(img_dir, out_dir, valid_images, False, transform)

### Visualize

In [15]:
def visualise(img, out):
    fig, (ax1, ax2,ax3) = plt.subplots(1, 3, figsize=(15, 5))
    transform_img = inv_normalize(img[0]).permute(1,2,0).to('cpu').numpy()
    transform_img = transform_img.copy()
    ax1.imshow(transform_img)
    ax2.imshow(out[0].to('cpu').numpy())
    ax3.imshow(transform_img)
    ax3.imshow(out[0].to('cpu').numpy(), alpha=0.5)
    plt.show()

### Network Definition

In [18]:
class UNET(nn.Module):
    def __init__(self):
        super(UNET, self).__init__()

        # Encoder

        # Input - 3*512*512
        self.conv1 = nn.Conv2d(3,64,3,1,1)  #64*512*512
        self.conv2 = nn.Conv2d(64,64,3,1,1) #64*512*512
        self.pool1 = nn.MaxPool2d(2,2) #64*256*256

        self.conv3 = nn.Conv2d(64,128,3,1,1) #128*256*256
        self.conv4 = nn.Conv2d(128,128,3,1,1) #128*256*256
        self.pool2 = nn.MaxPool2d(2,2) #128*128*128

        self.conv5 = nn.Conv2d(128,256,3,1,1) #256*128*128
        self.conv6 = nn.Conv2d(256,256,3,1,1) #256*128*128
        self.pool3 = nn.MaxPool2d(2,2) #256*64*64
        self.drop3 = nn.Dropout(p=0.3)

        self.conv7 = nn.Conv2d(256,512,3,1,1) #512*64*64
        self.conv8 = nn.Conv2d(512,512,3,1,1) #512*64*64
        self.pool4 = nn.MaxPool2d(2,2)  #512*32*32

        self.conv9 = nn.Conv2d(512,1024,3,1,1) #1024*32*32
        self.conv10 = nn.Conv2d(1024,1024,3,1,1) #1024*32*32
        self.drop5 = nn.Dropout(p=0.3)

        # Decoder
        
        self.TConv1 = nn.ConvTranspose2d(1024,512,2,2,0) #512*64*64
        #1024*64*64 - torch.cat conv8

        self.conv11 = nn.Conv2d(1024,512,3,1,1) #512*64*64
        self.conv12 = nn.Conv2d(512,512,3,1,1) #512*64*64

        self.TConv2 = nn.ConvTranspose2d(512,256,2,2,0) #256*128*128
        # 512*128*128 - torch.cat conv6
        self.conv13 = nn.Conv2d(512,256,3,1,1) #256*128*128
        self.conv14 = nn.Conv2d(256,256,3,1,1) #256*128*128
        self.drop7 = nn.Dropout(p=0.3)

        self.TConv3 = nn.ConvTranspose2d(256,128,2,2,0) #128*256*256
        # 256*256*256 - torch.cat conv4
        self.conv15 = nn.Conv2d(256,128,3,1,1) #128*256*256
        self.conv16 = nn.Conv2d(128,128,3,1,1) #128*256*256

        self.TConv4 = nn.ConvTranspose2d(128,64,2,2,0) #64*512*512
        # 128*512*512 - torch.cat conv2
        self.conv17 = nn.Conv2d(128,64,3,1,1) #64*512*512
        self.conv18 = nn.Conv2d(64,64,3,1,1) #64*512*512
        self.conv19 = nn.Conv2d(64,num_class,1,1,0) #3*512*512


    def forward(self,x):
        out1 = F.relu(self.conv2((F.relu(self.conv1(x)))))

        in2 = self.pool1(out1)
        out2 = F.relu(self.conv4(self.drop3(F.relu(self.conv3(in2)))))

        in3 = self.pool2(out2)
        out3 = F.relu(self.conv6((F.relu(self.conv5(in3)))))

        in4 = self.pool3(out3)
        out4 = F.relu(self.conv8((F.relu(self.conv7(in4)))))

        in5 = self.pool4(out4)
        in5 = F.relu(self.conv10(self.drop5(F.relu(self.conv9(in5)))))

        dout4 = F.relu(self.TConv1(in5))
        din4 = torch.cat([out4,dout4], dim=1)
        dout3 = F.relu(self.conv12((F.relu(self.conv11(din4)))))

        dout3 = F.relu(self.TConv2(dout3))
        din3 = torch.cat([out3,dout3],dim=1)
        dout2 = F.relu(self.conv14(self.drop7(F.relu(self.conv13(din3)))))

        dout2 = F.relu(self.TConv3(dout2))
        din2 = torch.cat([out2,dout2], dim=1) 
        dout1 = F.relu(self.conv16((F.relu(self.conv15(din2)))))

        dout1 = F.relu(self.TConv4(dout1))
        din1 = torch.cat([out1,dout1], dim=1)
        

        out = self.conv19(F.relu(self.conv18((F.relu(self.conv17(din1))))))


        return out




### Dice definition

In [19]:
def get_dice_per_img(preds, targets, eps = 1e-6):
    if ((preds.sum() + targets.sum()) == 0):
        return 1.0

    num = 2.0  * ((preds*targets).sum() + eps)
    den = (preds.sum() + targets.sum() + eps)

    return (num/den)


def get_dice(preds, targets_inp, eps = 1e-6):
    targets = targets_inp.clone()
    #preds = batch x nChannels x Height x Width
    #Target = batch x Height x Width

    targets = (targets[...,None] == torch.arange(num_class).to(device)).int() # batch x Height x Width x nChannels
    targets = targets.permute(0,3,1,2)  # batch x nChannels x Height x Width 

    dice = 0
    for cls in range(preds.shape[0]):
        cls_pred = preds[cls,:,:,:]
        cls_targ = targets[cls,:,:,:]

        dice +=  get_dice_per_img(cls_pred, cls_targ)

    return dice/(1.0 * preds.shape[0])

### Create new checkpoint or load from saved checkpoint

In [29]:
load_model = ''

net = UNET().to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.0001)

batch_size = 4
avg_loss_idx = 10
best_valid_score = 100000
epoch_start = 0
loss_hist = []
valid_hist = []
dice_hist = []

if load_model != '':
    print('loading model ... ')
    checkpoint = torch.load(load_model, map_location=device)
    net.load_state_dict(checkpoint['net_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    loss_hist = checkpoint['loss_hist']
    valid_hist = checkpoint['valid_hist']
    dice_hist = checkpoint['dice_hist']
    best_valid_score = checkpoint['best_valid_score']
    epoch_start =checkpoint['epoch_start']
    
    net.train()
    print('model loaded ...' )



### Training pipeline

In [None]:


for epoch in range(epoch_start,epoch_start+500):
    net.train()
    print('epoch -- ',epoch)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    running_train_loss = 0
    for train_idx, train_data in enumerate(train_loader):
        
        img, label = train_data

        pred = net(img.to(device))

        loss = criterion(pred,label.to(device).long())


        # zero the parameter gradients
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_train_loss += loss

        
        if((train_idx+1) % (avg_loss_idx) == 0):
            loss_hist.append(running_train_loss/avg_loss_idx)
            print(str(train_idx) + ' -- ' + str((running_train_loss/avg_loss_idx).item()))
            running_train_loss = 0
        

    visualise(img,torch.argmax(pred,axis=1))
    visualise(img,label)
    plt.plot(loss_hist)
    plt.show()

    print('--- Evaluate Valid Data ---')

    net.eval()
    valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
    total_valid_loss = 0
    for valid_idx, valid_data in enumerate(valid_loader):
        valid_img, valid_label = valid_data

        with torch.no_grad():
            valid_pred = net(valid_img.to(device))

        loss = criterion(valid_pred, valid_label.to(device).long())

        total_valid_loss += loss

    valid_loss = total_valid_loss/valid_idx

    valid_hist.append(valid_loss)

    visualise(valid_img,torch.argmax(valid_pred,axis=1))
    plt.plot(valid_hist)
    plt.show()


    print('--- Save Model ---')

    PATH = 'drive/My Drive/saved_models/current.pt'
    net.train()
    torch.save({
        'net_state_dict': net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss_hist':loss_hist,
        'valid_hist':valid_hist,
        'dice_hist':dice_hist,
        'best_valid_score':best_valid_score,
        'epoch_start':epoch
      }, PATH)
    

    
    
    if valid_loss < best_valid_score:
        print('--- Save Best Model ---')
        best_valid_score = valid_loss
        net.train()
        PATH = 'drive/My Drive/saved_models/best.pt'
        torch.save({
          'net_state_dict': net.state_dict(),
          'optimizer_state_dict': optimizer.state_dict(),
          'loss_hist':loss_hist,
          'valid_hist':valid_hist,
          'dice_hist':dice_hist,
          'best_valid_score':best_valid_score,
          'epoch_start':epoch
        }, PATH)

   