# Continual Learning Pre - Flight Test Methods : Joint Learning, Sequential Learning

## Necessary Installs

In [None]:
!pip install -q monai einops

## Necessary Imports

In [None]:
# imports and installs
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torch.optim import SGD, Adam, ASGD

import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import monai
from monai.transforms import (ScaleIntensityRange, Compose, AddChannel, RandSpatialCrop, ToTensor, 
                            RandAxisFlip, Activations, AsDiscrete, Resize, RandRotate, RandFlip, EnsureType,
                             KeepLargestConnectedComponent)
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, DiceCELoss, DiceFocalLoss
from monai.networks.nets import UNet, VNet, UNETR, SwinUNETR, AttentionUnet
from monai.data import decollate_batch, ImageDataset
from monai.utils import set_determinism
import os
import wandb
from time import time
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from random import sample

torch.manual_seed(2000)
set_determinism(seed=2000)

wandb_log = True

## Create Data Loader 

In [None]:
class DecathlonCorrectImage(monai.transforms.Transform):
    def __call__(self, x : torch.Tensor) -> torch.Tensor:
        """
        This function is used to preprocess image input data.
        
        1. Keep the original intensity values from 0 to 99.9 %
        2. Scales the intensity range of the image to [0, 1].
        """
        # x shape is (channel, height, width, depth)

        a_max = np.percentile(x, 99.9)
        SIR = ScaleIntensityRange(a_min = 0, a_max = a_max, b_min = 0, b_max = 1, clip = True)
        x = SIR(x)
        return x

class DecathlonCorrectLabel(monai.transforms.Transform):
    def __call__(self, x : torch.Tensor) -> torch.Tensor:
        """
        This function is used to preprocess image input data.
        
        1. Multi class segmentation to binary segmentation(Labels > 0  will be set to 1.).
        """
        # x shape is (channel, height, width, depth)
        x[x > 0] = 1
        return x


class DecathlonProstaeJustMRI(monai.transforms.Transform):
    def __call__(self, x : torch.Tensor) -> torch.Tensor:
        """
        This function is used to preprocess image input data.
        1. For now we only use the first channel(MRI) of the input image and ignoring the 
        ADC map in the second channel.
        """
        # x shape is (channel, height, width, depth, maps)
        x = rearrange(x, 'h w d c -> c h w d')
        return x[0]

class DecathlonCorrectImage(monai.transforms.Transform):
    def __call__(self, x : torch.Tensor) -> torch.Tensor:
        """
        This function is used to preprocess image input data.
        1. Rotates the image by 270 degrees in the x-y plane.
        2. Keep the original intensity values from 0 to 99.9 %
        3. Scales the intensity range of the image to [0, 1].
        """
        # x shape is (channel, height, width, depth)

        x = np.flip(np.rot90(x, k=3, axes = (1,2)), axis = 2)
        a_max = np.percentile(x, 99.9)
        SIR = ScaleIntensityRange(a_min = 0, a_max = a_max, b_min = 0, b_max = 1, clip = True)
        x = SIR(x)
        return x

class DecathlonCorrectLabel(monai.transforms.Transform):
    def __call__(self, x : torch.Tensor) -> torch.Tensor:
        """
        This function is used to preprocess image input data.
        1. Rotates the image by 270 degrees in the x-y plane.
        2. Multi class segmentation to binary segmentation(Labels > 0  will be set to 1.).
        """
        # x shape is (channel, height, width, depth)
        x = np.flip(np.rot90(x, k=3, axes = (1,2)), axis = 2)
        x[x > 0] = 1
        return x

def get_img_label_folds(img_paths, label_paths):
    
    fold = list(range(0,len(img_paths)))
    fold = sample(fold, k=len(fold))
    fold_imgs = [img_paths[i] for i in fold]
    fold_labels = [label_paths[i] for i in fold]
    return fold_imgs, fold_labels


