# COMPUTER ASSIGNMENT 04



## U-net for image segmentation

Download train, train_mask, test from the following link:

https://www.dropbox.com/sh/e8j991bsd269fcq/AACARtYIoYnQaydahUlo22NFa?dl=0

and extract them to the current directory.


In class, we talked about U-net for image segmentation.

This assignment is intended to 
- help you better understand the concept of U-net for image segmentation 
- help you get started with designing networks in pytorch including loading data, network design, loss function, training and testing.

You should 
 -  Implement the U-net of the following architechure.
 ![U-net](U-net_architecture.png)
 -  Write function dice_coeff(input, target) for evaluation
 -  Load training dataset and testing dataset.
 Notice that you should rescale the images to a smaller size. Otherwise it's impossible to train on cpu.
 -  Train your network for a few epochs.
 -  Test your model by feeding in a new image in testing dataset. Plot your result.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import sys
import os
from optparse import OptionParser
import numpy as np
from torch import optim
from PIL import Image
from torch.autograd import Function, Variable
import matplotlib.pyplot as plt
from torchvision import transforms
%matplotlib inline

### [ TODO 1 ] First define following layers to be used later
- **Conv2d + BatchNorm2d + ReLu ** as **single conv2d layer** ,
- **Maxpool2d + single conv2d layer ** as **down layer**,
- **Upsample + single conv2d layer ** as **up layer** ,
-  **Conv2d ** as **outconv layer** 

You can check out the documentation in this link to understand how to use the methods called in the provided template:

 https://pytorch.org/docs/stable/nn.html
 
  ![single_conv](single_conv_layer.png)
  ![down_layer](down_layer.png)
  ![up_layer](Up_layer.png)
  ![outconv_layer](outconv_layer.png)
  

In [7]:
################################################ [TODO] ###################################################
# DEFINE SINGLE_CONV CLASS
class single_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(single_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True) # Save memory
        )

    def forward(self, x):
        x = self.conv(x)
        return x


################################################ [TODO] ###################################################
# DEFINE DOWN CLASS
class down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down, self).__init__()
        self.down = nn.MaxPool2d(kernel_size=2, stride=2) # use nn.MaxPool2d( )
        self.conv = single_conv(in_ch, out_ch) # use previously defined single_cov
    def forward(self, x):
        x = self.down(x)
        x = self.conv(x)
        return x
    

################################################ [TODO] ###################################################
# DEFINE UP CLASS
# Note that this class will not only upsample x1, but also concatenate up-sampled x1 with x2 to generate the final output

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()       
        self.up = nn.Upsample(scale_factor=2) # use nn.Upsample( )
        self.conv = single_conv(in_ch, out_ch) # use previously defined single_cov

    def forward(self, x1, x2):
        # This part is tricky, so we provide it for you
        # First we upsample x1
        x1 = self.up(x1)
            
        # Notice that x2 and x1 may not have the same spatial size. 
        # This is because when you downsample old_x2(say 25 by 25), you will get x1(12 by 12)   
        # Then you perform upsample to x1, you will get new_x1(24 by 24)
        # You should pad a new row and column so that new_x1 and x2 have the same size.
        
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, (diffX // 2, diffX - diffX//2,
                        diffY // 2, diffY - diffY//2))
        
        # Now we concatenat x2 channels with x1 channels
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x

################################################ [TODO] ###################################################
# DEFINE OUTCONV CLASS
class outconv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(outconv, self).__init__()
        self.conv = nn.Conv2D(in_ch, out_ch) # Use nn.Conv2D( ) since we do not need to do batch norm and relu at this layer

    def forward(self, x):
        x = self.conv(x)
        return x

In [8]:
################################################ [TODO] ###################################################
# Build your network with predefined classes: single_conv, up, down, outconv
import torch.nn.functional as F
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super(UNet, self).__init__()
        self.inc = single_conv(n_channels, 16) # conv2d +  batchnorm + relu
        self.down1 = down(16, 32)         # maxpool2d + conv2d + batchnorm + relu
        self.down2 = down(32, 32)         # maxpool2d + conv2d + batchnorm + relu

        self.up1 = up(64, 16)             # upsample + pad + conv2d + batchnorm + relu
        self.up2 = up(32, 16)             # upsample + pad + conv2d + batchnorm + relu

        self.outc = outconv(16, 1)        # conv2d

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)

        x = self.up1(x3, x2)
        x = self.up2(x, x1)

        x = self.outc(x)
        return F.sigmoid(x)
    

