# Unet Implementation

In [1]:
import pickle
import gzip
import numpy as np
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

In [2]:
use_cuda = torch.cuda.is_available()
#device = torch.device("cuda:0" if use_cuda else "cpu")
device = torch.device("cpu")

### Helper functions

In [3]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

In [4]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)
        
def flatten(dicts):
    
    # extract the annotated video frames, their labels
    # and the boxes
    
    images = []
    boxes = []
    labels = []
    
    for i, dic in enumerate(dicts):
        
        video = dic['video']
        frames = dic['frames']
        dic_labels = dic['label']
        
        images.append((video[:,:,frames[0]] / 255).astype('float32'))
        images.append((video[:,:,frames[1]] / 255).astype('float32'))
        images.append((video[:,:,frames[2]] / 255).astype('float32'))
        
        labels.append(dic_labels[:,:,frames[0]].astype('uint8'))
        labels.append(dic_labels[:,:,frames[1]].astype('uint8'))
        labels.append(dic_labels[:,:,frames[2]].astype('uint8'))
        
        boxes.append(dic['box'].astype('uint8'))
        boxes.append(dic['box'].astype('uint8'))
        boxes.append(dic['box'].astype('uint8'))
        
    return images, boxes, labels

def resize(images, boxes, labels, size):
    
    # resize images, boxes and labels

    for i in range(len(images)):
        
        images[i] = cv2.resize(images[i], size, interpolation = cv2.INTER_LANCZOS4)
        boxes[i] = cv2.resize(boxes[i], size, interpolation = cv2.INTER_LANCZOS4)
        labels[i] = cv2.resize(labels[i], size, interpolation = cv2.INTER_LANCZOS4)
        
        # add number of channels (in this case 1) at the front for the
        # right input shape for the pytorch layers
        
        images[i] = np.expand_dims(images[i], axis=0)
        images[i] = np.expand_dims(images[i], axis=0)
        
    return np.concatenate(images, axis=0), boxes, labels

### Load data

In [5]:
# load data
train_data = load_zipped_pickle("data/train.pkl")
test_data = load_zipped_pickle("data/test.pkl")

In [6]:
size = (384, 384)

### Preprocessing

In [7]:
images, boxes, labels = flatten(train_data)
images, boxes, labels = resize(images, boxes, labels, size)

# TODO: add data augmentation, denoising, other preprocessing steps?

### Model

In [36]:
class Unet(nn.Module):
    
    def __init__(self):
        
        super().__init__()
        
        self.block_enc_1 = self.block(1, 64, True)
        self.block_enc_2 = self.block(64, 128, True)
        self.block_enc_3 = self.block(128, 256, True)
        self.block_enc_4 = self.block(256, 512, True)
        
        self.block_inbetween = self.block(512, 1024, False)
        
        self.block_dec_1 = self.block(1024, 512, False)
        self.block_dec_2 = self.block(512, 256, False)
        self.block_dec_3 = self.block(256, 128, False)
        
        self.block_last = self.block(128, 64, False, last=True)
        
    def block(self, channels, filters, enc, last=False):
        
        modules = []
        
        modules.append(nn.Conv2d(channels, filters, 3, 1, padding='same'))
        modules.append(nn.ReLU())
        modules.append(nn.Conv2d(filters, filters, 3, 1, padding='same'))
        modules.append(nn.ReLU())

        # if it is an encoder block
        # then max pool, else upconv
        
        if enc:
            modules.append(nn.MaxPool2d(2))
        else:
            if not last:
                modules.append(nn.ConvTranspose2d(filters, filters//2, 2, stride=2))
            else:
                modules.append(nn.Conv2d(filters, 1, 1, 1))
                modules.append(nn.Sigmoid())
            
        return nn.Sequential(*modules)
    
    def crop_and_concat(self, x, y):
        
        transform = transforms.CenterCrop(x.size(dim=2))
        
        return torch.cat((x, transform(y)), dim=1)
            
    def forward(self, x):
        
        # encoder
        
        x1 = self.block_enc_1(x)
        x2 = self.block_enc_2(x1)
        x3 = self.block_enc_3(x2)
        x4 = self.block_enc_4(x3)
        
        # between encoder and decoder
        
        x5 = self.block_inbetween(x4)
        
        # decoder
        
        x6 = self.block_dec_1(self.crop_and_concat(x5, x4))
        x7 = self.block_dec_2(self.crop_and_concat(x6, x3))
        x8 = self.block_dec_3(self.crop_and_concat(x7, x2))
        
        x9 = self.block_last(self.crop_and_concat(x8, x1))
        
        return x9

### Train

In [37]:
unet = Unet().to(device)
train_images = torch.from_numpy(images[:10]).to(device)

In [38]:
unet(train_images).shape

torch.Size([10, 1, 384, 384])