# Train

## Imports 

In [1]:
# Task: Import the necessary libraries

from pathlib import Path

import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import imgaug.augmenters as iaa
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from celluloid import Camera

from dataset_lung import LungDataset
from model import UNet

## Dataset creation

In [2]:
# Task: Create the train and val dataset and the augmentation pipeline. Use Affine augmentations with:

#     1) 15% translation,
#     2) scaling between 0.85 and 1.15
#     3) rotations from -45 to 45°.
    
# Additionally use ElasticTransformation

seq = iaa.Sequential([
    iaa.Affine(translate_percent=(0.15), scale=(0.85, 1.15), rotate=(-45,45)),
    iaa.ElasticTransformation()
])

In [3]:
train_path = Path('Data/Atrium/Task06_Lung/Preprocessed/train/')
val_path = Path('Data/Atrium/Task06_Lung/Preprocessed/val/')

train_dataset = LungDataset(train_path, seq)
val_dataset = LungDataset(val_path, None)

## Oversampling to tackle strong class imbalance 

In [4]:
target_list = []

for _, label in tqdm(train_dataset):
    if np.any(label):
        target_list.append(1)
    else:
        target_list.append(0)

  0%|          | 0/14484 [00:00<?, ?it/s]

In [5]:
unique = np.unique(target_list, return_counts=True)
unique

(array([0, 1]), array([12960,  1524], dtype=int64))

In [6]:
fraction = unique[1][0] / unique[1][1]
fraction

8.503937007874017

In [7]:
weight_list = []

for target in target_list:
    if target == 0:
        weight_list.append(1)
    else:
        weight_list.append(fraction)

In [8]:
sampler = torch.utils.data.sampler.WeightedRandomSampler(weight_list, len(weight_list))

In [10]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1,
                                           num_workers=4, sampler=sampler)
# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = 1,
#                                            num_workers=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = 1,
                                         num_workers=4, shuffle=False)

## Loss 

Use Binary Cross Entropy

## Full Segmentation Model

In [11]:
# Task: Create the pytorch lightning model. Use Binary Cross Entropy as loss function and the 
# Adam optimizer with a learning rate of 1e-4

class LungSegmentation(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = UNet()
        
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = torch.nn.BCEWithLogitsLoss()
        
        
    def forward(self, data):
        return self.model(data)
    
    
    def training_step(self, batch, batch_idx):
        ct, mask = batch
        mask = mask.float()
        
        pred = self(ct.float())
        loss = self.loss_fn(pred, mask)
        
        self.log('Train Dice', loss)
        if batch_idx % 50 == 0:
            self.log_images(ct.cpu(), pred.cpu(), mask.cpu(), 'Train')
        
        return loss
    
    
    def validation_step(self, batch, batch_idx):
        ct, mask = batch
        mask = mask.float()
        
        pred = self(ct.float())
        loss = self.loss_fn(pred, mask)
        
        self.log('Val Dice', loss)
        if batch_idx % 50 == 0:
            self.log_images(ct.cpu(), pred.cpu(), mask.cpu(), 'Val')
        
        return loss
    
    
    def log_images(self, ct, pred, mask, name):
        pred = pred > 0.5
        
        fig, axis = plt.subplots(1, 2)
        
        axis[0].imshow(ct[0][0], cmap='bone')
        mask_ = np.ma.masked_where(mask[0][0] == 0, mask[0][0])
        axis[0].imshow(mask_, alpha=0.6)
        axis[0].set_title("Ground Truth")
        
        axis[1].imshow(ct[0][0], cmap='bone')
        mask_ = np.ma.masked_where(pred[0][0] == 0, pred[0][0])
        axis[1].imshow(mask_, alpha=0.6, cmap='autumn')
        axis[1].set_title("Pred")
        
        self.logger.experiment.add_figure(name, fig, self.global_step)
    
    
    def configure_optimizers(self):
        return[self.optimizer]

In [12]:
# Task: Instanciate the model, create a checkpoint callback and define the trainer.
# Train the model for 30 epochs and use a TensorboardLogger to log your training process.

model = LungSegmentation()

In [13]:
checkpoint_callback = ModelCheckpoint(monitor='Val Dice', save_top_k=30, mode='min')

In [14]:
trainer = pl.Trainer(gpus=1, logger=TensorBoardLogger(save_dir='logs/lungs'), log_every_n_steps=1,
                     callbacks=checkpoint_callback, max_epochs=30)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [15]:
trainer.fit(model, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type              | Params
----------------------------------------------
0 | model   | UNet              | 7.8 M 
1 | loss_fn | BCEWithLogitsLoss | 0     
----------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.127    Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]



RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 240 but got size 241 for tensor number 1 in the list.

## Evaluation

In [None]:
# Task: Load the latest checkpoint and evaluate the results by computing the prediction for the
# complete validation dataset and then compute the dice score for it

class DiceLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, mask):
        pred = torch.flatten(pred)     # flattens 4D input
        mask = torch.flatten(mask)
        
        counter = (pred * mask).sum()
        denum = pred.sum() + mask.sum() + 1e-8    # in case pred and mask are 0 ==> no 0 division
        dice = (2 * counter) / denum
        
        return 1 - dice

In [None]:
model = LungSegmentation.load_from_checkpoint('logs/lungs/...')

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.eval();
model.to(device)

In [None]:
preds = []
labels = []

for slice, label in tqdm(val_dataset):
    slice = torch.tensor(slice).to(device).unsqueeze(0)
    with torch.no_grad():
        pred = model(slice)
    preds.append(pred.cpu().numpy())
    labels.append(label)

preds = np.array(preds)
labels = np.array(labels)

In [None]:
DiceScore()(torch.from_numpy(preds), torch.from_numpy(labels).unsqueeze(0).float())

## Visualization 

In [None]:
# Task: Compute a prediction for a patient and visualize the prediction.
import nibabel as nib
import cv2

subject = 'Data/Atrium/Task06_Lung/imagesTs/lung_002.nii.gz'

# standardize subject
subject_ct = nib.load(subject).get_fdata() / 3071 
# crop
ct = ct[:, :, 30:]

In [None]:
segmentation = []
label = []
scan = []

for i in range(ct.shape[-1]):
    slice = ct[:, :, i]
    slice = cv2.resize(slice, (256, 256))
    slice = torch.tensor(slice)
    scan.append(slice)
    slice = slice.unsqueeze(0).unsqueeze(0).float().to(device)
    
    with torch.no_grad():
        pred = model(slice)[0][0].cpu()
    
    pred = pred > 0.5
    
    segmentation.append(pred)
    label.append(segmentation)

In [None]:
fig = plt.figure()
camera = Camera(fig)

for i in range(0, len(scan), 2):
    plt.imshow(scan[i], cmap='bone')
    
    mask = np.ma.masked_where(segmentation[i] == 0, segmentation[i])
    plt.imshow(mask, alpha=0.5, cmap='autumn')
    
    camera.snap()

animation = camera.animate()