In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules import Evaluator3D, FeTADataSet, Trainer3D, UNet3D
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.Tensorboard import TensorboardModules
from modules.Utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
params = {"num_epochs": 100,
          "batch_size": 1,
          "lr": 1e-2,
          "momentum": 0.9,
          "nesterov": True,
          "patch_sizes": (128, 128, 128),
          "image_sizes": (256, 256, 256),
          "ES":{"patience": 5, "min_delta": 1e-3},
          #"CLR":{"base": 1e-07, "max": 0.1, "up": 4, "down": 8, "mode": "triangular2"},
          "SLR":{'step_size': 10, "gamma": 1e-1}
         }

output_path = "output/UNet3D/run5"
weight_path = os.path.join(output_path, "weights/")

In [None]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Create patch indexes.
patch_indexes = create_patch_indexes(params["image_sizes"], params["patch_sizes"])

tb = TensorboardModules(output_path)
# Save hyperparameters as note.
(pd.DataFrame.from_dict(data=params, orient='index')
 .to_csv(os.path.join(output_path,"hyper_parameters.txt"), header=False, sep="="))

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet("train", path="data", transform=transform_)
val = FeTADataSet("val", path="data", transform=transform_)

torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=params["batch_size"], shuffle=True)

# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[8]
slices = (80, 150, 10)
tb.add_image_mask(mri_image, mri_mask, slices)

In [None]:
model = UNet3D().to(device)
model.apply(init_weights_kaiming)

# Add model graph to Tensorboard.
tb.add_graph(model, params["patch_sizes"], device)
#print(summary(model, input_size=(1, 256, 256)))

criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"], 
                            momentum=params["momentum"], nesterov=params["nesterov"])

#scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=params["CLR"]["base"], max_lr=params["CLR"]["max"],
#                                              step_size_up=params["CLR"]["up"], step_size_down=params["CLR"]["down"],
#                                              mode=params["CLR"]["mode"])

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=params["SLR"]["step_size"], gamma=params["SLR"]["gamma"])

early_stopping = EarlyStopping(patience=params["ES"]["patience"], min_delta=params["ES"]["min_delta"])

# Initalize trainer for training.
trainer = Trainer3D(model, train_loader, optimizer, criterion, params["patch_sizes"], 
                    params["num_epochs"], scheduler)

# Initalize evaluator for validation.
evaluator = Evaluator3D(criterion, model, params["patch_sizes"], val_loader)

In [None]:
# Initalize trainer for training.
#lr_trainer = Trainer3D(model, train_loader, optimizer, criterion, patch_indexes, 12)

#lr_finder = LearningRateFinder(lr_trainer)
#lr_finder.find(startLR=1e-9)

In [None]:
prev_weights = ""
prev_val_loss = 100

for epoch in range(params["num_epochs"]):
    # One forward pass for all training data.
    avg_train_loss = trainer.fit()
    
    # Evaluate current model on validation data.
    avg_val_loss, avg_scores = evaluator.evaluate(model)
    
    print("-------------------------------------------------------------")
    
    # Add results to tensorboard.
    tb.add_scalars(step=epoch+1, lr=scheduler.get_last_lr()[0], ds=avg_scores, 
                   train_loss=avg_train_loss, val_loss=avg_val_loss)
    
    model_name = "_".join([str(epoch), "model.pth"])
    model_path = os.path.join(weight_path, model_name)
    
    if avg_val_loss < prev_val_loss:
        # Save trained weights.
        if os.path.isfile(prev_weights):
            os.remove(prev_weights)        
        torch.save(model.state_dict(), model_path)
        
    prev_weights = model_path        
    prev_val_loss = avg_val_loss
    
    # If model is not learning stop the training.
    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        break

print('Finished Training')