In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.LossFunctions import GDiceLoss
from modules.UNet import UNet3D
from modules.Utils import create_patch_indexes, create_onehot_mask

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_epochs = 500
batch_size_ = 1
learning_rate = 0.001
weight_path = "weights"
shape = (256, 256, 256)
patch_sizes = (128, 128, 128)

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet(train=True, transform=transform_)
test = FeTADataSet(train=False, transform=transform_)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size_)
test_loader = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size_)

In [None]:
model = UNet3D().to(device)
#print(summary(model.cuda(), input_size=(1, 64, 64, 64)))

criterion = GDiceLoss()#nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_loader)

In [None]:
patch_indexes = create_patch_indexes(shape, patch_sizes)

for epoch in range(num_epochs):
    for i, (image, mask) in enumerate(train_loader):
        image = image.to(device) #[bs,x,y,z]
        image = image.view(batch_size_, 1, shape[0], shape[1], shape[2]) #[bs,c,x,y,z]
        
        mask = mask.to(device) #[x,y,z]
        mask = mask.view(batch_size_, 1, shape[0], shape[1], shape[2]) #[bs,1,x,y,z]
        
        
        for coors in patch_indexes:
            [sx, sy, sz] = coors[0]
            [ex, ey, ez] = coors[1]            
            patch_image = image[:, :, sx:ex, sy:ey, sz:ez]
            patch_mask = mask[:, :, sx:ex, sy:ey, sz:ez]
            
            unique, counts = torch.unique(mask[sx:ex, sy:ey, sz:ez], return_counts=True) 
            
            # if the label count below the 70 pixel don't train with it.            
            if torch.any(counts[1:]>70):            
                outputs = model(patch_image.float())
                
                one_hot_mask = create_onehot_mask(outputs.shape, patch_mask)
                
                loss = criterion(outputs, one_hot_mask)
                print(loss.item())
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
        
        if (i+1) % 5 == 0:            
            torch.save(model.state_dict(), os.path.join(weight_path, "model.pth"))

print('Finished Training')