### [ TODO 2 ] Define evaulation function:
Based on what we have learnt in class, Dice coefficient is defined as 
![dice.png](dice.png)
For the case of evaluating a Dice coefficient on predicted segmentation masks, we can approximate intersection of A and B as the element-wise multiplication between the prediction and target mask, and then sum the resulting matrix.

In order to quantify the area of A and B, some researchers use the simple sum whereas other researchers prefer to use the squared sum for this calculation. 



In [10]:
################################################ [TODO] ###################################################
# define dice coefficient 
class DiceCoeff(Function):
    """Dice coeff for one pair of input image and target image"""
    def forward(self, input, target):
        self.save_for_backward(input, target)
        eps = 0.0001 # in case union = 0
        ################################################ [TODO] ###################################################
        # Calculate intersection and union. 
        # You can convert the input image into a vector with input.contiguous().view(-1)
        # Then use torch.dot(A, B) to calculate the intersection.
        # Use torch.sum(A) to get the sum.
        self.inter = torch.sum(input*target) # Instruction looks strange
        self.union = torch.sum(input) + torch.sum(target) + eps
        # Calculate DICE 
        d = self.inter/self.union
        return d


################################################ [TODO] ###################################################
# Calculate dice coefficients for batches
def dice_coeff(input, target):
    """Dice coeff for batches"""
    s = torch.FloatTensor(1).zero_()
    
    # Are you sure this has grade or just for test?
    
    # For each pair of input and target, call DiceCoeff().forward(input, target) to calculate dice coefficient
    # Then average
    for i, c in enumerate(zip(input, target)):
        s = s + DiceCoeff.forward(c[0], c[1]) 
    s = s / (i + 1)
    return s

### Load data ids. 

Split them into training and validation. Validation percent of 0.05 means 5% of training dataset is used as validation.

You can try different percentage.

This part has been done for you.

In [11]:
def get_ids(dir):
    """Returns a list of the ids in the directory"""
    return (f[:-4] for f in os.listdir(dir))


def split_ids(ids, n=2):
    """Split each id in n, creating n tuples (id, k) for each id"""
    return ((id, i)  for id in ids for i in range(n))


def split_train_val(dataset, val_percent=0.05):
    
    dataset = list(dataset)
    length = len(dataset)
    n = int(length * val_percent)
    random.shuffle(dataset)
    return {'train': dataset[:-n], 'val': dataset[-n:]}




In [15]:
dir_img = 'train/'
dir_mask = 'train_masks/'
dir_checkpoint = 'checkpoints/'

# Get ids of training dataset
ids = get_ids(dir_img)
ids = split_ids(ids)
# iddataset consists iddataset['train'] and iddataset['val']
# you can get all the ids of the images in training dataset with following code:
# for id, pos in iddataset['train']:
#    print(id)
# you will need this in the following get_imgs_and_masks() function
# Or you can also load images in your way
iddataset = split_train_val(ids, 0.05)
# Get the number of training samples
N_train = len(iddataset['train'])


### [ TODO 3 ] & [ TODO 4 ] Load images and start training your network

In [None]:
# You might need to use these functions in the following steps
# hwc_to_chw: Convert images from Height*Width*Channels to Channels*Height*Width
def hwc_to_chw(img):
    return np.transpose(img, axes=[2, 0, 1])

