In [1]:
import glob
import os
import numpy as np
import cv2
import time

import matplotlib
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch import sigmoid
from torch import optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split

from utils.vtk import render_scan
from utils.image import load_img, img_to_array, cubify_scan, calc_dist_map
from utils.performance import calc_confusion_matrix

from losses.surface import surface_loss

### Check if GPU is available

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

### Network setup

In [3]:
torch.manual_seed(42)

setup = {
    'struct': 'lateral_ventricle', 'arch': 'Unet', 'loss_fn': 'binary',
    'batch_size': 4, 'filters': 16, 'batch_norm': True,
    'optimizer_fn': 'Adam', 'lr': 0.001, 'threshold': 0.5,
    'input_shape': (192, 256, 1), 'epochs': 50
}

collection_name = 'mindboggle_84_Nx192x256_lateral_ventricle'
dataset_dir = os.path.join('/home/filip/Projekty/ML/datasets', collection_name)

### Create Unet Network

In [4]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=(3,3), padding=(1,1)),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, x):
        x = self.conv(x)
        return x
        

In [5]:
class Unet(nn.Module):
    def __init__(self, n_channels, n_filters, n_classes):
        super().__init__()
        
        self.conv1 = ConvBlock(n_channels, n_filters)
        self.pool1 = nn.MaxPool2d(kernel_size=(2,2))
        
        self.conv2 = ConvBlock(n_filters, n_filters*2)
        self.pool2 = nn.MaxPool2d(kernel_size=(2,2))
        
        self.conv3 = ConvBlock(n_filters*2, n_filters*4)
        self.pool3 = nn.MaxPool2d(kernel_size=(2,2))
        
        self.conv4 = ConvBlock(n_filters*4, n_filters*8)
        self.pool4 = nn.MaxPool2d(kernel_size=(2,2))
        
        self.bridge5 = ConvBlock(n_filters*8, n_filters*16)
        
        self.up6 = nn.ConvTranspose2d(n_filters*16, n_filters*8, kernel_size=(2, 2), stride=(2, 2))
        self.conv6 = ConvBlock(n_filters*16, n_filters*8)
        
        self.up7 = nn.ConvTranspose2d(n_filters*8, n_filters*4, kernel_size=(2, 2), stride=(2, 2))
        self.conv7 = ConvBlock(n_filters*8, n_filters*4)
        
        self.up8 = nn.ConvTranspose2d(n_filters*4, n_filters*2, kernel_size=(2, 2), stride=(2, 2))
        self.conv8 = ConvBlock(n_filters*4, n_filters*2)
        
        self.up9 = nn.ConvTranspose2d(n_filters*2, n_filters, kernel_size=(2, 2), stride=(2, 2))
        self.conv9 = ConvBlock(n_filters*2, n_filters)
        
        self.outputs = nn.Conv2d(n_filters, n_classes, kernel_size=(1,1))
    
    def forward(self, x):        
        conv1 = self.conv1(x)
        pool1 = self.pool1(conv1)
        
        conv2 = self.conv2(pool1)
        pool2 = self.pool2(conv2)
        
        conv3 = self.conv3(pool2)
        pool3 = self.pool3(conv3)
        
        conv4 = self.conv4(pool3)
        pool4 = self.pool4(conv4)
        
        bridge5 = self.bridge5(pool4)
        
        up6 = self.up6(bridge5)
        cat6 = torch.cat([up6, conv4], dim=1)
        conv6 = self.conv6(cat6)
        
        up7 = self.up7(conv6)
        cat7 = torch.cat([up7, conv3], dim=1)
        conv7 = self.conv7(cat7)
        
        up8 = self.up8(conv7)
        cat8 = torch.cat([up8, conv2], dim=1)
        conv8 = self.conv8(cat8)
        
        up9 = self.up9(conv8)
        cat9 = torch.cat([up9, conv1], dim=1)
        conv9 = self.conv9(cat9)
        
        x = self.outputs(conv9)
        
        return sigmoid(x)

### Load dataset

In [6]:
class Dataset2d(Dataset):
    def __init__(self, X_tensor, y_tensor):
        self.X = X_tensor
        self.y = y_tensor
        
    def __getitem__(self, index):
        return (self.X[index], self.y[index])
    
    def __len__(self):
        return len(self.X)        

In [7]:
X_files = sorted(
    glob.glob(
        os.path.join(dataset_dir, 'train', 'images', '*.png')
    )[:]
)
y_files = sorted(
    glob.glob(
        os.path.join(dataset_dir, 'train', 'labels', '*.png')
    )[:]
)

X_images = np.stack([img_to_array(load_img(filename)) for filename in X_files])
y_images = np.stack([img_to_array(load_img(filename)) for filename in y_files])

### Create dataset loader

In [8]:
X_train, X_valid, y_train, y_valid = train_test_split(X_images, y_images, test_size=0.2, random_state=1)

