In [None]:

from monai.utils import first, set_determinism
from monai.transforms import (EnsureChannelFirstd, Compose, CropForegroundd, LoadImaged, Orientationd, RandCropByPosNegLabeld, ScaleIntensityRanged, Spacingd)
from monai.networks.nets import DynUNet
from monai.networks.layers import Norm
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset
from monai.apps import download_and_extract
from monai.transforms import CenterSpatialCropd
from monai.transforms import Resized
import torch
import matplotlib.pyplot as plt
import os
import glob
import torch.nn as nn
import json
from datetime import datetime
from data_preparation2 import DataHandling 
from UNet_model import create_unet
import numpy as np
import nibabel as nib

import math
import os
import glob
import torch
import numpy as np
import optuna
from optuna.integration import PyTorchLightningPruningCallback
from functools import partial
from monai.networks.nets import DynUNet
from monai.data import CacheDataset, DataLoader
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, SpatialPadd, RandCropByPosNegLabeld, CenterSpatialCropd)
import random
import math
from collections import defaultdict


In [None]:
data_dir = '/homes/zshahpouri/DLP/ASC-PET-001'
directory = '/homes/zshahpouri/DLP/Practic/LOGTEST'
output_dir = '/homes/zshahpouri/DLP/Practic/OUT'

In [None]:

train_images = sorted(glob.glob(os.path.join(data_dir, "NAC", "*.nii.gz")))
target_images = sorted(glob.glob(os.path.join(data_dir, "ADCM", "*.nii.gz")))

# data_dicts = [{"image": img, "target": tar} for img in train_images]
data_dicts = [{"image": img, "target": tar} for img, tar in zip(train_images, target_images)]

random.seed(42)
# Separate data based on the center
data_by_center = defaultdict(list)
for data in data_dicts:
    center = data["image"].split('/')[-1].split('_')[1]  # Assuming the format is always like /path/Cx_...
    # print(center)
    data_by_center[center].append(data)
# print(len(data_by_center['C5']))
# Initialize test set with all data from C5
test_files = data_by_center.pop('C5', [])

# From each remaining center, randomly select 2 for the test set and ensure they're removed from the training set
for center, files in data_by_center.items():
    if len(files) > 2:  # Ensure there are more than 2 files to choose from
        selected_for_test = random.sample(files, 2)
        test_files.extend(selected_for_test)
        # Remove selected files from the original list
        for selected in selected_for_test:
            files.remove(selected)
    else:
        test_files.extend(files)
        data_by_center[center] = []  # Empty the list as all files have been moved to test

# Recombine the remaining files for training and validation
remaining_files = [file for files in data_by_center.values() for file in files]
# print(len(remaining_files))
random.shuffle(remaining_files)  # Shuffle to ensure random distribution

total_size = len(remaining_files)
train_size = math.floor(total_size * 0.8)

train_files = remaining_files[:train_size]
val_files = remaining_files[train_size:]


In [None]:
from monai.transforms import NormalizeIntensityd


patch_size = [168, 168, 16]
spacing = [4.07, 4.07, 3.00]
spatial_size = (168, 168, 320)
train_transforms = Compose(

    [   LoadImaged(keys=["image", "target"]),
        EnsureChannelFirstd(keys=["image", "target"]),
        # NormalizeIntensityd(keys=[ "target"]),
        Spacingd(keys=["image", "target"], pixdim= spacing, mode= 'trilinear'),
        
        SpatialPadd(keys=["image", "target"], spatial_size=spatial_size, mode='constant'),  # Pad to ensure minimum size
        
        RandCropByPosNegLabeld(
            keys=["image", "target"],
            label_key="target",
            spatial_size = patch_size,
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),        ])

