In [None]:
import os
import sys
sys.path.append(os.path.join(os.pardir, os.pardir))

import math
import random

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms.functional as F_transforms
import torchinfo

import pytorch_lightning as pl
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

#from amlutils.task3 import load_zipped_pickle, visualize_segmentation

RANDOM_SEED = 42

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
pl.seed_everything(RANDOM_SEED, workers=True)

In [None]:
# CHANGE DIRECTORY
data_dir = os.path.join(os.pardir, 'data')

In [None]:
import pickle
import gzip

import matplotlib.pyplot as plt

def load_zipped_pickle(filename):
    with gzip.open(filename, 'rb') as f:
        loaded_object = pickle.load(f)
        return loaded_object

def visualize_segmentation(ax, image, segmentation, segmentation_opacity=0.5):
    ax.imshow(image)
    ax.imshow(segmentation, alpha=segmentation_opacity)


In [None]:
def crop_and_concat(conv_out, upconv_out):
    '''
    Perform
    '''
    conv_out_crop = F_transforms.center_crop(conv_out, upconv_out.shape[2:])
    return torch.concat([conv_out_crop, upconv_out], dim=1) 

def print_shape(var_name: str, var):
    pass
    #print(f'{var_name}.shape = {var.shape}')

def get_num_incoming_nodes(tensor):
    channels, _, width, height = tensor.shape
    return channels * width * height

def initialize_conv_weights(weights):
    std = math.sqrt(2 / get_num_incoming_nodes(weights))
    nn.init.normal_(weights, mean=0.0, std=std)

def initialize_conv_bias(bias):
    nn.init.zeros_(bias)

def initialize_conv_relu_layer(layer):
    for name, param in layer.named_parameters():
        if 'weight' in name:
            initialize_conv_weights(param.data)
        if 'bias' in name:
            initialize_conv_bias(param.data)

