# Liver Segmentation with ConvDeconv architecture

### In this notebook, we'll load a pretrained ConvDeconv model and predict on test images.
#### Training code at the end of the notebook lets you train your own ConvDeconv model

First, the needed imports.

In [None]:
import numpy as np
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import SimpleITK as sitk
import scipy.misc as misc
import scipy.ndimage as snd
import imageio
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
%matplotlib inline

## Utility Functions:

#### display_image_label_and_output:
    A matplotlib function to plot the image, its label and the corresponding output from the network

In [None]:
def display_image_label_and_output(image, label, output):
    plt.figure()
    plt.subplot(1,3,1)
    plt.imshow(image, cmap = 'gray')
    plt.subplot(1,3,2)
    plt.imshow(image, cmap = 'gray')
    plt.imshow(label, alpha = 0.5)
    plt.subplot(1,3,3)
    plt.imshow(output, cmap = 'gray')
    plt.imshow(label, alpha = 0.5)
    plt.show()    

#### predict_on_test_data:
    given the model and the number of files, we predict on those and display the outputs using the above function

In [None]:
def predict_on_test_data(model, n_files = 20):
    test_files = os.listdir('test_images')
    test_imgs = [os.path.join('test_images',f) for f in test_files if 'img' in f][:n_files]
    test_labels = [f.replace('img', 'label') for f in test_imgs][:n_files]
    for f,g in zip(test_imgs, test_labels):
        img_arr = imageio.imread(f)
        img_arr = (np.float32(img_arr) - img_arr.min())/(img_arr.max() - img_arr.min())
        label_arr = imageio.imread(g)
        label_arr = np.uint8((label_arr - label_arr.min())/(label_arr.max() - label_arr.min()))
        # input to neural net has to be of form NCWH
        inputs = img_arr[None,None,:,:]
        inputs = Variable(torch.from_numpy(inputs), volatile = True)
        outs = model.forward(inputs)
        _, outs = torch.max(outs, 1)
        output_arr = outs.data.numpy()[0]
        display_image_label_and_output(img_arr, label_arr, output_arr)

## ConvDeconv Network architecture

### nn.Sequential
    A sequential container. Modules will be added to it in the order they are passed in the constructor

### nn.Conv2d

    Applies a 2D convolution over an input signal composed of several input planes.
    stride controls the stride for the cross-correlation, a single number or a tuple.
    padding controls the amount of implicit zero-paddings on both sides for padding number of points for each dimension.
    dilation controls the spacing between the kernel points; also known as the à trous algorithm.

In [None]:
class ConvDeconv(nn.Module):
    
    def __init__(self):
        
        super(ConvDeconv, self).__init__()
        
        self.conv1 = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32))
        self.conv2 = nn.Sequential(nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        
        self.conv3 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128))
        self.conv4 = nn.Sequential(nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128))
        
        self.upconv1 = nn.Sequential(nn.ConvTranspose2d(128,128,stride=2, kernel_size=2), nn.BatchNorm2d(128))
        self.conv5 = nn.Sequential(nn.Conv2d(128,64, kernel_size=3, padding=1), nn.BatchNorm2d(64))
        
        self.upconv2 = nn.Sequential(nn.ConvTranspose2d(64,32,stride=2, kernel_size=2), nn.BatchNorm2d(32))
        self.conv6 = nn.Sequential(nn.Conv2d(32,2, kernel_size=3, padding=1), nn.BatchNorm2d(2))
        
    def forward(self, x):
        
        x = F.relu(self.conv2(F.relu(self.conv1(x))), inplace=True)
        x = F.max_pool2d(x, 2)
        
        x = F.relu(self.conv4(F.relu(self.conv3(x))), inplace=True)
        x = F.max_pool2d(x, 2)
        
        x = F.relu(self.conv5(F.relu(self.upconv1(x))))
        x = self.conv6(F.relu(self.upconv2(x)))
        x = F.log_softmax(x, dim=1)
        
        return x

## Network with random weights

In [None]:
model = ConvDeconv()
print(model)
predict_on_test_data(model, n_files = 5)

## Network loaded with trained weights

In [None]:
state = torch.load('pretrained_models/conv-deconv_cpu.tar')['state_dict']
model = ConvDeconv()
model.load_state_dict(state)
predict_on_test_data(model, n_files = 5)

