# Semantic Segmentation with PyTorch

Mount google drive to colab.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# if you mount Google drive correctly, the following commands should be able to executed correctly
!ls /content/drive/
%cd "/content/drive/My Drive/CamVid"
!ls

Import neccessary libraties and set parameters.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models

import numpy as np
import time
import os

from PIL import Image
import pandas as pd

In [None]:
# dataset path
root_dir   = "/content/drive/My Drive/CamVid/"
train_file = os.path.join(root_dir, "train.csv")
val_file   = os.path.join(root_dir, "val.csv")

print("training csv exits:{}".format(os.path.exists(train_file)))
print("validation csv exits:{}".format(os.path.exists(val_file)))

# Create folder to store training results.
val_dir = "/content/drive/My Drive/segmentation_output/"
if os.path.isdir(val_dir) == False:
   os.mkdir(val_dir)

# Parameters
num_class = 11 # 32 for original CamVid
input_h, input_w = 256, 256
batch_size = 16
epochs = 40
lr = 1e-4
use_gpu = torch.cuda.is_available()

# index for validation images
global_index = 0

# pixel accuracy and mIOU list 
pixel_acc_list = []
mIOU_list = []

## CamVid Dataset

In [None]:
class CamVidDataset(Dataset):
    def __init__(self, csv_file, n_class=num_class, flip_rate=0.5, rand_crop=True):
        self.data = pd.read_csv(csv_file)
        self.n_class = n_class
        self.new_h = input_h
        self.new_w = input_w
        self.flip_rate = flip_rate  
        self.rand_crop = rand_crop

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # open image data
        img_name   = self.data.iloc[idx, 0]                
        img_name = root_dir  + img_name                        
        img = Image.open(img_name).convert('RGB')
        
        # open label data
        label_name = self.data.iloc[idx, 1]        
        label_name = root_dir  + label_name                       
        label_image = Image.open(label_name)
        
        # crop images and labels
        w, h = img.size
        if self.rand_crop:            
            A_x_offset = np.int32(np.random.randint(0, w - self.new_w + 1, 1))[0]
            A_y_offset = np.int32(np.random.randint(0, h - self.new_h + 1, 1))[0]
        else:            
            A_x_offset = int((w - self.new_w)/2)
            A_y_offset = int((h - self.new_h)/2)
       
        img = img.crop((A_x_offset, A_y_offset, A_x_offset + self.new_w, A_y_offset + self.new_h)) # left, top, right, bottom
        label_image = label_image.crop((A_x_offset, A_y_offset, A_x_offset + self.new_w, A_y_offset + self.new_h)) # left, top, right, bottom

        # flip images and labels
        img = np.transpose(img, (2, 0, 1)) / 255.
        label = np.asarray(label_image)
        if np.random.sample() < self.flip_rate:
            img = np.fliplr(img)
            label = np.fliplr(label)

        # create tensor
        img = torch.from_numpy(img.copy()).float()
        label = torch.from_numpy(label.copy()).long()

        # create one-hot encoding tensor
        h, w = label.size()
        target = torch.zeros(self.n_class, h, w)
        for c in range(self.n_class):
            target[c][label == c] = 1

        sample = {'X': img, 'Y': target, 'l': label}
        return sample

# Load dataset
train_data = CamVidDataset(csv_file=train_file, flip_rate=0.5, rand_crop=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)
val_data = CamVidDataset(csv_file=val_file, flip_rate=0, rand_crop=False)
val_loader = DataLoader(val_data, batch_size=1, num_workers=8)

## Network Model
### VGG16 Feature Extractor (pretrained)

In [None]:
class Vgg16(nn.Module):
    def __init__(self, pretrained = True):
        super(Vgg16, self).__init__()
        self.vggnet = models.vgg16(pretrained)
        del(self.vggnet.classifier) # Remove fully connected layer to save memory.
        features = list(self.vggnet.features)
        self.layers = nn.ModuleList(features).eval() 
        
    def forward(self, x):
        results = []
        for ii,model in enumerate(self.layers):
            x = model(x)
            if ii in [3,8,15,22,29]:
                results.append(x) #(64,256,256),(128,128,128),(256,64,64),(512,32,32),(512,16,16)
        return results