X_train = torch.from_numpy(X_train)
y_train = torch.from_numpy(y_train)
X_valid = torch.from_numpy(X_valid)
y_valid = torch.from_numpy(y_valid)

train_dataset = Dataset2d(X_train, y_train)
val_dataset = Dataset2d(X_valid, y_valid)

train_loader = DataLoader(dataset=train_dataset, batch_size=setup['batch_size'], shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=setup['batch_size'])

### Run deep learning

In [9]:
model = Unet(n_channels=1, n_filters=16, n_classes=1).to(device)

In [10]:
loss_fn = surface_loss
optimizer = optim.Adam(model.parameters(), lr=setup['lr'])

In [None]:
losses = []
val_losses = []

for epoch in range(setup['epochs']):
    print(f"Epoch {epoch + 1} / {setup['epochs']}")
    start = time.time()
    batch_losses = []
    batch_val_losses = []
    
    for X_batch, y_batch in train_loader:
        X_batch = X_batch.to(device)
        y_batch = y_batch.to(device)
        
        model.train()

        y_hat = model(X_batch)
          
        # calc distance and pass it to surface loss
        y_batch_numpy = y_batch.cpu().numpy()
        y_dist = np.array([calc_dist_map(y) for y in y_batch_numpy]).astype(np.float32)
        y_dist = torch.from_numpy(y_dist).to(device)

        loss = loss_fn(y_hat, y_dist)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        batch_losses.append(loss.item())
        
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            
            model.eval()
            
            y_hat = model(X_batch)
            
            val_loss = loss_fn(y_hat, y_batch)
            batch_val_losses.append(val_loss.item())
        
        
    losses.append(np.mean(batch_losses))
    val_losses.append(np.mean(batch_val_losses))
    
    print(f'Time per epoch: {(time.time() - start):.3f} seconds')
    print(f'Train. loss:', np.mean(batch_losses))
    print(f'Valid. loss:', np.mean(batch_val_losses))
    print(f'---------------------------------------------------')
    

Epoch 1 / 50
Time per epoch: 54.395 seconds
Train. loss: 6.915209076709244
Valid. loss: 0.005713049112068069
---------------------------------------------------
Epoch 2 / 50
Time per epoch: 54.067 seconds
Train. loss: 0.5821487852513911
Valid. loss: 0.006149502332142043
---------------------------------------------------
Epoch 3 / 50
Time per epoch: 54.472 seconds
Train. loss: 0.18628107695925308
Valid. loss: 0.002138805359971875
---------------------------------------------------
Epoch 4 / 50
Time per epoch: 54.762 seconds
Train. loss: 0.07733907402090683
Valid. loss: 0.002834622932676857
---------------------------------------------------
Epoch 5 / 50
Time per epoch: 54.785 seconds
Train. loss: 0.04084676160320307
Valid. loss: 0.0027389867442956703
---------------------------------------------------
Epoch 6 / 50
Time per epoch: 55.749 seconds
Train. loss: 0.021889229676767458
Valid. loss: 0.0025590235946442736
---------------------------------------------------
Epoch 7 / 50
Time per 

In [None]:
plt.plot(range(setup['epochs']), losses, val_losses)
plt.ylabel('loss')
plt.xlabel('epoch')

### Predictions

In [None]:
X_files = sorted(
    glob.glob(
        os.path.join(dataset_dir, 'test', 'images', '*.png')
    )[:]
)
y_files = sorted(
    glob.glob(
        os.path.join(dataset_dir, 'test', 'labels', '*.png')
    )[:]
)

X_images = np.stack([img_to_array(load_img(filename)) for filename in X_files])
y_images = np.stack([img_to_array(load_img(filename)) for filename in y_files])

X_test = torch.from_numpy(X_images)
y_test = torch.from_numpy(y_images)

test_dataset = Dataset2d(X_test, y_test)

test_loader = DataLoader(dataset=test_dataset, batch_size=setup['batch_size'])

In [None]:
scan_preds = list()
scan_mask = list()

for X_batch, y_batch in test_loader:
    X_batch = X_batch.to(device)
    y_batch = y_batch.to(device)

    with torch.no_grad():
        model.eval()

        preds = model(X_batch)
        preds = preds.cpu().numpy()
        preds = (preds > setup['threshold']).astype(np.uint8)
        scan_preds.append(preds.squeeze())
        
        mask = y_batch.cpu().numpy().squeeze().astype(np.uint8)
        scan_mask.append(mask)
        
scan_preds = np.concatenate([img for img in scan_preds])
scan_mask = np.concatenate([img for img in scan_mask])

print('Preds shape:', scan_preds.shape)
print('Mask shape:', scan_mask.shape)

### Performance

In [None]:
res = calc_confusion_matrix(scan_mask, scan_preds)
res

In [None]:
test_scan = cubify_scan(scan_preds, 256, 256, 192, 256)
test_scan.shape

In [None]:
render_scan(test_scan, 256)