# Training

## Part 1

In [None]:
from pathlib import Path

import torchio as tio
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np

from model_3d import UNet

In [1]:
def change_img_to_label_path(path):
    """
    Replace data with mask to get the masks
    """
    parts = list(path.parts)
    parts[parts.index("imagesTr")] = "labelsTr"
    return Path(*parts)

In [None]:
path = Path('Data/Atrium/Task03_Liver/imagesTr/')
subject_paths = list(path.glob('liver_*'))
subjects = []

for subject_path in subject_paths:
    label_path = change_img_to_label_path(subject_path)
    subject = tio.Subject({'CT': tio.ScalarImage(subject_path),
                           'Label': tio.LabelMap(label_path)})
    subjects.append(subject)

In [None]:
for subject in subjects:
    assert subject['CT'].orientation == ('R', 'A', 'S')

In [None]:
process = tio.Compose([
    tio.CropOrPas((256, 256, 200)),
    tio.RescaleIntensity((-1, 1))
])

augmentation = tio.RandomAffine(scales=(0.9, 1.1), degrees=(-10, 10))

val_transform = process
train_transform = tio.Compose([process, augmentation])

In [None]:
train_dataset = tio.SubjectsDataset(subjects[:105], transform=train_transform)
val_dataset = tio.SubjectsDataset(subjects[105:], transform=val_transform)

sampler = tio.data.LabelSampler(patch_size=96, label_name='Label',
                                label_probabilities={0:0.2, 1:0.3, 2:0.5})

In [None]:
train_patches_queue = tio.Queue(train_dataset, max_length=40, samples_per_volume=5,
                                sampler=sampler, num_workers=4)
val_patches_queue = tio.Queue(val_dataset, max_length=40, samples_per_volume=5,
                              sampler=sampler, num_workers=4)

In [None]:
train_loader = torch.utils.data.DataLoader(train_patches_queue, batch_size=2, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_patches_queue, batch_size=2, num_workers=0)

## Part 2 

In [None]:
class Segmenter(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = UNet()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-4)
        self.loss_fn = torch.nn.CrossEntropyLoss()
    
    def forward(self, data):
        return self.model(data)
    
    def training_step(self, batch, batch_idx):
        img = batch['CT']['data']
        mask = batch['Label']['data'][:, 0]
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        self.log("Train Loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        img = batch['CT']['data']
        mask = batch['Label']['data'][:, 0]
        mask = mask.long()
        
        pred = self(img)
        loss = self.loss_fn(pred, mask)
        
        self.log("Val Loss", loss)
        return loss
    
    def configure_optimizers(self):
        return [self.optimizer]

In [None]:
model = Segmenter()

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='Val Loss',
    save_top_k=10,
    mode='min')

In [None]:
trainer = pl.Trainer(gpus=1, logger=TensorBoardLogger(save_dir='logs/liver'), 
                     log_every_n_steps=1, callbacks=checkpoint_callback, max_epoch=100)

In [None]:
trainer.fit(model, train_loader, val_loader)

## Part 3 

In [None]:
from IPython.display import HTML
from celluloid import Camera
import numpy as np

In [None]:
model = Segmenter.load_from_checkpoint('logs/liver/checkpoints/...')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.eval();
model.to(device)

### Patch Aggregation 

In [None]:
IDX = 4
imgs = val_dataset[IDX]['CT']['data']
mask = val_dataset[IDX]['Label']['data']

grid_sampler = tio.inference.GridSampler(val_dataset[IDX], 96, (8, 8, 8))

In [None]:
aggregator = tio.inference.GridAggregator(grid_sampler)

In [None]:
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=4)

In [None]:
with torch.no_grad():
    for patches_batch in patch_loader:
        input_tensor = patches_batch['CT']['data'].to(device)
        locations = patches_batch[tio.LOCATION]
        
        pred = model(input_tensor)
        aggregator.add_batch(pred, locations)

In [None]:
output_tensor = aggregator.get_output_tensor()

In [None]:
fig = plt.figure()
camera = Camera(fig)

pred = output_tensor.argmax(0)

for i in range(0, output_tensor.shape[-1], 2):
    plt.imshow(imgs[0, :, :, i], cmap='bone')
    
    mask_ = np.ma.masked_where(pred[:, :, i] == 0, pred[:, :, i])
    plt.imshow(mask_, alpha=0.5, cmap='autumn')
    
    camera.snap()
    
animation = camera.animate()

In [None]:
HTML(animation.to_html5_video())