In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.Evaluator import Evaluator3D
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.UNet import UNet3D
from modules.Utils import calculate_dice_score, create_onehot_mask, create_patch_indexes, init_weights_kaiming
from modules.Utils import TensorboardModules 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_epochs = 250
batch_size_ = 1
lr_ = 0.01
momentum_ = 0.9
nesterov_ = True
shape = (256, 256, 256)
patch_sizes = (128, 128, 128)

output_path = "output/UNet3D/run1"
weight_path = os.path.join(output_path, "weights/")

In [None]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Create patch indexes.
patch_indexes = create_patch_indexes(shape, patch_sizes)

tb = TensorboardModules(output_path)

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet("train", transform=transform_)
val = FeTADataSet("val", transform=transform_)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size_)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=batch_size_)

In [None]:
# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[8]
slices = (80, 150, 10)
tb.add_images("Fetal Brain Images", mri_image, slices)
tb.add_images("Fetal Brain Masks", mri_mask, slices)

In [None]:
model = UNet3D().to(device)
model.apply(init_weights_kaiming)

# Add model graph to Tensorboard.
tb.add_graph(model, patch_sizes, device)
#print(summary(model, input_size=(1, 256, 256)))

criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
optimizer = torch.optim.SGD(model.parameters(), lr=lr_, momentum=momentum_, nesterov=nesterov_)

# Initalize evaluator for validation.
evaluator = Evaluator3D(model, patch_indexes, val_loader)

In [None]:
count_forward = 0
n_total_steps = len(train_loader)
running_loss = 0.0

for epoch in range(num_epochs):
    for i, (image, mask) in enumerate(train_loader):
        image = image.to(device) #[bs,x,y,z]
        image = image.view(batch_size_, 1, shape[0], shape[1], shape[2]) #[bs,c,x,y,z]
        
        mask = mask.to(device) #[x,y,z]
        mask = mask.view(batch_size_, 1, shape[0], shape[1], shape[2]) #[bs,1,x,y,z]
        
        for coors in patch_indexes:
            [sx, sy, sz] = coors[0]
            [ex, ey, ez] = coors[1]            
            patch_image = image[:, :, sx:ex, sy:ey, sz:ez]
            patch_mask = mask[:, :, sx:ex, sy:ey, sz:ez]
            
            unique, counts = torch.unique(mask[sx:ex, sy:ey, sz:ez], return_counts=True) 
            
            # If the label count below the 70 pixel don't train with it.            
            if torch.any(counts[1:]>70):
                outputs = model(patch_image.float())

                loss = criterion(outputs, patch_mask)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Sum losses and dice scores for all predicitions.
                running_loss += loss.item()
                count_forward += 1

    # Save torch model.                
    model_name = "_".join([str(epoch+1), str(lr_), str(nesterov_), "model.pth"])
    torch.save(model.state_dict(), os.path.join(weight_path, model_name))    
    
    # Add average loss per 10th step to Tensorboard.
    avg_loss = running_loss / count_forward
    step = epoch * n_total_steps + i
    tb.add_loss(avg_loss, step)
    #Set values to zero for new calculations.
    count_forward = 0
    running_loss = 0.0
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss}")
    
    #Add average dice score per 10th step to Tensorboard.
    avg_scores = evaluator.evaluate(model)
    tb.add_dice_score(avg_scores, epoch)
    

print('Finished Training')