In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchio as tio
from torchsummary import summary
import torchvision
import torchvision.transforms as transforms


from modules.Dataset import FeTADataSet
from modules.Evaluator import Evaluator3D
from modules.LossFunctions import DC_and_CE_loss, GDiceLossV2
from modules.Trainer import Trainer3D
from modules.UNet import UNet3D
from modules.Utils import calculate_dice_score, create_onehot_mask, create_patch_indexes, init_weights_kaiming
from modules.Utils import EarlyStopping, LearningRateFinder, TensorboardModules 

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Hyper-parameters 
params = {"num_epochs": 250,
          "batch_size": 1,
          "lr": 0.1,
          "momentum": 0.9,
          "nesterov": True,
          "patch_sizes": (128, 128, 128),
          "image_sizes": (256, 256, 256),
          "ES":{"patience": 5, "min_delta": 1e-3},
          #"CLR":{"base": 1e-07, "max": 0.1, "up": 4, "down": 8, "mode": "triangular2"},
          "SLR":{'step_size': 10, "gamma": 1e-1}
         }

In [None]:
labels = pd.read_csv("feta_2.1/dseg.tsv", sep='\t', index_col="index")

transform_ = None#transforms.Compose([tio.ZNormalization(masking_method=tio.ZNormalization.mean)])

train = FeTADataSet("train", path="data2", transform=transform_)
val = FeTADataSet("val", path="data2", transform=transform_)

torch.manual_seed(0)
train_loader = torch.utils.data.DataLoader(dataset=train[:10], batch_size=params["batch_size"], shuffle=True)
val_loader = torch.utils.data.DataLoader(dataset=val[:10], batch_size=params["batch_size"], shuffle=True)

In [None]:
model = UNet3D().to(device)
model.apply(init_weights_kaiming)


criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
optimizer = torch.optim.SGD(model.parameters(), lr=params["lr"], 
                            momentum=params["momentum"], nesterov=params["nesterov"])

# Initalize trainer for training.
trainer = Trainer3D(model, train_loader, optimizer, criterion, params["patch_sizes"], 
                    params["num_epochs"])

# Learning rate finder.
lr_finder = LearningRateFinder(trainer)
lr_finder.find(startLR=1e-9)