val_transforms = Compose(
    [   LoadImaged(keys=["image", "target"]),
        EnsureChannelFirstd(keys=["image", "target"]),
        Spacingd(keys=["image", "target"], pixdim=spacing, mode= 'trilinear'),
        SpatialPadd(keys=["image", "target"], spatial_size=spatial_size, mode='constant'),  # Ensure minimum size
        CenterSpatialCropd(keys=["image", "target"], roi_size=spatial_size),  # Ensure uniform size
    ])

train_ds = CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, num_workers=4)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)

val_ds = CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=4)

In [None]:
# starting_epoch = 0
# decay_epoch = 5
# learning_rate = 0.001

# import torch
# from monai.networks.nets import DynUNet
# from torch import nn

# class DecayLR:
#     def __init__(self, epochs, offset, decay_epochs):
#         epoch_flag = epochs - decay_epochs
#         assert (epoch_flag > 0), "Decay must start before the training session ends!"
#         self.epochs = epochs
#         self.offset = offset
#         self.decay_epochs = decay_epochs

#     def step(self, epoch):
#         return 1.0 - max(0, epoch + self.offset - self.decay_epochs) / (self.epochs - self.decay_epochs)
    


---------
# Network Parameters Finder

In [None]:
import os
import torch
from monai.networks.nets import DynUNet

def get_kernels_strides(patch_size, spacing):
    """
    Adjusted function to use the correct variable names.
    """
    sizes = patch_size  
    spacings = spacing  
    strides, kernels = [], []
    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [2 if ratio <= 2 and size >= 8 else 1 for (ratio, size) in zip(spacing_ratio, sizes)]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        for idx, (i, j) in enumerate(zip(sizes, stride)):
            if i % j != 0:
                raise ValueError(
                    f"Patch size is not supported, please try to modify the size {patch_size[idx]} in the spatial dimension {idx}."
                )
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)

    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides


def get_network(patch_size, spacing):
    """
    Initializes the DynUNet with dynamically determined kernels and strides.
    """
    kernels, strides = get_kernels_strides(patch_size, spacing)
    print(kernels)
    print(strides)
    print(len(strides))
    net = DynUNet(
        spatial_dims=3,
        in_channels=1,
        out_channels=1,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="instance",
        deep_supervision=True,
        deep_supr_num=2,
    )
    return net

# Example usage
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = get_network(patch_size, spacing)
model = model.to(device)

In [None]:


# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.5, 0.999))




# optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
loss_function = torch.nn.MSELoss()
max_epochs = 2
val_interval = 2
best_metric = float('inf')
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
train_losses = []
val_losses = []

In [None]:
def deep_loss(outputs, target, loss_function, device, weights=None):
    """
    Compute the deep supervision loss for each output feature map.

    Parameters:
    - outputs: Tensor containing all output feature maps, including the final prediction.
    - target: The ground truth tensor.
    - loss_function: The loss function to apply.
    - device: The device on which to perform the calculations.
    - weights: A list of weights for each output's loss. Defaults to equal weighting if None.

    Returns:
    - Weighted average of the computed losses.
    """
    # Unbind the outputs along the first dimension to handle each feature map individually
    output_maps = torch.unbind(outputs, dim=1)
    
    if weights is None:
        # If no weights specified, use equal weights
        weights = [1.0 / len(output_maps)] * len(output_maps)
    elif sum(weights) != 1:
        # Normalize weights to sum to 1
        total = sum(weights)
        weights = [w / total for w in weights]

    total_loss = 0.0
    for output, weight in zip(output_maps, weights):
        # Resize target to match the output size if necessary
        resized_target = torch.nn.functional.interpolate(target, size=output.shape[2:], mode='nearest').to(device)
        # Compute loss for the current output
        loss = loss_function(output, resized_target)
        # Accumulate weighted loss
        total_loss += weight * loss

    return total_loss



In [None]:
import functools