class UNet(pl.LightningModule):
    '''
    Architecture based on research paper:

    https://arxiv.org/pdf/1505.04597.pdf
    '''

    def __init__(self, n_classes=2):
        '''
        Initialize layers.
        
        Args:
            n_classes (int): Number of classes to map to (default=2).
        '''
        super(UNet, self).__init__()

        ## Contracting path. ##
        self.contract_conv1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_conv2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.contract_conv3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_conv4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.contract_conv5 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_conv6 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.contract_conv7 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_conv8 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.contract_conv9 = nn.Sequential(
            nn.Conv2d(512, 1024, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.contract_conv10 = nn.Sequential(
            nn.Conv2d(1024, 1024, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        self.contract_dropout = nn.Dropout(0.1)

        ## Expansive path. ##
        self.expand_upconv1 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(1024, 512, kernel_size=2)
        )
        self.expand_conv1 = nn.Sequential(
            nn.Conv2d(1024, 512, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.expand_conv2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        self.expand_upconv2 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 256, kernel_size=2)
        )
        self.expand_conv3 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.expand_conv4 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        self.expand_upconv3 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=2)
        )
        self.expand_conv5 = nn.Sequential(
            nn.Conv2d(256, 128, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.expand_conv6 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        self.expand_upconv4 = nn.Sequential(
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=2)
        )
        self.expand_conv7 = nn.Sequential(
            nn.Conv2d(128, 64, kernel_size=3, padding='same'),
            nn.ReLU()
        )
        self.expand_conv8 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding='same'),
            nn.ReLU()
        )

        self.expand_final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

        self.initialize_parameters()

    def initialize_parameters(self):
        for name, param in self.named_parameters():
            if name.endswith(".weight"):
                initialize_conv_weights(param.data)
            if name.endswith(".bias"):
                initialize_conv_bias(param.data)

    def forward(self, X):
        ## Contracting path. ##
        print_shape('X', X)
        contract_conv1_out = self.contract_conv1(X)
        print_shape('contract_conv1_out', contract_conv1_out)
        contract_conv2_out = self.contract_conv2(contract_conv1_out)
        print_shape('contract_conv2_out', contract_conv2_out)
        contract_pool1_out = self.contract_pool1(contract_conv2_out)
        print_shape('contract_pool1_out', contract_pool1_out)

        contract_conv3_out = self.contract_conv3(contract_pool1_out)
        print_shape('contract_conv3_out', contract_conv3_out)
        contract_conv4_out = self.contract_conv4(contract_conv3_out)
        print_shape('contract_conv4_out', contract_conv4_out)
        contract_pool2_out = self.contract_pool2(contract_conv4_out)
        print_shape('contract_pool2_out', contract_pool2_out)

        contract_conv5_out = self.contract_conv5(contract_pool2_out)
        print_shape('contract_conv5_out', contract_conv5_out)
        contract_conv6_out = self.contract_conv6(contract_conv5_out)
        print_shape('contract_conv6_out', contract_conv6_out)
        contract_pool3_out = self.contract_pool3(contract_conv6_out)
        print_shape('contract_pool3_out', contract_pool3_out)

        contract_conv7_out = self.contract_conv7(contract_pool3_out)
        print_shape('contract_conv7_out', contract_conv7_out)
        contract_conv8_out = self.contract_conv8(contract_conv7_out)
        print_shape('contract_conv8_out', contract_conv8_out)
        contract_pool4_out = self.contract_pool4(contract_conv8_out)
        print_shape('contract_pool4_out', contract_pool4_out)

        contract_conv9_out = self.contract_conv9(contract_pool4_out)
        print_shape('contract_conv9_out', contract_conv9_out)
        contract_conv10_out = self.contract_conv10(contract_conv9_out)
        print_shape('contract_conv10_out', contract_conv10_out)

        contract_dropout_out = self.contract_dropout(contract_conv10_out)
        print_shape('contract_dropout_out', contract_dropout_out)

        ## Expansive path. ##
        expand_upconv1_out = self.expand_upconv1(contract_dropout_out)
        print_shape('expand_upconv1_out', expand_upconv1_out)
        expand_conv1_in = crop_and_concat(contract_conv8_out, expand_upconv1_out)
        print_shape('expand_conv1_in', expand_conv1_in)
        expand_conv1_out = self.expand_conv1(expand_conv1_in)
        print_shape('expand_conv1_out', expand_conv1_out)
        expand_conv2_out = self.expand_conv2(expand_conv1_out)
        print_shape('expand_conv2_out', expand_conv2_out)

        expand_upconv2_out = self.expand_upconv2(expand_conv2_out)
        print_shape('expand_upconv2_out', expand_upconv2_out)
        expand_conv3_in = crop_and_concat(contract_conv6_out, expand_upconv2_out)
        print_shape('expand_conv3_in', expand_conv3_in)
        expand_conv3_out = self.expand_conv3(expand_conv3_in)
        print_shape('expand_conv3_out', expand_conv3_out)
        expand_conv4_out = self.expand_conv4(expand_conv3_out)
        print_shape('expand_conv4_out', expand_conv4_out)

        expand_upconv3_out = self.expand_upconv3(expand_conv4_out)
        print_shape('expand_upconv3_out', expand_upconv3_out)
        expand_conv5_in = crop_and_concat(contract_conv4_out, expand_upconv3_out)
        print_shape('expand_conv5_in', expand_conv5_in)
        expand_conv5_out = self.expand_conv5(expand_conv5_in)
        print_shape('expand_conv5_out', expand_conv5_out)
        expand_conv6_out = self.expand_conv6(expand_conv5_out)
        print_shape('expand_conv6_out', expand_conv6_out)

        expand_upconv4_out = self.expand_upconv4(expand_conv6_out)
        print_shape('expand_upconv4_out', expand_upconv4_out)
        expand_conv7_in = crop_and_concat(contract_conv2_out, expand_upconv4_out)
        print_shape('expand_conv7_in', expand_conv7_in)
        expand_conv7_out = self.expand_conv7(expand_conv7_in)
        print_shape('expand_conv7_out', expand_conv7_out)
        expand_conv8_out = self.expand_conv8(expand_conv7_out)
        print_shape('expand_conv8_out', expand_conv8_out)

        final_out = self.expand_final_conv(expand_conv8_out)
        print_shape('final_out', final_out)

        return final_out

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        return optimizer

    def training_step(self, batch, batch_nb):
        X, y = batch
        X = torch.unsqueeze(X, 1)

        y_hat = self.forward(X)
        y = F_transforms.center_crop(y, y_hat.shape[2:])

        '''
        prediction = F.softmax(y_hat, dim=1)
        prediction = torch.squeeze(prediction)

        X_crop = F_transforms.center_crop(X, y_hat.shape[2:])
        X_crop = torch.squeeze(X_crop)
        X_crop = torch.squeeze(X_crop).numpy()

        visualize_segmentation(
            X_crop,
            prediction[0,:,:].detach().numpy(),
            segmentation_opacity=1
        )

        visualize_segmentation(
            X_crop,
            prediction[1,:,:].detach().numpy(),
            segmentation_opacity=1
        )
        '''

        loss = F.cross_entropy(y_hat, y)

        self.log('train_loss', loss)
        return {'loss': loss}

    #def validation_step(self, batch, batch_nb):
    #    x, y = batch
    #    y_hat = self.forward(x)
    #    loss = F.cross_entropy(y_hat, y)
    #    return {'val_loss': loss}

# Model Training

In [None]:
train_set = load_zipped_pickle(os.path.join(data_dir, 'labeled-images.pkl'))

X_train = train_set[0]['Image']
y_train = train_set[0]['Label']

