In [8]:

import os
import numpy as np
import torch
from tqdm import tqdm
import json
from dual_network import Dual3DCNN3, Dual3DCNN4, Dual3DCNN5
import torch
from torch.utils.data import Dataset
import SimpleITK as sitk
import numpy as np
import glob
from utilities import create_list_from_master_json, read_json_file, split_data
import re
import glob
import random
from torch.utils.data import Dataset, DataLoader
import numpy as np
import SimpleITK as sitk
import torch
from utilities import list_patient_folders, prepare_data_nrrd, split_data
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, Spacingd, ScaleIntensityd, SpatialPadd, CenterSpatialCropd, ScaleIntensityRanged
from monai.data import CacheDataset, DataLoader, Dataset
from monai.transforms import LoadImaged
from monai.data.image_reader import ITKReader
from monai.data import SmartCacheDataset
import random
import optuna
from optuna.integration import PyTorchLightningPruningCallback

In [29]:
data_path_NEW = '/home/shahpouriz/Data/DBP_newDATA/DBP/nrrd/test'


patient_list_NEW = list_patient_folders(data_path_NEW)
# Shuffle patient list if you want randomness
random.shuffle(patient_list_NEW)

# Define split sizes
total_patients = len(patient_list_NEW)
train_size = int(total_patients * 0.70)
val_size = int(total_patients * 0.20)
# The rest will be for the test set

# Split the patient list
train_patients = patient_list_NEW[:train_size]
val_patients = patient_list_NEW[train_size:train_size + val_size]
test_patients = patient_list_NEW[train_size + val_size:]

train_pct, train_rct, train_pos = prepare_data_nrrd(data_path_NEW, train_patients)
val_pct, val_rct, val_pos = prepare_data_nrrd(data_path_NEW, val_patients)
test_pct, test_rct, test_pos = prepare_data_nrrd(data_path_NEW, test_patients)

# Create dictionaries for each dataset
train_data = [{"plan": img, "repeat": tar, "pos": pos} for img, tar, pos in zip(train_pct, train_rct, train_pos)]
val_data = [{"plan": img, "repeat": tar, "pos": pos} for img, tar, pos in zip(val_pct, val_rct, val_pos)]
test_data = [{"plan": img, "repeat": tar, "pos": pos} for img, tar, pos in zip(test_pct, test_rct, test_pos)]


# Check the lengths of the sets
print("Number of training samples:", len(train_data))
print("Number of validation samples:", len(val_data))
print("Number of test samples:", len(test_data))
print(len(test_data)+len(val_data)+len(train_data))

Number of training samples: 320
Number of validation samples: 89
Number of test samples: 81
490


In [30]:

dim = 128
size = (dim, dim, dim)
pixdim = (3.0, 3.0, 3.0)
transforms = Compose([
        LoadImaged(keys=["plan", "repeat"], reader=ITKReader()),
        EnsureChannelFirstd(keys=["plan", "repeat"]),
        ScaleIntensityd(keys=["plan", "repeat"]),
        Spacingd(keys=["plan", "repeat"], pixdim=pixdim, mode='trilinear'),
        SpatialPadd(keys=["plan", "repeat"], spatial_size=size, mode='constant'),  # Ensure minimum size
        CenterSpatialCropd(keys=["plan", "repeat"], roi_size=size),  # Ensure uniform size
    ])


train_ds = CacheDataset(data=train_data, transform=transforms, cache_rate=0.8, num_workers=1)
train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1)

val_ds = CacheDataset(data=val_data, transform=transforms, cache_rate=0.8, num_workers=1)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1)


model = Dual3DCNN5(width=dim, height=dim, depth=dim)
device = torch.device("cuda:0")
model.to(device)

Loading dataset: 100%|██████████| 256/256 [00:37<00:00,  6.75it/s]
Loading dataset: 100%|██████████| 71/71 [00:10<00:00,  6.89it/s]


