In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules import Evaluator3D, MRIDataset, MRIDatasets, Trainer3D, UNet3D
from modules.Transforms import *
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.Tensorboard import TensorboardModules
from modules.Utils import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters
params = {"total_epochs": 100,
          "batch_size": 1,
          "patch_sizes": (128, 128, 128),
          "SGD": {"lr": 1e-01, "momentum": 0.9, "nesterov": True},
          #"Adam": {"lr": 1e-05, "betas":(0.9, 0.999), "eps": 1e-8},       
          "ES":{"patience": 20, "min_delta": 1e-03},
          "CLR":{"base": 1e-05, "max": 1e-02, "up": 3, "down": 5, "mode": "triangular2"},
          #"SLR":{'step_size': 13, "gamma": 1e-1}          
         }

output_path = "output/UNet3D/Iteration9/run4"
weight_path = os.path.join(output_path, "weights/")

In [None]:
# There are multiple data folders belong to same dataset. Each of them processed in different ways.
# Therefore, path of the data and its name explicitly are defined.
dataset_train = MRIDatasets.dHCP_FeTA
dataset_path_train = os.path.join('data')
dataset_val = MRIDatasets.FeTA_BalancedDistribution
dataset_path_val = os.path.join('data', 'feta_processed')

#cv_ = "cv3" # 5-fold cross-validation. Folds [cv1-cv5]

# Transformations.
transform_train = transforms.Compose([RandomMotion(), RandomAffine(degrees=[15])])

transform_eval = None # transforms.Compose([Mask()])


# Split dataset.
train = MRIDataset(dataset_train, "train", dataset_path_train, transform=transform_train)
val = MRIDataset(dataset_val, "val", dataset_path_val, transform=None)
test = MRIDataset(dataset_val, "test", dataset_path_val, transform=None)

torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=params["batch_size"])
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=params["batch_size"])

# Add dataset configuration to parameters to save them as meta data.
params["dataset_train"] = dataset_train.name
params["dataset_path_train"] = dataset_path_train
params["dataset_val"] = dataset_val.name
params["dataset_path_val"] = dataset_path_val
#params["cross_validation"] = "None" if not cv_ else cv_
params["transform_train"] = "None" if not transform_train else str(transform_train.transforms)
params["transform_eval"] = "None" if not transform_eval else str(transform_eval.transforms)

In [None]:
model = UNet3D().to(device)
criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
pretrained = False

# Initalize weights or load already trained model.
if not pretrained:
    model.apply(init_weights_kaiming)
    params["initial_weights"] = init_weights_kaiming.__name__
else:
    model_path = "output/UNet3D/Iteration8/run1/weights/39_model.pth"
    model.load_state_dict(torch.load(model_path))
    params["initial_weights"] = model_path

# Section: Training

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=params["SGD"]["lr"], 
                            momentum=params["SGD"]["momentum"], nesterov=params["SGD"]["nesterov"])


scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=params["CLR"]["base"], 
                                              max_lr=params["CLR"]["max"],
                                              step_size_up=params["CLR"]["up"], 
                                              step_size_down=params["CLR"]["down"],
                                              mode=params["CLR"]["mode"])

# scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=params["SLR"]["step_size"], 
#                                            gamma=params["SLR"]["gamma"])

early_stopping = EarlyStopping(patience=params["ES"]["patience"], min_delta=params["ES"]["min_delta"])


# Initalize trainer for training.
trainer = Trainer3D(criterion, model, optimizer, params["patch_sizes"], 
                    params["total_epochs"], train_loader, scheduler)

# Initalize evaluator for validation.
evaluator = Evaluator3D(criterion, model, params["patch_sizes"], val_loader)

In [None]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Create Tensorboard object to save experiment outputs.    
tb = TensorboardModules(output_path)

# Save hyperparameters as note.
(pd.DataFrame.from_dict(data=params, orient='index')
 .to_csv(os.path.join(output_path,"details.txt"), header=False, sep="="))

# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[1]
slices = (80, 150, 10)
tb.add_image_mask(mri_image, mri_mask, slices)

# Add model graph to Tensorboard.
tb.add_graph(model, params["patch_sizes"], device)
# print(summary(model, input_size=(1, 32, 128, 128)))

In [None]:
prev_weights = ""
prev_val_loss = 100

