In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.LossFunctions import GDiceLossV2
from modules.UNet import UNet2Dv2
from modules.Utils import create_onehot_mask

import segmentation_models_pytorch as smp

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_epochs = 500
batch_size_ = 1
learning_rate = 0.1
weight_path = "weights/UNet2D"
shape = (256, 256, 256)

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet(train=True, transform=transform_)
test = FeTADataSet(train=False, transform=transform_)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size_)
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size_)

In [None]:
model = UNet2Dv2().to(device)
#print(summary(model, input_size=(1, 64, 64)))

criterion = GDiceLossV2()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_loader)

In [None]:
for epoch in range(num_epochs):
    for i, (image, mask) in enumerate(train_loader):
        image = image.to(device) #[bs,x,y,z]        
        mask = mask.to(device) #[x,y,z]
                
        for slice_ix in range(image.shape[1]):
            slice_image = image[0, :, :, slice_ix].view(1, 1, 256, 256) #to get axial slice [:, :, index]
            slice_mask = mask[0, :, :, slice_ix].view(1, 1, 256, 256)
            
            unique, counts = torch.unique(slice_mask, return_counts=True) 

            # if the label count below the 16 pixel don't train with it.            
            if torch.any(counts[1:]>16):                
                outputs = model(slice_image.float())

                one_hot_mask = create_onehot_mask(outputs.shape, slice_mask, device)
                loss = criterion(outputs, one_hot_mask)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item()}')
        
        if (i+1) % 5 == 0:            
            torch.save(model.state_dict(), os.path.join(weight_path, "model.pth"))

print('Finished Training')