# normalize: normalize from 255 to 1
def normalize(x):
    return x / 255

# batch:  Yields lists by batch
def batch(iterable, batch_size):
    b = []
    for i, t in enumerate(iterable):
        b.append(t)
        if (i + 1) % batch_size == 0:
            yield b
            b = []
    if len(b) > 0:
        yield b


In [91]:
################################################ [TODO] ###################################################
# This function returns rescaled images and masks
# Note that:
#  - The shape of images should be Channels*rescaled_Height*rescaled_Width 
#  - Pixel values should be normalized to [0,1]
def get_imgs_and_masks(ids, dir_img, dir_mask, scale):
    """Return all the couples (img, mask)"""
    # Read in the image and rescale it according to the scale factor
    # You can use Image.open() to read images
    # Image is a package of PIL. Check https://pillow.readthedocs.io/en/stable/reference/Image.html for more details
    imgs = ...
    rescaled_imgs = ...

    # Convert images from Height*Width*Channels to Channels*Height*Width
    # you can use hwc_to_chw() 
    imgs_switched = ...
    # Then normalize switched images to [0,1]
    # you can use normalize()
    imgs_normalized = ...
    # Read in the mask and rescale it according to the scale factor
    masks = ...
    rescaled_masks = ...

    return zip(imgs_normalized, rescaled_masks)
################################################ [TODO] ###################################################
# This function is used to evaluate the network after each epoch of training
# Input: network and validation dataset
# Output: average dice_coeff
def eval_net(net, dataset):
    # set net mode to evaluation
    net.eval()
    tot = 0
    for i, b in enumerate(dataset):
        img = b[0]
        true_mask = b[1]
        ################################################ [TODO] ###################################################
        # convert numpy array img and true_mask to torch tensor
        img = torch.from_numpy(img).unsqueeze(0)
        true_mask = ...
      
        # Feed in the image to get predicted mask
        mask_pred = net( )...
        # For all pixels in predicted mask, set them to 1 if larger than 0.5. Otherwise set them to 0
        mask_pred = ...
        # calculate dice_coeff()
        # note that you should add all the dice_coeff in validation/testing dataset together 
        # call dice_coeff() here
        tot += ...
        # Return average dice_coeff()
    return tot / (i + 1)


# convert numpy array img and true_mask to torch tensor
        img = torch.from_numpy(img).unsqueeze(0)
        true_mask = torch.from_numpy(true_mask).unsqueeze(0)
        # Feed in the image to get predicted mask
        mask_pred = net(img)[0]
        # For all pixels in predicted mask, set them to 1 if larger than 0.5. Otherwise set them to 0
        mask_pred = (mask_pred > 0.5).float()
        # calculate dice_coeff()
        # note that you should add all the dice_coeff in validation/testing dataset together 
        tot += dice_coeff(mask_pred, true_mask).item()
        # Return average dice_coeff()
  

In [92]:
################################################ [TODO] ###################################################
# Create a UNET object. Input channels = 3, output channels = 1
net = ...

In [None]:
################################################ [TODO] ###################################################
# Specify number of epochs, image scale factor, batch size and learning rate
epochs = ... # i.e, 10
img_scale = ... # i.e, 1/16
batch_size = ...# i.e, 50
lr = ...        # i.e, 0.01

################################################ [TODO] ###################################################
# Define an optimizer for your model.
# Pytorch has built-in package called optim. Most commonly used methods are already supported.
# Here we use stochastic gradient descent to optimize
# For usage of SGD, you can read https://pytorch.org/docs/stable/_modules/torch/optim/sgd.html
# Also you can use ADAM as the optimizer
# For usage of ADAM, you can read https://www.programcreek.com/python/example/92667/torch.optim.Adam

optimizer = optim.SGD(...)
#OR optimizer = optim.Adam(...)



#suggested parameter settings: momentum=0.9, weight_decay=0.0005

