In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.Evaluator import Evaluator3D
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.Trainer import Trainer3D
from modules.UNet import UNet3D
from modules.Utils import calculate_dice_score, create_onehot_mask, create_patch_indexes, init_weights_kaiming
from modules.Utils import EarlyStopping, LearningRateFinder, TensorboardModules 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
params = {"num_epochs": 250,
          "batch_size": 1,
          "lr": 0.1,
          "momentum": 0.9,
          "nesterov": True,
          "patch_sizes": (128, 128, 128),
          "image_sizes": (256, 256, 256)
         }

output_path = "output/UNet3D/run1"
weight_path = os.path.join(output_path, "weights/")

In [None]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Create patch indexes.
patch_indexes = create_patch_indexes(params["image_sizes"], params["patch_sizes"])

tb = TensorboardModules(output_path)
# Save hyperparameters as note.
(pd.DataFrame.from_dict(data=params, orient='index')
 .to_csv(os.path.join(output_path,"hyper_parameters.txt"), header=False, sep="="))

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet("train", path="feta_2.1", transform=transform_)
val = FeTADataSet("val", path="feta_2.1", transform=transform_)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=params["batch_size"])
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=params["batch_size"])

In [None]:
# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[8]
slices = (80, 150, 10)
tb.add_images("Fetal Brain Images", mri_image, slices)
tb.add_images("Fetal Brain Masks", mri_mask, slices)

In [None]:
# Learning rate finder.
# lr_finder = LearningRateFinder(trainer)
# losses, lrs = lr_finder.find(startLR=1e-1)

In [None]:
model = UNet3D().to(device)
model.apply(init_weights_kaiming)

# Add model graph to Tensorboard.
tb.add_graph(model, params["patch_sizes"], device)
#print(summary(model, input_size=(1, 256, 256)))

criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"], 
                            momentum=params["momentum"], nesterov=params["nesterov"])
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
early_stopping = EarlyStopping()

# Initalize trainer for training.
trainer = Trainer3D(model, train_loader, optimizer, criterion, patch_indexes, 
                    params["num_epochs"], scheduler)

# Initalize evaluator for validation.
evaluator = Evaluator3D(criterion, model, patch_indexes, val_loader)

In [None]:
prev_weights = ""
prev_val_loss = 100

for epoch in range(params["num_epochs"]):
    # One forward pass for all training data.
    avg_train_loss = trainer.fit()
    
    # Evaluate current model on validation data.
    avg_val_loss, avg_scores = evaluator.evaluate(model)
    
    # Add results to tensorboard.
    tb.add_scalars(step=epoch, lr=scheduler.get_last_lr()[0], ds=avg_scores, 
                   train_loss=avg_train_loss, val_loss=avg_val_loss)
    
    model_name = "_".join([str(epoch), "model.pth"])
    model_path = os.path.join(weight_path, model_name)
    
    if avg_val_loss < prev_val_loss:
        # Save trained weights.
        if os.path.isfile(prev_weights):
            os.remove(prev_weights)        
        torch.save(model.state_dict(), model_path)
        
    prev_weights = model_path        
    prev_val_loss = avg_val_loss
    
    # If model is not learning stop the training.
    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        break

print('Finished Training')