In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.Evaluator import Evaluator2D
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.UNet import UNet2Dv2
from modules.Utils import calculate_dice_score, create_onehot_mask, init_weights_kaiming
from modules.Utils import TensorboardModules 

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyper-parameters 
num_epochs = 250
batch_size_ = 1
lr_ = 0.01
momentum_ = 0.9
nesterov_ = True
shape = (256, 256, 256)

output_path = "output/UNet2D/run1"
weight_path = os.path.join(output_path, "weights/")

In [3]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

tb = TensorboardModules(output_path)

In [4]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet("train", path="data", transform=transform_)
val = FeTADataSet("val", path="data", transform=transform_)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size_)
val_loader = torch.utils.data.DataLoader(dataset=val, batch_size=batch_size_)

In [7]:
# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[8]
slices = (80, 150, 10)
tb.add_images("Fetal Brain Images", mri_image, slices)
tb.add_images("Fetal Brain Masks", mri_mask, slices)

torch.Size([126, 256, 256])
torch.Size([126, 256, 256])


In [8]:
model = UNet2Dv2().to(device)
model.apply(init_weights_kaiming)

# Add model graph to Tensorboard.
tb.add_graph(model, (256, 256), device)
#print(summary(model, input_size=(1, 256, 256)))

criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
optimizer = torch.optim.SGD(model.parameters(), lr=lr_, momentum=momentum_, nesterov=nesterov_)

# Initalize evaluator for validation.
evaluator = Evaluator2D(model, val_loader)

In [20]:
count_forward = 0
n_total_steps = len(train_loader)
running_loss = 0.0
running_dice_scores = torch.zeros(8).to(device)
bs_2d = 16

for epoch in range(num_epochs):
    for i, (image, mask) in enumerate(train_loader):
        image = image.to(device) #[bs,x,y,z]        
        mask = mask.to(device) #[x,y,z]
        
        # Slice 3D image. It's like splitting 3D images into batches.
        for slice_ix in range(0, image.shape[-1], bs_2d):
            start = slice_ix
            stop = slice_ix+bs_2d
            
            if stop > image.shape[-1]:
                stop = image.shape[-1]-1
            
            slice_image = image[:, start:stop]
            slice_mask = mask[:, start:stop]
            slice_image = slice_image.view(-1, 1, 256, 256)
            slice_mask = slice_mask.view(-1, 1, 256, 256)
                           
            outputs = model(slice_image.float())            
            one_hot_mask = create_onehot_mask(outputs.shape, slice_mask)
            
            loss = criterion(outputs, slice_mask)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Sum losses and dice scores for all predicitions.
            running_loss += loss.item()
            count_forward += 1
        
        
    # Save torch model.                
    model_name = "_".join([str(epoch+1), str(lr_), str(nesterov_), "model.pth"])
    torch.save(model.state_dict(), os.path.join(weight_path, model_name))    
    
    # Add average loss per 10th step to Tensorboard.
    avg_loss = running_loss / count_forward
    step = epoch * n_total_steps + i
    tb.add_loss(avg_loss, step)
    #Set values to zero for new calculations.
    count_forward = 0
    running_loss = 0.0
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss}")
    
    #Add average dice score per 10th step to Tensorboard.
    avg_scores = evaluator.evaluate(model)
    tb.add_dice_score(avg_scores, epoch)
    

print('Finished Training')

2.6758439540863037


AssertionError: 

In [None]:
model.load_state_dict(torch.load("weights/UNet2D/139_0.01_True_model.pth"))
model.eval()

In [None]:
# Show results.
im_id = 4
slice_id = 72
class_index = 0

image, mask = test[im_id]
image = torch.Tensor(image)

inp = image[:, :, slice_id].view(1, 1, 256, 256)
inp = torch.Tensor(inp).to(device)
out = F.softmax(model(inp.float()), dim=1)

gt = torch.Tensor(mask[:, :, slice_id]).view(1, 1, 256, 256)
gt = create_onehot_mask(out.shape, gt.to(device), device)

fig, ax = plt.subplots(1, 3, figsize=(15, 15))
ax[0].imshow(image[:, :, slice_id])
ax[1].imshow(out[0, class_index].cpu().detach().numpy(), cmap="gray")
ax[2].imshow(gt[0, class_index].cpu(), cmap="gray")
plt.show()