def convert_to_tensor(X, data_type=torch.float):
    X_tensor = torch.tensor(X, dtype=data_type)
    # Unsqueeze X tensor to have another dimension representing the channel, this
    # is needed for convolutions.
    X_tensor = torch.unsqueeze(X_tensor, 0)
    return X_tensor

def build_data_loader(X, y):
    X_tensor = convert_to_tensor(X)
    print(f'[build_data_loader] X_tensor.shape = {X_tensor.shape}')
    y_tensor = convert_to_tensor(y, torch.long)

    train_tensor = TensorDataset(X_tensor, y_tensor)
    return DataLoader(dataset=train_tensor, batch_size=1, shuffle=True)

train_loader = build_data_loader(X_train, y_train)

In [None]:
mv_segmenter = UNet()

early_stopping = EarlyStopping(
    monitor='train_loss',
    stopping_threshold=0.001,
    patience=100
)

trainer = pl.Trainer(
    callbacks=[early_stopping],
    deterministic=True
)
trainer.fit(mv_segmenter, train_loader)

In [None]:
trainer.save_checkpoint('mv-segmenter-u-net-loss-threshold-0.001.ckpt')

In [None]:
conv_relu = nn.Sequential(
    nn.Conv2d(1, 64, kernel_size=3, padding='same'),
    nn.ReLU()
)

for name, param in conv_relu.named_parameters():
    print(f'{name}: {param}')

def get_num_incoming_nodes(tensor):
    channels, _, width, height = tensor.shape
    return channels * width * height

def initialize_conv_weights(weights):
    std = math.sqrt(2 / get_num_incoming_nodes(weights))
    nn.init.normal_(weights, mean=0.0, std=std)

def initialize_conv_bias(bias):
    nn.init.zeros_(bias)

def initialize_conv_relu_layer(layer):
    for name, param in layer.named_parameters():
        if name.endswith(".weight"):
            initialize_conv_weights(param.data)
        if name.endswith(".bias"):
            initialize_conv_bias(param.data)

initialize_conv_relu_layer(conv_relu)

for name, param in conv_relu.named_parameters():
    print(f'{name}: {param}')

In [None]:
float('inf') * 2

# Segment Image

In [None]:
image = convert_to_tensor(train_set[0]['Image'])
image = torch.unsqueeze(image, 0)
segmentation = mv_segmenter(image)
segmentation = torch.squeeze(segmentation)
segmentation.shape

In [None]:
prediction = F.softmax(segmentation, dim=0)
print(prediction.shape)
prediction = torch.max(prediction, dim=0).indices
prediction = torch.squeeze(prediction)#.detach().numpy()
#prediction = prediction[0,:,:].detach().numpy()
prediction

In [None]:
prediction = F.softmax(segmentation, dim=0)
print(torch.min(prediction[0,:,:]))
prediction = torch.round(prediction[1,:,:]).detach().numpy()
np.max(prediction)

In [None]:
prediction = F.softmax(segmentation, dim=0)
prediction = torch.where(prediction[1,:,:] > 0.29, torch.ones(prediction.shape[1:]), torch.zeros(prediction.shape[1:])).numpy()
prediction

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
fig.set_facecolor('white')

cropped_image = F_transforms.center_crop(image, prediction.shape[1:])
cropped_label = F_transforms.center_crop(torch.tensor(train_set[0]['Label']), prediction.shape[1:]).numpy()
cropped_image = torch.squeeze(cropped_image)
cropped_image = torch.squeeze(cropped_image).numpy()

visualize_segmentation(ax[0], cropped_image, prediction, segmentation_opacity=0.5)
ax[0].set_title('U-Net Predicted Segmentation (P[Is MV]>0.29)')

visualize_segmentation(ax[1], cropped_image, cropped_label, segmentation_opacity=0.5)
ax[1].set_title('Ground Truth')

plt.tight_layout()
plt.show()

# Inspect Training Set

In [None]:
train_set = load_zipped_pickle(os.path.join(data_dir, 'labeled-images.pkl'))

In [None]:
visualize_segmentation(train_set[0]['Image'], train_set[0]['Label'])

In [None]:
train_input = torch.from_numpy(train_set[0]['Image'])
train_input = torch.unsqueeze(train_input, 0)
torchinfo.summary(UNet(), input_size=(1, *train_input.shape))

In [None]:
visualize_segmentation(
    F_transforms.center_crop(torch.from_numpy(train_set[0]['Image']), prediction.shape).numpy(),
    F_transforms.center_crop(torch.from_numpy(train_set[0]['Label']), prediction.shape).numpy()
)

In [None]:
tens = F_transforms.center_crop(torch.from_numpy(train_set[0]['Image']), (60, 60))
#torch.stack([tens, tens], dim=2).shape
tens = torch.unsqueeze(tens, dim=2)
torch.concat([tens, tens], dim=2).shape