In [1]:
!pwd

/data/marciano/experiments/one-to-rule-them-all/src/notebooks


In [2]:
%cd /data/marciano/experiments/one-to-rule-them-all

/data/marciano/experiments/one-to-rule-them-all


In [3]:
!pwd

/data/marciano/experiments/one-to-rule-them-all


## Data preparation

In [4]:
import os
import numpy as np
from src.utils.utilsAcdc import generate_patient_info, preprocess

patient_info = generate_patient_info("data/heart/training/", patient_ids=range(1, 101))
patient_info = {**patient_info, **generate_patient_info("data/heart/testing/", patient_ids=range(101, 151))}

if not os.path.exists("preprocessed/heart"):
    os.makedirs("preprocessed/heart/")
np.save(os.path.join("preprocessed/heart/", "patient_info"), patient_info)

#spacings = [patient_info[id]["spacing"] for id in range(1, 101)]
#spacing_target = np.percentile(np.vstack(spacings), 50, 0)
spacing_target = [10, 1.25, 1.25]

if not os.path.exists("preprocessed/heart/training/"):
    os.makedirs("preprocessed/heart/training/")
preprocess(
    range(1,101), patient_info, spacing_target,
    "data/heart/training", "preprocessed/heart/training",
    lambda folder, id: os.path.join(folder, 'patient{:03d}'.format(id)),
    lambda patient_info, id, phase: "patient{:03d}_frame{:02d}_gt.nii.gz".format(id, patient_info[id][phase])
)

for model in os.listdir("data/heart/predictions/"):
    if not os.path.exists("preprocessed/heart/predictions/{}".format(model)):
        os.makedirs("preprocessed/heart/predictions/{}".format(model))
    preprocess(
        range(101,151), patient_info, spacing_target,
        "data/heart/predictions/{}".format(model), "preprocessed/heart/predictions/{}".format(model),
        lambda folder, id: folder,
        lambda patient_info, id, phase: "patient{:03d}_{}.nii.gz".format(id,phase)
    ) 

## Dataset

In [5]:
import os
import numpy as np
import random
import torchvision
from src.utils.utilsAcdc import AddPadding, CenterCrop, OneHot, ToTensor, MirrorTransform, SpatialTransform

ids = random.sample(range(1, 101), 100)
train_ids = ids[:80]
val_ids = ids[80:]

transform = torchvision.transforms.Compose([
    AddPadding((256,256)),
    CenterCrop((256,256)),
    OneHot(),
    ToTensor()
])
transform_augmentation = torchvision.transforms.Compose([
    MirrorTransform(),
    SpatialTransform(patch_size=(256,256), angle_x=(-np.pi/6,np.pi/6), scale=(0.7,1.4), random_crop=True),
    OneHot(),
    ToTensor()
])



## Training

In [None]:
import sys, importlib
importlib.reload(sys.modules['CA'])

import torch
from src.ConvAE.basic_model import AE
from src.ConvAE.utils import hyperparameter_tuning, plot_history
from utils.utilsAcdc import ACDCDataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#the following cell allows the user to tune his own hyperparameters. In case you wanna try our hyperparameters first, just skip this cell.

#this is a list of possible values being tested for each hyperparameter.
parameters = {
    "DA": [True, False], #data augmentation
    "latent_size": [100, 500], #size of the latent space of the autoencoder
    "BATCH_SIZE": [8, 16, 4],
    "optimizer": [torch.optim.Adam],
    "lr": [2e-4, 1e-4, 1e-3],
    "weight_decay": [1e-5],
    "tuning_epochs": [5, 10], #number of epochs each configuration is run for
    "functions": [["GDLoss", "MSELoss"], ["GDLoss"], ["BKGDLoss", "BKMSELoss"]], #list of loss functions to be evaluated. BK stands for "background", which is a predominant and not compulsory class (it can lead to a dumb local minimum retrieving totally black images).
    "settling_epochs_BKGDLoss": [10, 0], #during these epochs BK has half the weight of LV, RV and MYO in the evaluation of BKGDLoss
    "settling_epochs_BKMSELoss": [10, 0], #during these epochs BK has half the weight of LV, RV and MYO in the evaluation of BKMSELoss
}

#this is a list of rules cutting out some useless combinations of hyperparameters from the tuning process.
rules = [
    '"settling_epochs_BKGDLoss" == 0 or "BKGDLoss" in "functions"',
    '"settling_epochs_BKMSELoss" == 0 or "BKMSELoss" in "functions"',
    '"BKGDLoss" not in "functions" or "settling_epochs_BKGDLoss" <= "tuning_epochs"',
    '"BKMSELoss" not in "functions" or "settling_epochs_BKMSELoss" <= "tuning_epochs"',
    #'"BKGDLoss" not in "functions" or "settling_epochs_BKGDLoss" >= "settling_epochs_BKMSELoss"'
]

optimal_parameters = hyperparameter_tuning(
    parameters,
    ACDCDataLoader("preprocessed/heart/training", patient_ids=train_ids, batch_size=None, transform=None),
    ACDCDataLoader("preprocessed/heart/training", patient_ids=val_ids, batch_size=None, transform=None),
    transform, transform_augmentation,
    rules,
    fast=True) #very important parameter. When False, all combinations are tested to return the one retrieving the maximum DSC. When True, the first combination avoiding dumb local minima is returned.

np.save(os.path.join("preprocessed/heart/", "optimal_parameters"), optimal_parameters)

In [None]:
upload_your_own_parameters = False

if upload_your_own_parameters:
    optimal_parameters = np.load(os.path.join("preprocessed/heart", "optimal_parameters.npy"), allow_pickle=True).item()
else:
    optimal_parameters = {
        "BATCH_SIZE": 8,
        "DA": False,
        "latent_size": 100,
        "optimizer": torch.optim.Adam,
        "lr": 2e-4,
        "weight_decay": 1e-5,
        "functions": ["BKGDLoss", "BKMSELoss"],
        "settling_epochs_BKGDLoss": 10,
        "settling_epochs_BKMSELoss": 0
    }
    np.save(os.path.join("preprocessed/heart", "optimal_parameters"), optimal_parameters)

assert optimal_parameters is not None, "Be sure to continue with a working set of hyperparameters"

BATCH_SIZE = optimal_parameters["BATCH_SIZE"]
DA = optimal_parameters["DA"]

ae = AE(**optimal_parameters).to(device)

ckpt = None
if ckpt is not None:
    ckpt = torch.load(ckpt)
    ae.load_state_dict(ckpt["AE"])
    ae.optimizer.load_state_dict(ckpt["AE_optim"])
    start = ckpt["epoch"]+1
else:
    start = 0

print(ae)

plot_history(
    ae.training_routine(
        range(start, 500),
        ACDCDataLoader("preprocessed/heart/training", patient_ids=train_ids, batch_size=BATCH_SIZE, transform=transform_augmentation if DA else transform),
        ACDCDataLoader("preprocessed/heart/training", patient_ids=val_ids, batch_size=BATCH_SIZE, transform=transform),
        "checkpoints/"
    )
)