Dual3DCNN5(
  (input_fixed_blocks): ModuleList(
    (0): Sequential(
      (0): Conv3d(1, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): Conv3d(16, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (2): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.01, inplace=True)
    )
    (1): Sequential(
      (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): Conv3d(32, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (2): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): LeakyReLU(negative_slope=0.01, inplace=True)
    )
    (2): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
      (1): Conv3d(64, 64, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
      (2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, tra

In [31]:
# Set parameters
starting_epoch = 0
final_epoch = 2

# Condition for saving list
best_mae = np.inf
exception_list = ['']
mae_loss = torch.nn.L1Loss()

In [32]:
import functools

def objective(trial, model, train_loader, val_loader, device, final_epoch, mae_loss):
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    # Training loop
    model.train()

    for epoch in range(starting_epoch, final_epoch):
        model.train()  # Set model to training mode
        mae_list = []
        train_loss = []
        for i, batch_data in enumerate(train_loader):  # Use enumerate to get the batch index
            pCT, rCT = batch_data["plan"].to(device), batch_data["repeat"].to(device)
            reg = batch_data["pos"].clone().detach().requires_grad_(True).to(device)  # If gradients are required for 'reg'
            optimizer.zero_grad()

            output = model(pCT, rCT)
            loss_output = mae_loss(output, reg)

            loss_output.backward()
            optimizer.step()
            
            # Logging
            mae_list.append(loss_output.item())
            mean_mae = np.mean(mae_list)
            # Corrected to print the current batch number


        # Validation loop
        model.eval()
        val_loss = []
        with torch.no_grad():
            for batch_data in val_loader:
                pCT_val, rCT_val = batch_data["plan"].to(device), batch_data["repeat"].to(device)
                reg_val = batch_data["pos"].clone().detach().requires_grad_(True).to(device)  # If gradients are required for 'reg'

                output_val = model(pCT_val, rCT_val)
                loss_output_val = mae_loss(output_val, reg_val)

                val_loss.append(loss_output_val.item())

            mean_val_loss = np.mean(val_loss)
    return mean_val_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use functools.partial to pass extra arguments to the objective function
objective_with_args = functools.partial(
    objective,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    final_epoch=final_epoch,
    mae_loss=mae_loss,

)

study = optuna.create_study(direction='minimize')
study.optimize(objective_with_args, n_trials=10) 

print("Best trial:", study.best_trial.number)
print("Best value (validation loss):", study.best_value)
print("Best hyperparameters:", study.best_params)

[I 2024-03-13 14:31:57,482] A new study created in memory with name: no-name-5d73c6d5-2b86-4821-b676-5dcd995a9c5e


[I 2024-03-13 14:32:36,741] Trial 0 finished with value: 0.22764338837580733 and parameters: {'lr': 2.9862034252760157e-05}. Best is trial 0 with value: 0.22764338837580733.
[I 2024-03-13 14:33:17,953] Trial 1 finished with value: 0.20791751846378104 and parameters: {'lr': 0.0015962784379023476}. Best is trial 1 with value: 0.20791751846378104.
[I 2024-03-13 14:34:20,875] Trial 2 finished with value: 0.21020508756379733 and parameters: {'lr': 0.010165108285484545}. Best is trial 1 with value: 0.20791751846378104.
[I 2024-03-13 14:35:23,806] Trial 3 finished with value: 0.20735523947305307 and parameters: {'lr': 0.016827959619010308}. Best is trial 3 with value: 0.20735523947305307.
[I 2024-03-13 14:36:26,785] Trial 4 finished with value: 0.2067016278618549 and parameters: {'lr': 5.742814249611241e-05}. Best is trial 4 with value: 0.2067016278618549.
[I 2024-03-13 14:37:29,956] Trial 5 finished with value: 0.20707001745324122 and parameters: {'lr': 0.006575352851050944}. Best is trial 4

Best trial: 8
Best value (validation loss): 0.20497857394701477
Best hyperparameters: {'lr': 0.0018966257353212146}


[I 2024-03-13 12:57:48,637] A new study created in memory with name: no-name-2cd05e81-8e1c-44d3-81a4-4be74941161f
[I 2024-03-13 12:58:47,573] Trial 0 finished with value: 0.309533763534079 and parameters: {'lr': 9.218304386771182e-06}. Best is trial 0 with value: 0.309533763534079.
[I 2024-03-13 12:59:54,444] Trial 1 finished with value: 0.3092626925185323 and parameters: {'lr': 6.610746908388528e-05}. Best is trial 1 with value: 0.3092626925185323.
[I 2024-03-13 13:01:01,542] Trial 2 finished with value: 0.30921216851100325 and parameters: {'lr': 2.4698381067064442e-05}. Best is trial 2 with value: 0.30921216851100325.
[I 2024-03-13 13:02:08,653] Trial 3 finished with value: 0.3096086760734518 and parameters: {'lr': 0.0007246235392743511}. Best is trial 2 with value: 0.30921216851100325.
[I 2024-03-13 13:03:15,620] Trial 4 finished with value: 0.3095567206541697 and parameters: {'lr': 1.2856561858081072e-05}. Best is trial 2 with value: 0.30921216851100325.
[I 2024-03-13 13:04:22,598] Trial 5 finished with value: 0.3095563782254855 and parameters: {'lr': 8.659948648935581e-08}. Best is trial 2 with value: 0.30921216851100325.
[I 2024-03-13 13:05:29,773] Trial 6 finished with value: 0.3095362058530251 and parameters: {'lr': 4.637608822923589e-06}. Best is trial 2 with value: 0.30921216851100325.
[I 2024-03-13 13:06:36,824] Trial 7 finished with value: 0.3088546040095389 and parameters: {'lr': 0.0002332795188029022}. Best is trial 7 with value: 0.3088546040095389.
[I 2024-03-13 13:07:44,302] Trial 8 finished with value: 0.3088908504260083 and parameters: {'lr': 0.0002125349283948824}. Best is trial 7 with value: 0.3088546040095389.
[I 2024-03-13 13:08:51,713] Trial 9 finished with value: 0.30929054742679 and parameters: {'lr': 0.0006778213291587758}. Best is trial 7 with value: 0.3088546040095389.
Best trial: 7
Best value (validation loss): 0.3088546040095389
Best hyperparameters: {'lr': 0.0002332795188029022}