# CS137 Final Project - HuBMAP - Hacking the Human Body

## Imports and environment setup

In [None]:
# If use google colab, mount the working directory there. 
from google.colab import drive
import sys
drive.mount('/content/drive')

# NOTE: you need to use your own path to add the implementation to the python path 
# so you can import functions from implementation.py
sys.path.append('/content/drive/MyDrive/CS137_Assignment1_RobPitkin')

In [8]:
# A bit of setup
import numpy as np
import torch
import matplotlib.pyplot as plt
import sys

%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## UNet Helper Implementation

### Code derived from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py

In [2]:
class ConvBlock(torch.nn.Module):
    """
    Convolutional block
    """
    def __init__(self, num_filters=64, dropout_p=0):
        super().__init__()

        if dropout_p == 0:
            self.conv_block = torch.nn.Sequential(
                torch.nn.LazyConv2d(num_filters, kernel_size=3, padding='same'),
                torch.nn.BatchNorm2d(num_filters),
                torch.nn.ReLU(),
                torch.nn.LazyConv2d(num_filters, kernel_size=3, padding='same'),
                torch.nn.BatchNorm2d(num_filters),
                torch.nn.ReLU()
            )
        else:
            self.conv_block = torch.nn.Sequential(
                torch.nn.LazyConv2d(num_filters, kernel_size=3, padding='same'),
                torch.nn.BatchNorm2d(num_filters),
                torch.nn.ReLU(),
                torch.nn.LazyConv2d(num_filters, kernel_size=3, padding='same'),
                torch.nn.BatchNorm2d(num_filters),
                torch.nn.ReLU(),
                torch.nn.Dropout(dropout_p)
            )
    
    def forward(self, input):
        return self.conv_block(input)

In [3]:
class DownBlock(torch.nn.Module):
    """
    Downsampling Block
    """
    def __init__(self, out_channels, dropout_p=0):
        super().__init__()

        self.down_block = torch.nn.Sequential(
            torch.nn.MaxPool2d(kernel_size=2),
            ConvBlock(out_channels, dropout_p),
        )
    
    def forward(self, input):
        return self.down_block(input)

In [4]:
class UpBlock(torch.nn.Module):
    """
    Upsampling Block
    """
    def __init__(self, out_channels, dropout_p=0):
        super().__init__()

        self.up = torch.nn.LazyConvTranspose2d(out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(out_channels, dropout_p=dropout_p)
    
    def forward(self, input, skip):
        """
        Using code derived from https://github.com/milesial/Pytorch-UNet/blob/master/unet/unet_parts.py
        """
        input = self.up(input)
        # input is CHW
        diffY = skip.size()[2] - input.size()[2]
        diffX = skip.size()[3] - input.size()[3]

        input = torch.nn.functional.pad(input, [diffX // 2, diffX - diffX // 2,
                                                diffY // 2, diffY - diffY // 2])
       
        x = torch.cat([skip, input], dim=1)
        return self.conv_block(x)

In [5]:
class OutBlock(torch.nn.Module):
    """
    Final output block
    """
    def __init__(self, out_channels):
        super(OutBlock, self).__init__()

        self.conv_layer = torch.nn.LazyConv2d(out_channels=out_channels, kernel_size=1, padding='same')
    
    def forward(self, input):
        return self.conv_layer(input)


## Implementing the UNet Model

In [10]:
class UNetModel(torch.nn.Module):
    def __init__(self, num_classes=5):
        super(UNetModel, self).__init__()

        self.conv_block = ConvBlock(64)
        self.down_block1 = DownBlock(128)
        self.down_block2 = DownBlock(256)
        self.down_block3 = DownBlock(512)
        self.down_block4 = DownBlock(1024)
        self.up_block4 = UpBlock(512)
        self.up_block3 = UpBlock(256)
        self.up_block2 = UpBlock(128)
        self.up_block1 = UpBlock(64)
        self.out_block = OutBlock(num_classes)
    
    def forward(self, input):
        x1 = self.conv_block(input)
        x2 = self.down_block1(x1)
        x3 = self.down_block2(x2)
        x4 = self.down_block3(x3)
        x5 = self.down_block4(x4)
        out = self.up_block4(x5, x4)
        out = self.up_block3(out, x3)
        out = self.up_block2(out, x2)
        out = self.up_block1(out, x1)
        out = self.out_block(out)
        return out

## Instantiating model and checking summary

In [11]:
from torchsummary import summary
model = UNetModel()
model.to(device)
summary(model, input_size=(1, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 512, 512]             640
       BatchNorm2d-2         [-1, 64, 512, 512]             128
              ReLU-3         [-1, 64, 512, 512]               0
            Conv2d-4         [-1, 64, 512, 512]          36,928
       BatchNorm2d-5         [-1, 64, 512, 512]             128
              ReLU-6         [-1, 64, 512, 512]               0
         ConvBlock-7         [-1, 64, 512, 512]               0
         MaxPool2d-8         [-1, 64, 256, 256]               0
            Conv2d-9        [-1, 128, 256, 256]          73,856
      BatchNorm2d-10        [-1, 128, 256, 256]             256
             ReLU-11        [-1, 128, 256, 256]               0
           Conv2d-12        [-1, 128, 256, 256]         147,584
      BatchNorm2d-13        [-1, 128, 256, 256]             256
             ReLU-14        [-1, 128, 2