for epoch in range(0, params["total_epochs"]):
    # One forward pass for all training data.
    avg_train_loss = trainer.fit()
    
    # Evaluate current model on validation data.
    avg_val_loss, dice_scores = evaluator.evaluate()
    avg_scores = sum(dice_scores) / len(dice_scores)
    
    print("-------------------------------------------------------------")
    
    # Add results to tensorboard.
    tb.add_scalars(step=epoch+1, lr=scheduler.get_last_lr()[0], ds=avg_scores, 
                   train_loss=avg_train_loss, val_loss=avg_val_loss)
    
    model_name = "_".join([str(epoch), "model.pth"])
    model_path = os.path.join(weight_path, model_name)
    
    if avg_val_loss < prev_val_loss:
        # Save trained weights.
        if os.path.isfile(prev_weights):
            os.remove(prev_weights)        
        torch.save(model.state_dict(), model_path)
        
    prev_weights = model_path        
    prev_val_loss = avg_val_loss
    
    # If model is not learning stop the training.
    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        break

print('Finished Training')

# Section: Evalutaion

In [None]:
tissue_classes = ["Background", "eCSF", "Gray Matter", "White Matter", "Ventricles", 
                  "Cerrebilium", "Deep Gray Matter", "Brain Stem"]

# Evaluate the last model on validation set.
evaluator = Evaluator3D(criterion, model, params["patch_sizes"], val_loader)
val_loss, val_scores = evaluator.evaluate()
avg_val_scores = sum(val_scores) / len(val_scores)
# Convert Tensors to list.
val_scores = [score.tolist() for score in val_scores]
# Combine results and subject information to examine data carefully. 
val_results = pd.DataFrame(val_scores, index=val.meta_data["participant_id"], columns=tissue_classes)
val_results.drop(columns="Background", inplace=True)
val_results = pd.merge(val.meta_data, val_results, on=["participant_id"])

# Display results.
print(f"Average Validation Dice Scores{avg_val_scores}")
plt.boxplot(val_results.iloc[:, 3:]) # Plot only dice scoress in box plot.
plt.show()
val_results

In [None]:
# Evaluate the last model on validation set. 
# If cross-validation was used test set is not available for evaluation.
evaluator = Evaluator3D(criterion, model, params["patch_sizes"], test_loader)
test_loss, test_scores = evaluator.evaluate()
avg_test_scores = sum(test_scores) / len(test_scores)
# Convert Tensors to list.
test_scores = [score.tolist() for score in test_scores]
# Combine results and subject information to examine data carefully. 
test_results = pd.DataFrame(test_scores, index=test.meta_data["participant_id"], columns=tissue_classes)
test_results.drop(columns="Background", inplace=True)
test_results = pd.merge(test.meta_data, test_results, on=["participant_id"])

# Display results.
print(f"Average Test Dice Scores{test_scores}")
plt.boxplot(test_results.iloc[:, 3:])
plt.show()
test_results

In [None]:
# Draw an example output of trained model.
mri, mask = val[6]
pred = F.softmax(evaluator.predict(mri.view(1, *mri.shape)), dim=1)
pred = torch.argmax(pred, dim=1)

mask2 = mask.clone()
index = 65
class_id = 0
mask2[:, index, :][mask2[:, index, :]!=class_id] = 0
plot_sub(mri[:, :, index], mask2[:, :, index], pred[0, :, :, index])

In [None]:
# Save outputs.
output_folder = "eval"

for i, (mri, mask) in enumerate(val_loader):
    pred = F.softmax(evaluator.predict(mri), dim=1)
    one_hot_mask = create_onehot_mask(pred.shape, mask.unsqueeze(0))
    pred = torch.argmax(pred, dim=1, keepdim=True)
    one_hot_pred = create_onehot_mask(one_hot_mask.shape, pred)
    dice_scores = calculate_dice_score(one_hot_pred, one_hot_mask)
    with open(os.path.join(output_folder, f"{i+1}.txt"), 'w') as writer:
        writer.write(str([round(score, 3) for score in dice_scores.tolist()]))
        
    pred = pred.squeeze()
    mri = mri.squeeze().detach().cpu().numpy()
    mask = mask.squeeze().detach().cpu().numpy()
    pred = pred.squeeze().detach().cpu().numpy().astype(np.float32)
    
    mri_name = f"{i+1}_mri"
    mask_name = f"{i+1}_mask"
    pred_name = f"{i+1}_prediction"
    
    save_nii(output_folder, mri_name, mri)
    save_nii(output_folder, mask_name, mask)
    save_nii(output_folder, pred_name, pred)