# The loss function we use is binary cross entropy: nn.BCELoss()
criterion = nn.BCELoss()
# note that although we want to use DICE for evaluation, we use BCELoss for training in this example

################################################ [TODO] ###################################################
# Start training
for epoch in range(epochs):
    print('Starting epoch {}/{}.'.format(epoch + 1, epochs))
    net.train()
    ################################################ [TODO] ###################################################
    # Load images and masks for training and validation
    train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask, img_scale)
    val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask, img_scale)

    epoch_loss = 0
    
    for i, b in enumerate(batch(train, batch_size)):
        ################################################ [TODO] ###################################################
        # Get images and masks from each batch
        imgs = ...
        true_masks = ...
        ################################################ [TODO] ###################################################
        # Convert images and masks from numpy to torch tensor with torch.from_numpy
        imgs = ...
        true_masks = ...

        ################################################ [TODO] ###################################################
        # Feed your images into the network
        masks_pred = ...
        # Flatten the predicted masks and true masks. For example, A_flat = A.view(-1)
        masks_probs_flat = ...
        true_masks_flat = ...
        ################################################ [TODO] ###################################################
        # Calculate the loss by comparing the predicted masks vector and true masks vector
        # And sum the losses together 
        loss = criterion(...)
        epoch_loss += loss.item()

        print('{0:.4f} --- loss: {1:.6f}'.format(i * batch_size / N_train, loss.item()))

        # optimizer.zero_grad() clears x.grad for every parameter x in the optimizer. 
        # It’s important to call this before loss.backward(), otherwise you’ll accumulate the gradients from multiple passes.
        optimizer.zero_grad()
        # loss.backward() computes dloss/dx for every parameter x which has requires_grad=True. 
        # These are accumulated into x.grad for every parameter x
        # x.grad += dloss/dx
        loss.backward()
        # optimizer.step updates the value of x using the gradient x.grad. 
        # x += -lr * x.grad
        optimizer.step()

    print('Epoch finished ! Loss: {}'.format(epoch_loss / i))
    ################################################ [TODO] ###################################################
    # Perform validation with eval_net()
    val_dice = ...
    print('Validation Dice Coeff: {}'.format(val_dice))
    # Save the model after each epoch
    torch.save(net.state_dict(),
                dir_checkpoint + 'CP{}.pth'.format(epoch + 1))
    print('Checkpoint {} saved !'.format(epoch + 1))


### [ TODO 5 ] load one image from testing dataset and plot output mask

In [120]:
################################################ [TODO] ###################################################
# Define a function for prediction/testing
def predict_img(net,
                img,
                scale_factor=0.5,
                out_threshold=0.5):
    # set the mode of your network to evaluation
    net.eval()
    ################################################ [TODO] ###################################################
    # get the height and width of your image
    img_height = ...
    img_width = ...
    ################################################ [TODO] ###################################################
    # resize the image according to the scale factor
    img = ...
    # Normalize the image by dividing by 255
    img = ...
    # convert from Height*Width*Channel TO Channel*Height*Width
    img = ...
    # convert numpy array to torch tensor 
    X_img = ...
    
    with torch.no_grad():
        ################################################ [TODO] ###################################################
        # predict the masks 
        output_img = ...
        out_probs = output_img.squeeze(0)
        # Rescale to its original size
        out_probs = ...
        # convert to numpy array
        out_mask_np = ...

    # For all pixels in predicted mask, set them to 1 if larger than 0.5. Otherwise set them to 0
    return out_mask_np > out_threshold

In [None]:
################################################ [TODO] ###################################################
# Load an image from testing dataset
test_img = ...
    
################################################ [TODO] ###################################################
# Predict the mask
mask = predict_img(net=net,
                    img=test_img,
                    scale_factor=...,
                    out_threshold=...)


### Plot original image and mask image

In [None]:
# Plot original images and masks
...
plt.show()