def objective(trial, model, train_loader, val_loader, device, max_epochs, val_interval, loss_function, deep_loss):
    lr = trial.suggest_float('lr', 1e-5, 1e-1, log=True)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    # Training loop
    model.train()

    for epoch in range(max_epochs):

        epoch_loss = 0
        step = 0

        for batch_data in train_loader:
            step += 1
            inputs, targets = batch_data["image"].to(device), batch_data["target"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            if isinstance(outputs, tuple) or (outputs.dim() > targets.dim()):
                loss = deep_loss(outputs, targets, loss_function, device)
            else:
                outputs = torch.squeeze(outputs)
                targets = torch.squeeze(targets, dim=1)  # Adjust for channel dimension if necessary
                loss = loss_function(outputs, targets)
            
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

        epoch_loss /= step


        # Validation logic remains largely the same
        if (epoch + 1) % val_interval == 0:
            model.eval()
            val_loss = 0
            roi_size = (168, 168, 32)
            sw_batch_size = 16
            
            with torch.no_grad():
                for val_data in val_loader:
                    val_inputs, val_targets = val_data["image"].to(device), val_data["target"].to(device)

                    val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                    val_loss += loss_function(val_outputs, val_targets).item()

            val_loss /= len(val_loader)
    return val_loss

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Use functools.partial to pass extra arguments to the objective function
objective_with_args = functools.partial(
    objective,
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    max_epochs=max_epochs,
    val_interval=val_interval,
    loss_function=loss_function,
    deep_loss=deep_loss
)

study = optuna.create_study(direction='minimize')
study.optimize(objective_with_args, n_trials=10) 

print("Best trial:", study.best_trial.number)
print("Best value (validation loss):", study.best_value)
print("Best hyperparameters:", study.best_params)

---------------------

# Result for ADCM:

[I 2024-03-13 07:50:52,836] A new study created in memory with name: no-name-9bbcd5b9-2984-4e90-bb09-d75679945a16
[I 2024-03-13 07:58:07,865] Trial 0 finished with value: 0.026119201553656775 and parameters: {'lr': 0.008072853294952104}. Best is trial 0 with value: 0.026119201553656775.
[I 2024-03-13 08:05:24,497] Trial 1 finished with value: 0.024910740006495926 and parameters: {'lr': 0.05699468256166564}. Best is trial 1 with value: 0.024910740006495926.
[I 2024-03-13 08:12:36,228] Trial 2 finished with value: 0.013740474365048987 and parameters: {'lr': 0.0009126015047894743}. Best is trial 2 with value: 0.013740474365048987.
[I 2024-03-13 08:20:06,077] Trial 3 finished with value: 0.013423842097194317 and parameters: {'lr': 6.830399550473302e-05}. Best is trial 3 with value: 0.013423842097194317.
[I 2024-03-13 08:27:29,051] Trial 4 finished with value: 0.012969923698726822 and parameters: {'lr': 0.00020615854432982388}. Best is trial 4 with value: 0.012969923698726822.
[I 2024-03-13 08:35:12,140] Trial 5 finished with value: 0.012929829809924258 and parameters: {'lr': 3.557901372471996e-05}. Best is trial 5 with value: 0.012929829809924258.
[I 2024-03-13 08:43:13,100] Trial 6 finished with value: 0.012899920029346557 and parameters: {'lr': 5.8365845148990886e-05}. Best is trial 6 with value: 0.012899920029346557.
[I 2024-03-13 08:50:58,246] Trial 7 finished with value: 0.012893239296424915 and parameters: {'lr': 1.8475974753442264e-05}. Best is trial 7 with value: 0.012893239296424915.
[I 2024-03-13 08:58:42,021] Trial 8 finished with value: 0.012873913805164835 and parameters: {'lr': 5.156637879231117e-05}. Best is trial 8 with value: 0.012873913805164835.
[I 2024-03-13 09:06:27,394] Trial 9 finished with value: 0.012866821701583616 and parameters: {'lr': 1.82173317485587e-05}. Best is trial 9 with value: 0.012866821701583616.
Best trial: 9
Best value (validation loss): 0.012866821701583616
Best hyperparameters: {'lr': 1.82173317485587e-05}