In [1]:
# Import default libraries
import os
from datetime import datetime
import yaml

In [2]:
# Import 3rd party libraries
import pandas as pd
import torch 
from torch.utils.data import DataLoader
import torchio as tio
from torchsummary import summary
import torchvision.transforms as transforms

In [3]:
# Import user defined libraries
from data.Dataset import FeTABalancedDistribution, MRIDataset
from data.transforms import *
from models import Evaluator3D, SDUNet3D, Trainer3D
from utils.Config import Config
from utils.LossFunctions import DC_and_CE_loss
from utils.Utils import *
from visualization.Tensorboard import TensorboardModules

### Experiment parameters

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

cfg = Config('config.yaml')
cfg_dict = cfg.as_dict()

output_path = '../models/{}/Test/'.format(datetime.now().strftime("%Y%m%d_%H%M"))
weight_path = os.path.join(output_path, "weights/")

### Data

In [5]:
# Dataset operations

# There are multiple data folders belong to same dataset. Each of them processed in different ways.
# Therefore, path of the data and its name explicitly are defined.
dataset_train = FeTABalancedDistribution
dataset_val = FeTABalancedDistribution

#cv_ = "cv3" # 5-fold cross-validation. Folds [cv1-cv5]

# Transformations.
transform_train = transforms.Compose([tio.transforms.RandomAffine(scales=(0.95, 1.05)),
                                      tio.transforms.RandomMotion(), 
                                      tio.transforms.RandomNoise()])

transform_eval = None # transforms.Compose([Mask()])


train = MRIDataset(dataset_train, "train", cfg.data.train_path, transform=transform_train)
train_queue = tio.Queue(subjects_dataset=train.dataset, max_length=216, samples_per_volume=8,
                        sampler=tio.UniformSampler(patch_size=cfg.data.patch_size), num_workers=4)

val = MRIDataset(dataset_val, "val", cfg.data.val_path, transform=transform_eval)

In [6]:
# DataLoader operations
train_loader = DataLoader(dataset=train_queue, batch_size=cfg.data.batch_size, num_workers=0, shuffle=True)
val_loader = DataLoader(dataset=val, batch_size=cfg.data.batch_size)

In [7]:
# Add dataset configuration to parameters to save them as meta data.
cfg_dict["data"]["dataset_train"] = str(dataset_train).split("'")[1].split('.')[-1]
cfg_dict["data"]["dataset_val"] = str(dataset_train).split("'")[1].split('.')[-1]
#cfg_dict["data"]["cross_validation"] = "None" if not cv_ else cv_
cfg_dict["data"]["transform_train"] = "None" if not transform_train else str(transform_train.transforms)
cfg_dict["data"]["transform_eval"] = "None" if not transform_eval else str(transform_eval.transforms)

### Model

In [8]:
model = SDUNet3D().to(device)
criterion = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False, 'square': False}, {})
pretrained = False

# Initalize weights or load already trained model.
if not pretrained:
    cfg_dict["initial_weights"] = 'Random' 
else:
    model_path = "../models/20230107/SDUNet/weights/34_model.pth"
    model.load_state_dict(torch.load(model_path))
    cfg_dict["initial_weights"] = model_path

### Training configuration

In [10]:
optimizer = torch.optim.SGD(model.parameters(), lr=cfg.optimizer.SGD.lr, 
                            momentum=cfg.optimizer.SGD.momentum, nesterov=cfg.optimizer.SGD.nesterov)


scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=cfg.lr_scheduler.CLR.base, 
                                              max_lr=cfg.lr_scheduler.CLR.max,
                                              step_size_up=cfg.lr_scheduler.CLR.up, 
                                              step_size_down=cfg.lr_scheduler.CLR.down,
                                              mode=cfg.lr_scheduler.CLR.mode)

#scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.lr_scheduler.SLR.step_size], 
                                            #gamma=cfg.lr_scheduler.SLR.gamma)

early_stopping = EarlyStopping(patience=cfg.training.ES.patience, min_delta=cfg.training.ES.min_delta)


# Initalize trainer for training.
trainer = Trainer3D(criterion, model, optimizer, cfg.training.total_epoch, train_loader, scheduler)

# Initalize evaluator for validation.
evaluator = Evaluator3D(criterion, model, cfg.data.patch_size, val_loader)

In [13]:
# Create output and path if it is not exist.
if not os.path.isdir(weight_path):
    os.makedirs(weight_path)

# Create Tensorboard object to save experiment outputs.    
tb = TensorboardModules(output_path)

# Save paramaters.
with open(os.path.join(output_path, 'config.yaml'), 'w') as outfile:
    yaml.dump(cfg_dict, outfile, default_flow_style=False)

# Add some images and corresponding masks into Tensorboard.
mri_image, mri_mask = val[0]['mri']['data'].squeeze(0), val[0]['mask']['data'].squeeze(0)
slices = (50, 100, 10)
tb.add_image_mask(mri_image, mri_mask, slices)

# Add model graph to Tensorboard.
tb.add_graph(model, cfg.data.patch_size, device)
# print(summary(model, input_size=(1, 32, 128, 128)))

## Training

In [14]:
prev_weights = ""
prev_val_loss = 100

for epoch in range(0, cfg.training.total_epoch):
    # One forward pass for all training data.
    avg_train_loss = trainer.fit()
    
    # Evaluate current model on validation data.
    avg_val_loss, dice_scores = evaluator.evaluate()
    avg_scores = sum(dice_scores) / len(dice_scores)
    
    print("-------------------------------------------------------------")
    
    # Add results to tensorboard.
    tb.add_scalars(step=epoch+1, lr=scheduler.get_last_lr()[0], ds=avg_scores, 
                   train_loss=avg_train_loss, val_loss=avg_val_loss)
    
    model_name = "_".join([str(epoch), "model.pth"])
    model_path = os.path.join(weight_path, model_name)
    
    if avg_val_loss < prev_val_loss:
        # Save trained weights.
        if os.path.isfile(prev_weights):
            os.remove(prev_weights)        
        torch.save(model.state_dict(), model_path)
        
    prev_weights = model_path        
    prev_val_loss = avg_val_loss
    
    # If model is not learning stop the training.
    early_stopping(avg_val_loss)
    if early_stopping.early_stop:
        break

print('Finished Training')

Epoch [1/50]: 100%|██████████| 464/464 [07:04<00:00,  1.09it/s, Loss: 1.0879]
Validation : 100%|██████████| 10/10 [00:36<00:00,  3.66s/it, Loss: 0.5776]


-------------------------------------------------------------


Epoch [2/50]: 100%|██████████| 464/464 [07:04<00:00,  1.09it/s, Loss: -0.0427]
Validation : 100%|██████████| 10/10 [00:37<00:00,  3.70s/it, Loss: -0.4070]


-------------------------------------------------------------


Epoch [3/50]: 100%|██████████| 464/464 [07:04<00:00,  1.09it/s, Loss: -0.4919]
Validation : 100%|██████████| 10/10 [00:36<00:00,  3.67s/it, Loss: -0.6547]


-------------------------------------------------------------


Epoch [4/50]:  10%|█         | 48/464 [00:42<06:12,  1.12it/s, Loss: -0.4507]


KeyboardInterrupt: 