In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms

from modules.Dataset import FeTADataSet
from modules.UNet import UNet3D
from modules.Utils import create_patch_indexes
from modules.LossFunctions import GDiceLoss

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
num_epochs = 500
batch_size = 4
learning_rate = 0.001
weight_path = "weights"
patch_counts = (2, 2, 2)
image_shape = (256, 256, 256)

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")
train = FeTADataSet(train=True)
test = FeTADataSet(train=False)

train_loader = torch.utils.data.DataLoader(dataset=train, batch_size=1)
test_loader = torch.utils.data.DataLoader(dataset=train, batch_size=1)

In [None]:
model = UNet3D().to(device)
#print(summary(model.cuda(), input_size=(1, 128, 128, 128)))

criterion = GDiceLoss()#nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

n_total_steps = len(train_loader)

In [None]:
patch_indexes = create_patch_indexes(patch_counts, image_shape)

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        shape = images.shape
        images = torch.reshape(images, (shape[0], 1, shape[1], shape[2], shape[3]))
        labels = labels.to(device)

        for coors in patch_indexes:
            [sx, sy, sz] = coors[0]
            [ex, ey, ez] = coors[1]
            patch_image = images[:, :, sx:ex, sy:ey, sz:ez]
            patch_label = labels[:, sx:ex, sy:ey, sz:ez]

            outputs = model(patch_image.float())
            patch_label = patch_label.view(1, 1, 128, 128, 128)
            loss = criterion(outputs, patch_label.long())

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if (i+1) % 20 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
            torch.save(model.state_dict(), os.path.join(weight_path, "model.pth"))

print('Finished Training')