### Training Code (Take Home)
Additional requirements : GPU | Additional dependencies : progressbar

In [None]:
class SimpleTrainer(object):
    def __init__(self, model, loss_fn, optimizer):
        self.model = model
        self.optimizer = optimzer
        self.loss_fn = loss_fn
        
    def forward_backward(inputs, labels):
        inputs = torch.from_numpy(inputs).float()
        labels = torch.from_numpy(labels).long()
        inputs = Variable(inputs).cuda()
        labels = Variable(labels).cuda()
        self.optimizer.zero_grad()
        outputs = self.model.forward(inputs)
        loss = self.loss_fn(outputs, labels)
        loss.backward()
        self.optimizer.step()
        return loss.data[0]

    def forward(inputs, labels):
        inputs = torch.from_numpy(inputs).float()
        labels = torch.from_numpy(labels).long()
        inputs = Variable(inputs, volatile=True).cuda()
        labels = Variable(labels, volatile=True).cuda()
        outputs = self.model.forward(inputs)
        loss = self.loss_fn(outputs, labels)
        return loss.data[0]

## Prepare training data

The 2D slices are saved in .h5 format (H5 file is a data file saved in the Hierarchical Data Format (HDF). It contains multidimensional arrays of scientific data.)
Images and labels are stored as two datasets in the h5 file and can be accessed by file_obj\['image'\] and file_obj\['label'\]

We get the images and labels from it, randomise it and split it for training and validation.

In [None]:
def get_training_data():
    # Reading the .h5
    x = h5py.File('2DLiverSlices_128.h5','r')
    
    # Getting the images and the labels
    images = x['image'][:]
    labels = x['label'][:]
    x.close()
    
    randperm = np.random.permutation(images.shape[0])
    images = images[randperm]
    labels = labels[randperm]
    
    # Splitting the data into training and validation
    train_images = images[:1500]
    train_labels = labels[:1500]
    val_images = images[1500:]
    val_labels = labels[1500:]    
    return train_images, train_labels, val_images, val_labels

### Defining the hyper-parameter for the network

In [None]:
EPOCHS = 100  # Number of iterations for training the newtork
BATCH_SIZE = 48  # Number of training example to be fed to the network
PATCH_SIZE = [128,128]  # the input size of the image (L*B)

## Initialize the model

## Optimization:
Use the optim package to define an Optimizer that will update the weights of the model for us.

In [None]:
model = ConvDeconv().cuda()
# lr is the learning rate for optimization
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay=5e-5)
trainer = SimpleTrainer(model,nn.NLLLoss2d(), optimizer)
train_images, train_labels, val_images, val_labels = get_training_data()

In [None]:
for i in range(EPOCHS):
    print('Epoch: ' + str(i))
    
    # train
    model.train()
    train_loss = []
    bar = progressbar.ProgressBar()
    for j in bar(range(0, train_images.shape[0], BATCH_SIZE)):
        image_batch, label_batch = train_images[j: j+BATCH_SIZE], train_labels[j: j+BATCH_SIZE]
        image_batch = image_batch.reshape(image_batch.shape[0], 1, PATCH_SIZE[0], PATCH_SIZE[1])
        train_loss.append(trainer.forward_backward(image_batch, label_batch))
    print('Train loss: ' + str(np.array(train_loss).mean()))
    
    torch.save({'state_dict':model.cpu().float().state_dict()}, 'conv-deconv_cpu.tar')
    model.cuda()
    # validate
    
    model.eval()    
    val_loss = []
    bar = progressbar.ProgressBar()
    for j in bar(range(0, val_images.shape[0], BATCH_SIZE)):
        image_batch, label_batch = val_images[j: j+BATCH_SIZE], val_labels[j: j+BATCH_SIZE]
        image_batch = image_batch.reshape(image_batch.shape[0], 1, PATCH_SIZE[0], PATCH_SIZE[1])
        val_loss.append(trainer.forward(image_batch, label_batch))
    print('Val loss: ' + str(np.array(val_loss).mean()))

### Show results on validation data

In [None]:
model.eval()
inputs = Variable(torch.from_numpy(val_images[3].reshape(1,1,128,128))).cuda()
out = model.forward(inputs)
out = np.argmax(out.data.cpu().numpy(), axis=1).reshape(128,128)
plt.figure()
plt.imshow(val_images[3], cmap = 'gray')
plt.figure()
plt.imshow(out)