vgg_model = Vgg16()
vgg_model = vgg_model.cuda()
print(vgg_model.layers)

### Encoder-Decoder

In [None]:
class DeConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, dilation):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation)

    def forward(self, x):
        output = self.up(x)
        output = self.conv(output)
        return output

class EncoderDecoder(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)

        self.deconv1 = DeConv2d(512, 512, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(512)
        
        self.deconv2 = DeConv2d(512, 256, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(256)
        
        self.deconv3 = DeConv2d(256, 128, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.deconv4 = DeConv2d(128, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)[4]
        output = self.bn1(self.relu(self.deconv1(output)))
        output = self.bn2(self.relu(self.deconv2(output)))
        output = self.bn3(self.relu(self.deconv3(output)))
        output = self.bn4(self.relu(self.deconv4(output)))
        output = self.classifier(output)
        return output

### Fully Convolution Network (FCN)


In [None]:
class FCN(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)

        self.deconv1 = DeConv2d(512, 512, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(512)
        
        self.deconv2 = DeConv2d(512, 256, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(256)
        
        self.deconv3 = DeConv2d(256, 128, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.deconv4 = DeConv2d(128, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x0, x1, x2, x3, x4 = output

        output = self.bn1(self.relu(self.deconv1(x4)))
        output += x3
        output = self.bn2(self.relu(self.deconv2(output)))
        output += x2
        output = self.bn3(self.relu(self.deconv3(output)))
        output += x1
        output = self.bn4(self.relu(self.deconv4(output)))
        output += x0
        output = self.classifier(output)
        return output

### U-Net

In [None]:
class UNet(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)

        self.deconv1 = DeConv2d(512, 512, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(512)
        
        self.deconv2 = DeConv2d(512*2, 256, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(256)
        
        self.deconv3 = DeConv2d(256*2, 128, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.deconv4 = DeConv2d(128*2, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64*2, n_class, kernel_size=1)
    
    def forward(self, x):
        output = self.pretrained_net(x)
        x0, x1, x2, x3, x4 = output

        output = self.bn1(self.relu(self.deconv1(x4)))
        output = torch.cat([output, x3], dim=1)
        output = self.bn2(self.relu(self.deconv2(output)))
        output = torch.cat([output, x2], dim=1)
        output = self.bn3(self.relu(self.deconv3(output)))
        output = torch.cat([output, x1], dim=1)
        output = self.bn4(self.relu(self.deconv4(output)))
        output = torch.cat([output, x0], dim=1)
        output = self.classifier(output)
        return output

### PSPNet

In [None]:
class PPM(nn.Module):
    def __init__(self, in_channel, out_channel, kernel_size, stride, padding, dilation):
        super().__init__()
        # Pyramid Pooling Moudule
        self.relu = nn.ReLU(inplace=True)
        self.ppm_size = (16,16)
        self.ppm_channel = 512
        self.ppm_psize = [1,2,3,6]
        self.ppm_pool, self.ppm_conv, self.ppm_up = [], [], []

        for psize in self.ppm_psize:
            self.ppm_pool.append(nn.AdaptiveAvgPool2d((psize,psize)))
            self.ppm_conv.append(nn.Conv2d(int(self.ppm_channel), int(self.ppm_channel/len(self.ppm_psize)), kernel_size=1))
            self.ppm_up.append(nn.Upsample(size=self.ppm_size, mode='bilinear', align_corners=True))

        self.ppm_pool = nn.ModuleList(self.ppm_pool)
        self.ppm_conv = nn.ModuleList(self.ppm_conv)
        self.ppm_up = nn.ModuleList(self.ppm_up)

    def forward(self, x):
        ppm_list = [x]
        for i in range(len(self.ppm_psize)):
            output = self.ppm_pool[i](x)
            output = self.ppm_conv[i](output)
            output = self.ppm_up[i](self.relu(output))
            ppm_list.append(output)
            output = torch.cat(ppm_list, 1)
        return output

class PSPNet(nn.Module):
    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)

        self.ppm = PPM(512, 1024, kernel_size=3, stride=1, padding=1, dilation=1)

        self.deconv1 = DeConv2d(1024, 512, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn1 = nn.BatchNorm2d(512)
        
        self.deconv2 = DeConv2d(512, 256, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn2 = nn.BatchNorm2d(256)
        
        self.deconv3 = DeConv2d(256, 128, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        self.deconv4 = DeConv2d(128, 64, kernel_size=3, stride=1, padding=1, dilation=1)
        self.bn4 = nn.BatchNorm2d(64)
        
        self.classifier = nn.Conv2d(64, n_class, kernel_size=1)

    def forward(self, x):
        output = self.pretrained_net(x)
        x0, x1, x2, x3, x4 = output

        output = self.ppm(x4)
        output = self.bn1(self.relu(self.deconv1(output)))
        output += x3
        output = self.bn2(self.relu(self.deconv2(output)))
        output += x2
        output = self.bn3(self.relu(self.deconv3(output)))
        output += x1
        output = self.bn4(self.relu(self.deconv4(output)))
        output += x0
        output = self.classifier(output)
        return output

# Construct models.

In [None]:
# seg_model = EncoderDecoder(pretrained_net=vgg_model, n_class=num_class)
# seg_model = FCN(pretrained_net=vgg_model, n_class=num_class)
# seg_model = UNet(pretrained_net=vgg_model, n_class=num_class)
seg_model = PSPNet(pretrained_net=vgg_model, n_class=num_class)

seg_model = seg_model.cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(seg_model.parameters(), lr=lr)

# Training and Validation

In [None]:
def train():
    for epoch in range(epochs):
        ts = time.time()
        for iter, batch in enumerate(train_loader):
            optimizer.zero_grad()
            inputs = torch.FloatTensor(batch['X'])
            labels = torch.FloatTensor(batch['Y'])
            if use_gpu:
              inputs = inputs.cuda()
              labels = labels.cuda()

            outputs = seg_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            if iter % 10 == 0:
                print("epoch:{:2}, iter:{:2}, loss: {:.4f}".format(epoch, iter, loss.data.item()))
        
        print("Finish epoch:{:2}, time elapsed: {:.4f}".format(epoch, time.time() - ts))
        validate()
        print("========================================")
        
    highest_pixel_acc = max(pixel_acc_list)
    highest_mIOU = max(mIOU_list)        
    
    highest_pixel_acc_epoch = pixel_acc_list.index(highest_pixel_acc)
    highest_mIOU_epoch = mIOU_list.index(highest_mIOU)
    
    print("The highest mIOU is {} and is achieved at epoch-{}".format(highest_mIOU, highest_mIOU_epoch))
    print("The highest pixel accuracy  is {} and is achieved at epoch-{}".format(highest_pixel_acc, highest_pixel_acc_epoch))

In [None]:
def validate():
    seg_model.eval()
    total_ious = []
    pixel_accs = []
                    
    for iter, batch in enumerate(val_loader): ## batch is 1 in this case
        inputs = torch.FloatTensor(batch['X'])
        if use_gpu:
          inputs = inputs.cuda()      

        output = seg_model(inputs)                                
        
        # only save the 1st image for comparison
        if iter == 0:
            # generate images
            images = output.data.max(1)[1].cpu().numpy()[:,:,:]
            image = images[0,:,:]        
            save_result(batch['X'], image)
                            
        output = output.data.cpu().numpy()

        N, _, h, w = output.shape                
        pred = output.transpose(0, 2, 3, 1).reshape(-1, num_class).argmax(axis=1).reshape(N, h, w)        
        target = batch['l'].cpu().numpy().reshape(N, h, w)

        for p, t in zip(pred, target):
            total_ious.append(iou(p, t))
            pixel_accs.append(pixel_acc(p, t))

    # Calculate average IoU
    total_ious = np.array(total_ious).T  # n_class * val_len
    ious = np.nanmean(total_ious, axis=1)
    pixel_accs = np.array(pixel_accs).mean()
    print("pix_acc: {:.4f}, meanIoU: {:.4f}".format(pixel_accs, np.nanmean(ious)))
    
    global pixel_acc_list
    global mIOU_list
    
    pixel_acc_list.append(pixel_accs)
    mIOU_list.append(np.nanmean(ious))

# Calculates class intersections over unions
def iou(pred, target):
    ious = []
    for cls in range(num_class):
        pred_inds = pred == cls
        target_inds = target == cls
        intersection = pred_inds[target_inds].sum()
        union = pred_inds.sum() + target_inds.sum() - intersection
        if union == 0:
            ious.append(float('nan'))  # if there is no ground truth, do not include in evaluation
        else:
            ious.append(float(intersection) / max(union, 1))
    return ious

def pixel_acc(pred, target):
    correct = (pred == target).sum()
    total   = (target == target).sum()
    return correct / total     

def save_result(input_np, output_np):
    global global_index
    
    original_im_RGB = np.zeros((256,256,3))    
    original_im_RGB[:,:,0] = input_np[0,0,:,:]    
    original_im_RGB[:,:,1] = input_np[0,1,:,:]
    original_im_RGB[:,:,2] = input_np[0,2,:,:]
        
    original_im_RGB[:,:,0] = original_im_RGB[:,:,0] 
    original_im_RGB[:,:,1] = original_im_RGB[:,:,1] 
    original_im_RGB[:,:,2] = original_im_RGB[:,:,2] 
        
    original_im_RGB[:,:,0] = original_im_RGB[:,:,0]*255.0
    original_im_RGB[:,:,1] = original_im_RGB[:,:,1]*255.0
    original_im_RGB[:,:,2] = original_im_RGB[:,:,2]*255.0
    
    im_seg_RGB = np.zeros((256,256,3))

    # the following version is designed for 11-class version and could still work if the number of classes is fewer.
    for i in range(256):
        for j in range(256):
            if output_np[i,j] == 0:
                im_seg_RGB[i,j,:] = [128, 128, 128]
            elif output_np[i,j] == 1:  
                im_seg_RGB[i,j,:] = [128, 0, 0]
            elif output_np[i,j] == 2:  
                im_seg_RGB[i,j,:] = [192, 192, 128]    
            elif output_np[i,j] == 3:  
                im_seg_RGB[i,j,:] = [128, 64, 128]    
            elif output_np[i,j] == 4:  
                im_seg_RGB[i,j,:] = [0, 0, 192]    
            elif output_np[i,j] == 5:  
                im_seg_RGB[i,j,:] = [128, 128, 0]    
            elif output_np[i,j] == 6:  
                im_seg_RGB[i,j,:] = [192, 128, 128]    
            elif output_np[i,j] == 7:  
                im_seg_RGB[i,j,:] = [64, 64, 128]    
            elif output_np[i,j] == 8:  
                im_seg_RGB[i,j,:] = [64, 0, 128]    
            elif output_np[i,j] == 9:  
                im_seg_RGB[i,j,:] = [64, 64, 0]    
            elif output_np[i,j] == 10:  
                im_seg_RGB[i,j,:] = [0, 128, 192]    
                    
    # horizontally stack original image and its corresponding segmentation results     
    hstack_image = np.hstack((original_im_RGB, im_seg_RGB))             
    new_im = Image.fromarray(np.uint8(hstack_image))
    file_name = val_dir + str(global_index).zfill(3) + '.jpg'
    global_index = global_index + 1
    new_im.save(file_name)  

# Train

In [None]:
# perform training and validation
train()