# Unet Implementation

In [114]:
import pickle
import gzip
import numpy as np
import os
import time
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import segmentation_models_pytorch as smp

In [115]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
device = torch.device("cpu")

### Helper functions

In [116]:
def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

In [117]:
def save_zipped_pickle(obj, filename):
    with gzip.open(filename, 'wb') as f:
        pickle.dump(obj, f, 2)
        
def flatten(dicts):
    
    # extract the annotated video frames, their labels
    # and the boxes
    
    images = []
    boxes = []
    labels = []
    
    for i, dic in enumerate(dicts):
        
        video = dic['video']
        frames = dic['frames']
        dic_labels = dic['label']
        
        images.append((video[:,:,frames[0]] / 255).astype('float32'))
        images.append((video[:,:,frames[1]] / 255).astype('float32'))
        images.append((video[:,:,frames[2]] / 255).astype('float32'))
        
        labels.append(dic_labels[:,:,frames[0]].astype('uint8'))
        labels.append(dic_labels[:,:,frames[1]].astype('uint8'))
        labels.append(dic_labels[:,:,frames[2]].astype('uint8'))
        
        boxes.append(dic['box'].astype('uint8'))
        boxes.append(dic['box'].astype('uint8'))
        boxes.append(dic['box'].astype('uint8'))
        
    return images, boxes, labels

def resize(images, boxes, labels, size):
    
    # resize images, boxes and labels

    for i in range(len(images)):
        
        images[i] = cv2.resize(images[i], size, interpolation = cv2.INTER_LANCZOS4)
        boxes[i] = cv2.resize(boxes[i], size, interpolation = cv2.INTER_LANCZOS4)
        labels[i] = cv2.resize(labels[i], size, interpolation = cv2.INTER_LANCZOS4)
        
        # add number of channels (in this case 1) at the front for the
        # right input shape for the pytorch layers
        
        images[i] = np.expand_dims(images[i], axis=0)
        images[i] = np.expand_dims(images[i], axis=0)
        
        labels[i] = np.expand_dims(labels[i], axis=0)
        
    return np.concatenate(images, axis=0), boxes, np.concatenate(labels, axis=0)

### Load data

In [7]:
# load data
train_data = load_zipped_pickle("data/train.pkl")
test_data = load_zipped_pickle("data/test.pkl")

In [8]:
size = (128, 128)

### Preprocessing

In [150]:
images, boxes, labels = flatten(train_data)
images, boxes, labels = resize(images, boxes, labels, size)

train_images = torch.from_numpy(images[194:]).to(device)
labels = torch.from_numpy(labels[194:]).to(device)
# TODO: add data augmentation, denoising, other preprocessing steps?
# TODO: split into train and validation sets

### Model

In [151]:
class Unet(nn.Module):
    
    def __init__(self, filters=64):
        
        super().__init__()
        
        self.max_pool = nn.MaxPool2d(2)
        
        self.block_enc_1 = self.conv_block(1, filters)
        self.block_enc_2 = self.conv_block(filters, 2*filters)
        self.block_enc_3 = self.conv_block(2*filters, 4*filters)
        
        self.block_inbetween = self.conv_block(4*filters, 8*filters, True)
        
        self.block_dec_1 = self.conv_block(8*filters, 4*filters, True)
        self.block_dec_2 = self.conv_block(4*filters, 2*filters, True)
        
        self.block_last = self.conv_block(2*filters, filters, True, True)
        
    def conv_block(self, channels, filters, enc=False, last=False):
        
        modules = []
        
        modules.append(nn.Conv2d(channels, filters, 3, 1, padding='same'))
        modules.append(nn.ReLU())
        modules.append(nn.Conv2d(filters, filters, 3, 1, padding='same'))
        modules.append(nn.ReLU())

        if enc:
            if not last:
                modules.append(nn.ConvTranspose2d(filters, filters//2, 2, stride=2))
            else:
                modules.append(nn.Conv2d(filters, 1, 1, 1))
                modules.append(nn.Sigmoid())
            
        return nn.Sequential(*modules)
            
    def forward(self, x):
        
        # encoder
        
        x1 = self.block_enc_1(x)
        x2 = self.max_pool(x1)
        
        x3 = self.block_enc_2(x2)
        x4 = self.max_pool(x3)
        
        x5 = self.block_enc_3(x4)
        x6 = self.max_pool(x5)
        
        # between encoder and decoder
        
        x7 = self.block_inbetween(x6)
        
        # decoder
        
        x8 = self.block_dec_1(torch.cat((x7, x5), dim=1))
        x9 = self.block_dec_2(torch.cat((x8, x3), dim=1))
        
        x10 = self.block_last(torch.cat((x9, x1), dim=1))
        
        return x10

### Train

In [152]:
def iou(inputs, targets):       
        
        inputs = inputs.squeeze()
        inputs = torch.round(inputs).to(torch.uint8)
        
        intersection = (inputs * targets).sum(dim=(1, 2))
        total = (inputs + targets).sum(dim=(1, 2))
        union = total - intersection 
        
        IoU = intersection / union
                
        return IoU

In [153]:
def val_loss(model, val_loader):
    # TODO: calculate IoU for validation dataset

IndentationError: expected an indented block (3965556404.py, line 2)

In [181]:
unet = Unet(filters=16).to(device)
optimizer = torch.optim.Adam(unet.parameters(), lr=1e-3)
criterion = smp.losses.JaccardLoss(mode='binary')

In [182]:
train_tensorset = TensorDataset(train_images, labels)
train_loader = DataLoader(train_tensorset, batch_size=1)

### Train Loop

In [183]:
for epoch in range(100):
    
    t = time.time()
    num_samples_epoch = 0
    train_loss_cum = 0
    IoU = torch.Tensor()
    
    for x,y in train_loader:
        
        optimizer.zero_grad()
        
        unet.train()
        output = unet(x)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        
        # keep track of training loss
        num_samples_epoch += x.shape[0]
        train_loss_cum += loss * x.shape[0]
        
        # need the median IoU
        unet.eval()
        with torch.no_grad():
            IoU = torch.cat((IoU, iou(x, y)))
        
    train_iou = torch.median(IoU)
    train_loss = train_loss_cum / num_samples_epoch
        
    epoch_duration = time.time() - t
    
    print(f'Epoch {epoch} | Train IoU: {train_iou:.4f} | '
          f'Train loss: {train_loss:.4f} | '
          #f' Validation loss: {val_loss:.4f} | '
          f' Duration {epoch_duration:.2f} sec')

Epoch 0 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.23 sec
Epoch 1 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.28 sec
Epoch 2 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.26 sec
Epoch 3 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.23 sec
Epoch 4 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.25 sec
Epoch 5 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.27 sec
Epoch 6 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.24 sec
Epoch 7 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.26 sec
Epoch 8 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.26 sec
Epoch 9 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.29 sec
Epoch 10 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.26 sec
Epoch 11 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.26 sec
Epoch 12 | Train IoU: 0.0419 | Train loss: 0.9937 |  Duration 0.25 sec
Epoch 13 | Train IoU: 0.0419 | Train loss: 0.9936 |  Duration 0.27 sec
Epoch 14 | Train

In [136]:
res = unet(train_images)

In [142]:
res[0] == 0

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]]])

In [145]:
score = iou(res, labels)

In [146]:
torch.median(score)

tensor(0.)