def get_dataloaders(
    dataset : str ,
    batch_size : int ,
    test_size : float ,
    roi_size : int ,
    test_shuffle : bool ,
):
    """_summary_

    Args:
        dataset (str): dataset name
        batch_size (int): batch size
        test_size (float): test size ratio
    """

    dataset_map = {
        "promise12" : {
            "data_dir" : "../input/promise12-rot-intensity-scale-3d/",
#             "data_dir" : "../../MIS/datasets/promise12/rot_scale/",
            "train_img_transform" : [
                AddChannel(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                ToTensor()
                ],
            "train_label_transform" : [
                AddChannel(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                AsDiscrete(threshold=0.5),
                ToTensor()
                ],
            "test_img_transform" : [
                AddChannel(),
                ToTensor()
                ],
            "test_label_transfrom" : [
                AddChannel(),
                ToTensor()
                ],
            },
        # Issue in reading the original pixdim for decathlon prostate dataset.
        # Fix the pixdim by manually setting it to 1.
        "decathlon" : {
            "data_dir" : "../input/decathlonprostate/Task05_Prostate/",
#             "data_dir" : "../../MIS/datasets/Task05_Prostate/",
            "train_img_transform" : [
                DecathlonProstaeJustMRI(),
                AddChannel(),
                DecathlonCorrectImage(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                ToTensor()
                ],
            "train_label_transform" : [
                AddChannel(),
                DecathlonCorrectLabel(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                AsDiscrete(threshold=0.5),
                ToTensor()
                ],
            "test_img_transform" : [
                DecathlonProstaeJustMRI(),
                AddChannel(),
                DecathlonCorrectImage(),
                ToTensor()
                ],
            "test_label_transfrom" : [
                AddChannel(),
                DecathlonCorrectLabel(),
                AsDiscrete(threshold=0.5),
                ToTensor()
                ],
            },
        "isbi" : {
            "data_dir" : "../input/isbiv2-merged/ISBI_V2/",
#             "data_dir" : "../../MIS/datasets/ISBI_V2/",
            "remove_indexes" : [47],
            "train_img_transform" : [
                AddChannel(),
                DecathlonCorrectImage(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                ToTensor()
                ],
            "train_label_transform" : [
                AddChannel(),
                DecathlonCorrectLabel(),
                RandSpatialCrop(roi_size= roi_size, random_center = True, random_size=False),
                AsDiscrete(threshold=0.5),
                ToTensor()
                ],
            "test_img_transform" : [
                AddChannel(),
                DecathlonCorrectImage(),
                ToTensor()
                ],
            "test_label_transfrom" : [
                AddChannel(),
                DecathlonCorrectLabel(),
                ToTensor()
                ],
            }
    }
    
    if dataset not in dataset_map:
        raise ValueError("Dataset {} is not supported".format(dataset))
    
    # Get image paths and label paths
    
    dataset_dict = dataset_map[dataset]
    img_paths = glob(dataset_dict["data_dir"] + "imagesTr/*.nii")
    label_paths = glob(dataset_dict["data_dir"] + "labelsTr/*.nii")
    img_paths.sort()
    label_paths.sort()


    # Remove the indexes that create problems during training
    # if "remove_indexes" in dataset_dict:
    #     remove_indexes = dataset_dict["remove_indexes"]
    #     img_paths = [img_paths[i] for i in range(len(img_paths)) if i not in remove_indexes]
    #     label_paths = [label_paths[i] for i in range(len(label_paths)) if i not in remove_indexes]

    # Get folds
    images_fold, labels_fold  = get_img_label_folds(img_paths, label_paths)
#     images_fold, labels_fold  = img_paths, label_paths
    
    print("Number of images: {}".format(len(images_fold)))
    print("Number of labels: {}".format(len(labels_fold)))

    # Get train and test sets
    train_idx = int(len(images_fold) * (1 - test_size))
        
    train_set = ImageDataset(images_fold[:train_idx], labels_fold[:train_idx],
                            transform=Compose(dataset_dict['train_img_transform']), 
                            seg_transform=Compose(dataset_dict['train_label_transform']))

    test_set = ImageDataset(images_fold[train_idx:], labels_fold[train_idx:],
                        transform=Compose(dataset_dict['test_img_transform']),
                        seg_transform=Compose(dataset_dict['test_label_transfrom']))


    # Get dataloaders for train and test sets
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=test_shuffle)
    
    return train_loader, test_loader


## Visualize dataset

In [None]:
# ----------------------------Get dataloaders--------------------------
roi_single = 160
train_loader, test_loader = get_dataloaders(
    dataset="decathlon",
    batch_size=1,
    test_size=0.2,
    roi_size=roi_single,
    test_shuffle=True,
)


print(f"\nTraining samples : {len(train_loader)}")
print(f"Testing samples : {len(test_loader)}")


imgs,labels = next(iter(train_loader))
imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
labels = rearrange(labels, 'b c h w d -> (b d) c h w')
print(f"\nImage shape : {imgs.shape}")
print(f"Label shape : {labels.shape}")

img_no = 8
plt.figure(figsize=(6*3,6*1))
plt.subplot(1,3,1)
plt.imshow(imgs[img_no,0], cmap='gray')
plt.axis('off')
plt.title('Image')
plt.subplot(1,3,2)
plt.imshow(labels[img_no,0], cmap='gray')
plt.axis('off')
plt.title('Label')
plt.subplot(1,3,3)
plt.imshow(imgs[img_no,0], cmap='gray')
plt.imshow(labels[img_no,0], 'copper', alpha=0.2)
plt.axis('off')
plt.title('Overlay')
plt.show()

### Train - Test Statistics

In [None]:
print(f"\nTraining samples : {len(train_loader)}")
print(f"Testing samples : {len(test_loader)}")

# -----------------------Slices-----------------------------

plt.figure(figsize = (2*7, 1*7))

slices = [label.shape[4] for _, label in train_loader]

plt.subplot(1,2,1)
plt.hist(slices, )
plt.xlabel('Slices')
plt.ylabel('Count')
plt.title('Training Set')

slices = [label.shape[4] for _, label in test_loader]
plt.subplot(1,2,2)
plt.hist(slices, )
plt.xlabel('Slices')
plt.ylabel('Count')
plt.title('Testing set')

plt.suptitle('Slices in Training & Testing Sets')
plt.show()

## Train Config, Loss, Metrics

In [None]:
# ----------------------------Train Config-----------------------------------------
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=2,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

epochs = 100
initial_lr = 1e-3
optimizer = Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5)
# optimizer = ASGD(model.parameters(), lr=initial_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose=True)
# scheduler = ExponentialLR(optimizer, gamma=0.98)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
hd_metric = HausdorffDistanceMetric(include_background=False, percentile = 95.)


# post_trans = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_pred = Compose([
    EnsureType(), AsDiscrete(argmax=True, to_onehot=2),
    KeepLargestConnectedComponent(applied_labels=[1], is_onehot=True, connectivity=2)
])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
argmax = AsDiscrete(argmax=True)
# dice_loss = DiceLoss(to_onehot_y=True, softmax=True)
# bce_loss = nn.BCEWithLogitsLoss()
# ce_loss = nn.CrossEntropyLoss(weight = torch.tensor([1., 20.], device = device))
# focal_loss = FocalLoss(to_onehot_y = True, weight = [.5, 95.])
dice_ce_loss = DiceCELoss(to_onehot_y=True, softmax=True,)
# dice_focal_loss = DiceFocalLoss(to_onehot_y=True, softmax=True, focal_weight = torch.tensor([1., 5.], device = device))
# focal_weight = torch.tensor([1., 20.], device = device)
# weight = torch.tensor([1., 5.], device = device)

## WANDB Logging

In [None]:
# ------------------------------------WANDB Logging-------------------------------------
config = {
    "Model" : "UNet2D",
    "Train Input ROI size" : roi_single,
#     "Test Input size" : (1, 320, 320),
    "Test mode" : f"Sliding window inference roi = {roi_single}",
    "Batch size" : "No of slices in original volume",
    "No of volumes per batch" : 1,
    "Epochs" : epochs,
    "Optimizer" : "Adam",
    "Scheduler" : "CosineAnnealingLR",
    "Initial LR" : scheduler.get_last_lr()[0],
    "Loss" : "DiceCELoss", 
    "Train Data Augumentations" : "RandSpatialCrop",
    "Test Data Preprocess" : "None",
    "Train samples" : len(train_loader),
    "Test Samples" : len(test_loader),
#     RandFlip, RandRotate90, RandGaussianNoise, RandGaussSmooth, RandBiasField, RandContrast
    "Pred Post Processing" : "KeepLargestConnectedComponent"
}
if wandb_log:
    wandb.login()
    wandb.init(project="CL_Joint", entity="vinayu", config = config)

## Joint Traning

In [None]:
batch_size = 1
test_shuffle = True

datasets = {'promise12' : {
    'test_size' : 0.1,
    'roi_size' : 160,
        },
            'isbi' : {
    'test_size' : 0.2,
    'roi_size' : 192,
        },
            'decathlon' : {
    'test_size' : 0.2,
    'roi_size' : 160,
        }
           }


dataloaders_map = {
    'promise12' : get_dataloaders(
        dataset = 'promise12',
        batch_size = batch_size,
        test_size = datasets['promise12']['test_size'],
        roi_size = datasets['promise12']['roi_size'],
        test_shuffle = test_shuffle
    ),
    'isbi' : get_dataloaders(
        dataset = 'isbi',
        batch_size = batch_size,
        test_size = datasets['isbi']['test_size'],
        roi_size = datasets['isbi']['roi_size'],
        test_shuffle = test_shuffle
    ),
    'decathlon' : get_dataloaders(
        dataset = 'decathlon',
        batch_size = batch_size,
        test_size = datasets['decathlon']['test_size'],
        roi_size = datasets['decathlon']['roi_size'],
        test_shuffle = test_shuffle
    )
}

metrics_map = {
    'promise12' : {
        'promise12_curr_dice' : 0,
        'promise12_best_dice' : 0,
        'promise12_curr_hd' : 1e10,
        'promise12_best_hd' : 1e10,
        'Epoch' : 0
    },
    'isbi' : {
        'isbi_curr_dice' : 0,
        'isbi_best_dice' : 0,
        'isbi_curr_hd' : 1e10,
        'isbi_best_hd' : 1e10,
        'Epoch' : 0
    },
    'decathlon' : {
        'decathlon_curr_dice' : 0,
        'decathlon_best_dice' : 0,
        'decathlon_curr_hd' : 1e10,
        'decathlon_best_hd' : 1e10,
        'Epoch' : 0
    }
}

## Training & Validation 

In [None]:
def train():
    """
    Inputs : No Inputs
    Outputs : No Outputs
    Function : Trains all datasets and logs metrics to WANDB
    """
    
    train_start = time()
    epoch_loss = 0
    model.train()
    print('\n')
    
    for dataset_name in dataloaders_map:
        train_loader, test_loader = dataloaders_map[dataset_name]
        print(f'\n----------------{dataset_name}----------------')
        print(f'Training samples : {len(train_loader)}')
        # Iterating over the dataset
        for i, (imgs, labels) in enumerate(train_loader, 1):

            imgs = imgs.to(device)
            labels = labels.to(device)
            imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
            labels = rearrange(labels, 'b c h w d -> (b d) c h w')

            optimizer.zero_grad()
            preds = model(imgs)

            loss = dice_ce_loss(preds, labels)

            preds = [post_pred(i) for i in decollate_batch(preds)]
            preds = torch.stack(preds)
            labels = [post_label(i) for i in decollate_batch(labels)]
            labels = torch.stack(labels)
        #         Metric scores
            dice_metric(preds, labels)
            hd_metric(preds, labels)

            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()

            if i % batch_interval == 0:
                print(f"Epoch: [{epoch}/{epochs}], Batch: [{i}/{len(train_loader)}], Loss: {loss.item() :.4f}, \
                      Dice: {dice_metric.aggregate().item() * 100 :.2f}, HD: {hd_metric.aggregate().item() :.2f}")
    
    # Print metrics, log data, reset metrics
    
    print(f"\nEpoch: [{epoch}/{epochs}], Avg Loss: {epoch_loss / len(train_loader) :.3f}, \
              Train Dice: {dice_metric.aggregate().item() * 100 :.2f}, Train HD: {hd_metric.aggregate().item() :.2f}, Time : {int(time() - train_start)} sec")

    if wandb_log:
        wandb.log({"Train Dice" : dice_metric.aggregate().item() * 100,
                   "Train Hausdorff Distance" : hd_metric.aggregate().item(),
                   "Train Loss" : epoch_loss / len(train_loader),
                   "Learning Rate" : scheduler.get_last_lr()[0],
                   "Epoch" : epoch })

    dice_metric.reset()
    hd_metric.reset()
    scheduler.step()


In [None]:
def validate(test_loader : DataLoader, dataset_name : str):
    """
    Inputs : Testing dataloader
    Outputs : Returns Dice, HD
    Function : Validate on the given dataloader and return the mertics 
    """
    train_start = time()
    model.eval()
    with torch.no_grad():
        # Iterate over all samples in the dataset
        for i, (imgs, labels) in enumerate(test_loader, 1):
            imgs = imgs.to(device)
            labels = labels.to(device)
            imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
            labels = rearrange(labels, 'b c h w d -> (b d) c h w')

            roi_size = (datasets[dataset_name]['roi_size'], datasets[dataset_name]['roi_size'])
            preds = sliding_window_inference(inputs=imgs, roi_size=roi_size, sw_batch_size=4,
                                            predictor=model, overlap = 0.5, mode = 'gaussian', device=device)
#                 preds = model(imgs)
            preds = [post_pred(i) for i in decollate_batch(preds)]
            preds = torch.stack(preds)
            labels = [post_label(i) for i in decollate_batch(labels)]
            labels = torch.stack(labels)

            dice_metric(preds, labels)
            hd_metric(preds, labels)

        val_dice = dice_metric.aggregate().item()
        val_hd = hd_metric.aggregate().item()
        
        dice_metric.reset()
        hd_metric.reset()
        
        print("-"*75)
        print(f"Epoch : [{epoch}/{epochs}], Dataset : {dataset_name.upper()}, Test Avg Dice : {val_dice*100 :.2f}, Test Avg HD : {val_hd :.2f}, Time : {int(time() - train_start)} sec")
        print("-"*75)
        
        return val_dice, val_hd

In [None]:
# The main training & validation loop

print("Training started ... \n")

val_interval = 1
batch_interval = 20
best_dice = -1
# best_dice_epoch = 0
best_hd = 1e10
# best_hd_epoch = 0


for epoch in range(1, epochs+1):
    # Trains on all datasets
    train()
    
    # Validation on each dataset individually and log metrics
    for dataset_name in dataloaders_map:
        _, test_loader = dataloaders_map[dataset_name]
        val_dice, val_hd = validate(test_loader, dataset_name)
        
        metrics = metrics_map['dataset_name']
        
        metrics[f'Epoch'] = epoch
        metrics[f'{dataset_name}_curr_dice'] = val_dice * 100
        metrics[f'{dataset_name}_curr_hd'] = val_hd
        
        
        
        if val_dice > metrics[f'{dataset_name}_best_dice']:

            metrics[f'{dataset_name}_best_dice'] = val_dice * 100
#             best_dice_epoch = epoch
#             torch.save(model.state_dict(), "best_model.pt")
#             print(f"Best model saved at epoch {best_metric_epoch} with Dice {best_metric*100:.2f}")


        if val_hd < metrics[f'{dataset_name}_best_hd'] and val_hd > 0:

            metrics[f'{dataset_name}_best_hd'] = val_hd
#             best_hd_epoch = epoch

        if wandb_log:
            # Quantiative metrics
            wandb.log(metrics)

            # Qualitative resulsts
            
#             preds = torch.stack([argmax(c) for c in preds])
#             labels = torch.stack([argmax(c) for c in labels])

#             f = make_grid(torch.cat([imgs,labels,preds],dim=3), nrow =2, padding = 20, pad_value = 1)
#             images = wandb.Image(rearrange(f.cpu(), 'c h w -> h w c'), caption="Left: Input, Middle : Ground Truth, Right: Prediction")
#             wandb.log({"Predictions": images, "Epoch" : epoch})
            
            print('Logged data to wandb')

In [None]:
# print("Training started ... \n")

# val_interval = 1
# batch_interval = 20
# best_dice = -1
# best_dice_epoch = 0
# best_hd = 1e10
# best_hd_epoch = 0

# # For in epochs
# for epoch in range(1, epochs+1):

#     # Training
#     train_start = time()
#     model.train()
#     epoch_loss = 0
#     print('\n')
    
#     # Iterate over datasets
    
#     for dataset_name in dataloaders_map:
#         train_loader, test_loader = dataloaders_map[dataset_name]
#         print(f'\n----------------{dataset_name}----------------')
#         print(f'Training samples : {len(train_loader)}')
# #         print(f'Testing samples : {len(test_loader)}\n')
        
#         # Iterate over all samples in the dataset
        
#         for i, (imgs, labels) in enumerate(train_loader, 1):

#             imgs = imgs.to(device)
#             labels = labels.to(device)
#             imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
#             labels = rearrange(labels, 'b c h w d -> (b d) c h w')

#             optimizer.zero_grad()
#             preds = model(imgs)

#     #         loss = dice_loss(preds, labels)
#     #         labels_temp = rearrange(labels, 'b c h w -> (b c) h w')
#     #         loss = ce_loss(preds, labels_temp.long().to(device))
#     #         loss = focal_loss(preds, labels)
#             loss = dice_ce_loss(preds, labels)

#             preds = [post_pred(i) for i in decollate_batch(preds)]
#             preds = torch.stack(preds)
#             labels = [post_label(i) for i in decollate_batch(labels)]
#             labels = torch.stack(labels)
#     #         Metric scores
#             dice_metric(preds, labels)
#             hd_metric(preds, labels)

#             loss.backward()
#             optimizer.step()

#             epoch_loss += loss.item()

#             if i % batch_interval == 0:
#                 print(f"Epoch: [{epoch}/{epochs}], Batch: [{i}/{len(train_loader)}], Loss: {loss.item() :.4f}, \
#                       Dice: {dice_metric.aggregate().item() * 100 :.2f}, HD: {hd_metric.aggregate().item() :.2f}")
    
#     # Print metrics, log data, reset metrics
    
#     print(f"\nEpoch: [{epoch}/{epochs}], Avg Loss: {epoch_loss / len(train_loader) :.3f}, \
#               Train Dice: {dice_metric.aggregate().item() * 100 :.2f}, Train HD: {hd_metric.aggregate().item() :.2f}, Time : {int(time() - train_start)} sec")

#     if wandb_log:
#         wandb.log({"Train Dice" : dice_metric.aggregate().item() * 100,
#                    "Train Hausdorff Distance" : hd_metric.aggregate().item(),
#                    "Train Loss" : epoch_loss / len(train_loader),
#                    "Learning Rate" : scheduler.get_last_lr()[0],
#                    "Epoch" : epoch })

#     dice_metric.reset()
#     hd_metric.reset()
#     scheduler.step()
        
    

#     # Validation
#     if epoch % val_interval == 0:
#         train_start = time()
#         model.eval()
#         with torch.no_grad():
#             for dataset_name in dataloaders_map:
#                 train_loader, test_loader = dataloaders_map[dataset_name]
#                 print(f'\n---------Validating on {dataset_name} - {len(test_loader)} samples---------')
# #                 print(f'Training samples : {len(train_loader)}')

#                 # Iterate over all samples in the dataset
#                 for i, (imgs, labels) in enumerate(test_loader, 1):
#                     imgs = imgs.to(device)
#                     labels = labels.to(device)
#                     imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
#                     labels = rearrange(labels, 'b c h w d -> (b d) c h w')

#                     roi_size = (datasets[dataset_name]['roi_size'], datasets[dataset_name]['roi_size'])
#                     preds = sliding_window_inference(inputs=imgs, roi_size=roi_size, sw_batch_size=4,
#                                                     predictor=model, overlap = 0.5, mode = 'gaussian', device=device)
#     #                 preds = model(imgs)
#                     preds = [post_pred(i) for i in decollate_batch(preds)]
#                     preds = torch.stack(preds)
#                     labels = [post_label(i) for i in decollate_batch(labels)]
#                     labels = torch.stack(labels)

#                     dice_metric(preds, labels)
#                     hd_metric(preds, labels)

#                 val_dice = dice_metric.aggregate().item()
#                 val_hd = hd_metric.aggregate().item()

#                 print("-"*75)
#                 print(f"Epoch : [{epoch}/{epochs}], Test Avg Dice : {val_dice*100 :.2f}, Test Avg HD : {val_hd :.2f}, Time : {int(time() - train_start)} sec")
#                 print("-"*75)



#     if val_dice > best_dice:

#         best_dice = val_dice
#         best_dice_epoch = epoch
# #                 torch.save(model.state_dict(), "best_model.pt")
# #                 print(f"Best model saved at epoch {best_metric_epoch} with Dice {best_metric*100:.2f}")


#     if val_hd < best_hd and val_hd > 0:

#         best_hd = val_hd
#         best_hd_epoch = epoch

#     if wandb_log:
#         wandb.log({"Test Dice" : val_dice * 100,
#                    "Test Best Dice" : best_dice * 100,
#                    "Test Hausdorff Distance" : hd_metric.aggregate().item(),
#                    "Test Best HD" : best_hd,
#                    "Epoch" : epoch })

#         preds = torch.stack([argmax(c) for c in preds])
#         labels = torch.stack([argmax(c) for c in labels])


#         f = make_grid(torch.cat([imgs,labels,preds],dim=3), nrow =2, padding = 20, pad_value = 1)
#         images = wandb.Image(rearrange(f.cpu(), 'c h w -> h w c'), caption="Left: Input, Middle : Ground Truth, Right: Prediction")
#         wandb.log({"Predictions": images, "Epoch" : epoch})
#         print('Logged data to wandb')

#     dice_metric.reset()
#     hd_metric.reset()

# # ---------------------------------------------------------------------

# print(f"Completed training, best model saved at epoch {best_dice_epoch} with Dice {best_dice*100:.2f} and Best HD {best_hd:.2f} at epoch {best_hd_epoch}")