# Continual Learning Pre - Flight Test Methods :  **Replay Learning**

## Necessary Installs

In [5]:
!pip install -q monai einops

[0m

## Necessary Imports

In [6]:
# imports and installs
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid
from torch.optim import SGD, Adam, ASGD

import numpy as np
import matplotlib.pyplot as plt
from glob import glob
import monai
from monai.transforms import (ScaleIntensityRange, Compose, AddChannel, RandSpatialCrop, ToTensor, 
                            RandAxisFlip, Activations, AsDiscrete, Resize, RandRotate, RandFlip, EnsureType,
                             KeepLargestConnectedComponent, CenterSpatialCrop)
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, FocalLoss, GeneralizedDiceLoss, DiceCELoss, DiceFocalLoss
from monai.networks.nets import UNet, VNet, UNETR, SwinUNETR, AttentionUnet
from monai.data import decollate_batch, ImageDataset
from monai.utils import set_determinism
import os
import wandb
from time import time
from einops import rearrange
from einops.layers.torch import Rearrange
from torch.optim.lr_scheduler import ExponentialLR, CosineAnnealingLR
from random import sample

torch.manual_seed(2000)
set_determinism(seed=2000)

wandb_log = True

## Create Data Loader 

In [7]:
def get_img_label_folds(img_paths, label_paths):
    
    fold = list(range(0,len(img_paths)))
    fold = sample(fold, k=len(fold))
    fold_imgs = [img_paths[i] for i in fold]
    fold_labels = [label_paths[i] for i in fold]
    return fold_imgs, fold_labels

# Transforms for images & labels
train_roi_size = 160
transforms_map = {
        "train_img_transform" : [
            AddChannel(),
#             CenterSpatialCrop([train_roi_size, train_roi_size, -1]),
            RandSpatialCrop(roi_size= train_roi_size, random_center = True, random_size=False),
            ToTensor()
            ],
        "train_label_transform" : [
            AddChannel(),
#             CenterSpatialCrop([train_roi_size, train_roi_size, -1]),
            RandSpatialCrop(roi_size= train_roi_size, random_center = True, random_size=False),
            AsDiscrete(threshold=0.5),
            ToTensor()
            ],
        "test_img_transform" : [
            AddChannel(),
#             CenterSpatialCrop([train_roi_size, train_roi_size, -1]),
            ToTensor()
            ],
        "test_label_transform" : [
            AddChannel(),
#             CenterSpatialCrop([train_roi_size, train_roi_size, -1]),
            AsDiscrete(threshold=0.5),
            ToTensor()
            ],
    }

# 1. Image & Label paths

dataset_map = {
        "promise12" : {
            "data_dir" : "../input/promise12prostatealigned/",
            "test_size" : 0.1,
            'test' :  {'images' : [], 'labels' : []},
            'train' :  {'images' : [], 'labels' : []}
            
            },
        "decathlon" : {
            "data_dir" : "../input/decathlonprostatealigned/",
            "test_size" : 0.2,
            'test' :  {'images' : [], 'labels' : []},
            'train' :  {'images' : [], 'labels' : []}
            },
        "isbi" : {
            "data_dir" : "../input/isbiprostatealigned/",
            "test_size" : 0.2,
            'test' :  {'images' : [], 'labels' : []},
            'train' :  {'images' : [], 'labels' : []}
            }
    }


for dataset in dataset_map:
    print(f"------------{dataset}------------")
    data_dir = dataset_map[dataset]['data_dir']

    img_paths = glob(data_dir + "imagesTr/*.nii")
    label_paths = glob(data_dir + "labelsTr/*.nii")
    img_paths.sort()
    label_paths.sort()
    
    # 2. Folds

    images_fold, labels_fold  = get_img_label_folds(img_paths, label_paths)
    
    print("Number of images: {}".format(len(images_fold)))
    print("Number of labels: {}".format(len(labels_fold)))
    
    # Get train and test sets
    # 3. Split into train - test
    train_idx = int(len(images_fold) * (1 - dataset_map[dataset]['test_size']))
    
    # Store train & test sets 
    
    dataset_map[dataset]['train']['images'] = images_fold[:train_idx]
    dataset_map[dataset]['train']['labels'] = labels_fold[:train_idx]
    
    dataset_map[dataset]['test']['images'] = images_fold[train_idx:]
    dataset_map[dataset]['test']['labels'] = labels_fold[train_idx:]

------------promise12------------
Number of images: 50
Number of labels: 50
------------decathlon------------
Number of images: 32
Number of labels: 32
------------isbi------------
Number of images: 79
Number of labels: 79


In [8]:
batch_size = 1
def get_dataloader(img_paths : list, label_paths : list, train : bool):
    
    if train:
        ttset = "train"
    else:
        ttset = "test"
        
    dataset = ImageDataset(img_paths, label_paths,
                            transform=Compose(transforms_map[f'{ttset}_img_transform']), 
                            seg_transform=Compose(transforms_map[f'{ttset}_label_transform']))
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    return  dataloader

In [9]:
dataloaders_map = {}

for dataset in dataset_map:
    print(f"------------{dataset}------------")
    for ttset in ['train', 'test']:
        if ttset == 'train':
            train = True
        else:
            train = False
        dataloaders_map[dataset] = {}
        dataloaders_map[dataset]['test'] = get_dataloader(img_paths = dataset_map[dataset][ttset]['images'],
                                                          label_paths = dataset_map[dataset][ttset]['labels'],
                                                          train = train)
        
        print(f"""No of samples in {dataset}-{ttset} : {len(dataloaders_map[dataset]['test'])}""")

# 7. That's it

------------promise12------------
No of samples in promise12-train : 45
No of samples in promise12-test : 5
------------decathlon------------
No of samples in decathlon-train : 25
No of samples in decathlon-test : 7
------------isbi------------
No of samples in isbi-train : 63
No of samples in isbi-test : 16


## Visualize dataset

In [10]:
# # ----------------------------Get dataloaders--------------------------
# imgs,labels = next(iter(dataloaders_map['promise12']['train']))
# imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
# labels = rearrange(labels, 'b c h w d -> (b d) c h w')
# print(f"\nImage shape : {imgs.shape}")
# print(f"Label shape : {labels.shape}")

# img_no = 8
# plt.figure(figsize=(6*3,6*1))
# plt.subplot(1,3,1)
# plt.imshow(imgs[img_no,0], cmap='gray')
# plt.axis('off')
# plt.title('Image')
# plt.subplot(1,3,2)
# plt.imshow(labels[img_no,0], cmap='gray')
# plt.axis('off')
# plt.title('Label')
# plt.subplot(1,3,3)
# plt.imshow(imgs[img_no,0], cmap='gray')
# plt.imshow(labels[img_no,0], 'copper', alpha=0.2)
# plt.axis('off')
# plt.title('Overlay')
# plt.show()

### Train - Test Statistics

In [7]:
# print(f"\nTraining samples : {len(train_loader)}")
# # print(f"Testing samples : {len(test_loader)}")

# # -----------------------Slices-----------------------------

# plt.figure(figsize = (1*7, 1*7))

# slices = [label.shape[4] for _, label in train_loader]

# # plt.subplot(1,2,1)
# plt.hist(slices, )
# plt.xlabel('Slices')
# plt.ylabel('Count')
# plt.title('Training Set')

# # slices = [label.shape[4] for _, label in test_loader]
# # plt.subplot(1,2,2)
# # plt.hist(slices, )
# # plt.xlabel('Slices')
# # plt.ylabel('Count')
# # plt.title('Testing set')

# plt.suptitle('Slices in Training & Testing Sets')
# plt.show()

## Train Config, Loss, Metrics

In [8]:
# ----------------------------Train Config-----------------------------------------
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

model = UNet(
        spatial_dims=2,
        in_channels=1,
        out_channels=2,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)

epochs = 100
initial_lr = 1e-3
optimizer = Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5)
# optimizer = ASGD(model.parameters(), lr=initial_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose=True)
# scheduler = ExponentialLR(optimizer, gamma=0.98)
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
hd_metric = HausdorffDistanceMetric(include_background=False, percentile = 95.)


# post_trans = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_pred = Compose([
    EnsureType(), AsDiscrete(argmax=True, to_onehot=2),
    KeepLargestConnectedComponent(applied_labels=[1], is_onehot=True, connectivity=2)
])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
argmax = AsDiscrete(argmax=True)
# dice_loss = DiceLoss(to_onehot_y=True, softmax=True)
# bce_loss = nn.BCEWithLogitsLoss()
# ce_loss = nn.CrossEntropyLoss(weight = torch.tensor([1., 20.], device = device))
# focal_loss = FocalLoss(to_onehot_y = True, weight = [.5, 95.])
dice_ce_loss = DiceCELoss(to_onehot_y=True, softmax=True,)
# dice_focal_loss = DiceFocalLoss(to_onehot_y=True, softmax=True, focal_weight = torch.tensor([1., 5.], device = device))
# focal_weight = torch.tensor([1., 20.], device = device)
# weight = torch.tensor([1., 5.], device = device)

Device: cuda:0
Adjusting learning rate of group 0 to 1.0000e-03.


## WANDB Logging

In [9]:
# ------------------------------------WANDB Logging-------------------------------------
config = {
    "Model" : "UNet2D",
    "Seqential Strategy" : "Replay with 10% dataset buffer storage",
    "Batch Training Strategy" : "A batch from current dataset and a batch from episodic memeory are stacked. One backward pass and paramenter update.",
    "Train Input ROI size" : train_roi_size,
#     "Test Input size" : (1, 320, 320),
    "Test mode" : f"Sliding window inference roi = {train_roi_size}",
    "Batch size" : "No of slices in original volume",
    "No of volumes per batch" : 1,
    "Epochs" : epochs,
    "Optimizer" : "Adam",
    "Scheduler" : "CosineAnnealingLR",
    "Initial LR" : scheduler.get_last_lr()[0],
    "Loss" : "DiceCELoss", 
    "Train Data Augumentations" : "RandSpatialCrop",
    "Test Data Preprocess" : "None",
    "Train samples" : {"Promise12" : 45, "ISBI" : 63, "Decathlon" : 25},
    "Test Samples" : {"Promise12" : 5, "ISBI" : 16, "Decathlon" : 7},
#     RandFlip, RandRotate90, RandGaussianNoise, RandGaussSmooth, RandBiasField, RandContrast
    "Pred Post Processing" : "KeepLargestConnectedComponent"
}
if wandb_log:
    wandb.login()
    wandb.init(project="CL_Sequential", entity="vinayu", config = config)

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mvinayu[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Replay Training

## Training function 

In [10]:
# def batch_train(batch_imgs : torch.Tensor, batch_labels : torch.Tensor):
    
    
#     imgs = batch_imgs.to(device)
#     labels = batch_labels.to(device)
#     imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
#     labels = rearrange(labels, 'b c h w d -> (b d) c h w')

#     preds = model(imgs)

#     loss = dice_ce_loss(preds, labels)

#     preds = [post_pred(i) for i in decollate_batch(preds)]
#     preds = torch.stack(preds)
#     labels = [post_label(i) for i in decollate_batch(labels)]
#     labels = torch.stack(labels)
    
#     # Metric scores
#     dice_metric(preds, labels)
#     hd_metric(preds, labels)

#     return loss

### Serial Batch Training 

In [11]:
# def train(train_loader : DataLoader, em_loader : DataLoader):
#     """
#     Inputs : No Inputs
#     Outputs : No Outputs
#     Function : Trains all samples in train_loader and log metrics to WANDB
#     """
    
#     train_start = time()
#     epoch_loss = 0
#     model.train()
#     print('\n')
    
    
#     # Iterating over the dataset
#     for i, (imgs, labels) in enumerate(train_loader, 1):
        
        
#         # Set all grads to zero before any training
        
#         optimizer.zero_grad()
        
#         # Get a single batch(imgs, labels) from train loader and call train_batch on it
        
#         loss = batch_train(batch_imgs = imgs, batch_labesl = labels)
        
#         # Get a single random batch(imgs, labels) from em_loader and call train_batch on it
        
#         loss += batch_train(batch_imgs = imgs, batch_labesl = labels)
        
#         # Accumulate the loss from both batches
        
#         # Compute loss with backward pass
        
#         # Optimize the parameters
        
#         # 

        
#         preds = model(imgs)

#         loss = dice_ce_loss(preds, labels)

#         preds = [post_pred(i) for i in decollate_batch(preds)]
#         preds = torch.stack(preds)
#         labels = [post_label(i) for i in decollate_batch(labels)]
#         labels = torch.stack(labels)
#     #         Metric scores
#         dice_metric(preds, labels)
#         hd_metric(preds, labels)

#         loss.backward()
#         optimizer.step()

#         epoch_loss += loss.item()

#         if i % batch_interval == 0:
#             print(f"Epoch: [{epoch}/{epochs}], Batch: [{i}/{len(train_loader)}], Loss: {loss.item() :.4f}, \
#                   Dice: {dice_metric.aggregate().item() * 100 :.2f}, HD: {hd_metric.aggregate().item() :.2f}")
    
#     # Print metrics, log data, reset metrics
    
#     print(f"\nEpoch: [{epoch}/{epochs}], Avg Loss: {epoch_loss / len(train_loader) :.3f}, \
#               Train Dice: {dice_metric.aggregate().item() * 100 :.2f}, Train HD: {hd_metric.aggregate().item() :.2f}, Time : {int(time() - train_start)} sec")

# #     if wandb_log:
# #         wandb.log({"Train Dice" : dice_metric.aggregate().item() * 100,
# #                    "Train Hausdorff Distance" : hd_metric.aggregate().item(),
# #                    "Train Loss" : epoch_loss / len(train_loader),
# #                    "Learning Rate" : scheduler.get_last_lr()[0],
# #                    "Epoch" : epoch })


#     dice_metric.reset()
#     hd_metric.reset()
#     scheduler.step()


### Stacked Batch Training

In [12]:
def train(train_loader : DataLoader, em_loader : DataLoader = None):
    """
    Inputs : No Inputs
    Outputs : No Outputs
    Function : Trains all datasets and logs metrics to WANDB
    """
    
    train_start = time()
    epoch_loss = 0
    model.train()
    print('\n')
    
    
    # Iterating over the dataset
    for i, (imgs, labels) in enumerate(train_loader, 1):

        imgs, labels = imgs.to(device), labels.to(device)
        
        if em_loader is not None:
            em_imgs, em_labels = next(iter(em_loader))
            em_imgs, em_labels = em_imgs.to(device), em_labels.to(device)
        
            # Stacking up batch from current dataset and episodic memeory 
            imgs, labels = torch.cat([imgs, em_imgs], dim=-1), torch.cat([labels, em_labels], dim=-1)
        
        imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
        labels = rearrange(labels, 'b c h w d -> (b d) c h w')

        optimizer.zero_grad()
        preds = model(imgs)

        loss = dice_ce_loss(preds, labels)

        preds = [post_pred(i) for i in decollate_batch(preds)]
        preds = torch.stack(preds)
        labels = [post_label(i) for i in decollate_batch(labels)]
        labels = torch.stack(labels)
    #         Metric scores
        dice_metric(preds, labels)
        hd_metric(preds, labels)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if i % batch_interval == 0:
            print(f"Epoch: [{epoch}/{epochs}], Batch: [{i}/{len(train_loader)}], Loss: {loss.item() :.4f}, \
                  Dice: {dice_metric.aggregate().item() * 100 :.2f}, HD: {hd_metric.aggregate().item() :.2f}")
    
    # Print metrics, log data, reset metrics
    
    print(f"\nEpoch: [{epoch}/{epochs}], Avg Loss: {epoch_loss / len(train_loader) :.3f}, \
              Train Dice: {dice_metric.aggregate().item() * 100 :.2f}, Train HD: {hd_metric.aggregate().item() :.2f}, Time : {int(time() - train_start)} sec")

#     if wandb_log:
#         wandb.log({"Train Dice" : dice_metric.aggregate().item() * 100,
#                    "Train Hausdorff Distance" : hd_metric.aggregate().item(),
#                    "Train Loss" : epoch_loss / len(train_loader),
#                    "Learning Rate" : scheduler.get_last_lr()[0],
#                    "Epoch" : epoch })


    dice_metric.reset()
    hd_metric.reset()
    scheduler.step()

## Validation function

In [13]:
batch_size = 1
test_shuffle = True

test_map_config = {
            'promise12' : {'roi_size' : 160},
            'isbi' : {'roi_size' : 160},
            'decathlon' : {'roi_size' : 160},
           }


In [14]:
def validate(test_loader : DataLoader, dataset_name : str = None):
    """
    Inputs : Testing dataloader
    Outputs : Returns Dice, HD
    Function : Validate on the given dataloader and return the mertics 
    """
    train_start = time()
    model.eval()
    with torch.no_grad():
        # Iterate over all samples in the dataset
        for i, (imgs, labels) in enumerate(test_loader, 1):
            imgs = imgs.to(device)
            labels = labels.to(device)
            imgs = rearrange(imgs, 'b c h w d -> (b d) c h w')
            labels = rearrange(labels, 'b c h w d -> (b d) c h w')

            roi_size = (test_map_config[dataset_name]['roi_size'], test_map_config[dataset_name]['roi_size'])
            preds = sliding_window_inference(inputs=imgs, roi_size=roi_size, sw_batch_size=4,
                                            predictor=model, overlap = 0.5, mode = 'gaussian', device=device)
#                 preds = model(imgs)
            preds = [post_pred(i) for i in decollate_batch(preds)]
            preds = torch.stack(preds)
            labels = [post_label(i) for i in decollate_batch(labels)]
            labels = torch.stack(labels)

            dice_metric(preds, labels)
            hd_metric(preds, labels)

        val_dice = dice_metric.aggregate().item()
        val_hd = hd_metric.aggregate().item()
        
        dice_metric.reset()
        hd_metric.reset()
        
        print("-"*75)
        print(f"Epoch : [{epoch}/{epochs}], Dataset : {dataset_name.upper()}, Test Avg Dice : {val_dice*100 :.2f}, Test Avg HD : {val_hd :.2f}, Time : {int(time() - train_start)} sec")
        print("-"*75)
        
        return val_dice, val_hd

## Trainer & Validator

In [15]:
val_interval = 5
batch_interval = 25

## Main Training Sequence

### Initialize replay memory buffer 

In [16]:
# Empty replay buffer as a list
replay_buffer = {
    "train" : {
        'images' : [],
        'labels' : [],
    },
}

In [17]:
# Training on Promise12 --> Testing on Promise12
epochs = 100
initial_lr = 1e-3
optimizer = Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5)
# optimizer = ASGD(model.parameters(), lr=initial_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose=True)

dataset_name = 'promise12'
train_loader = get_dataloader(img_paths = dataset_map[dataset_name]['train']['images'],
                              label_paths = dataset_map[dataset_name]['train']['labels'],
                              train = True)

test_dataset_names = ['promise12']

metric_prefix  = 'p12->'

metrics_map = {}
for dname in test_dataset_names:
    metrics_map[dname] = {
        f'{metric_prefix}_{dname}_curr_dice' : 0,
        f'{metric_prefix}_{dname}_best_dice' : 0,
        f'{metric_prefix}_{dname}_curr_hd' : 1e10,
        f'{metric_prefix}_{dname}_best_hd' : 1e10,
        'Epoch' : 0
    }

for epoch in range(1, epochs+1):   
    
        train(train_loader = train_loader, em_loader = None)
        
        if epoch % val_interval == 0:
            for dname in test_dataset_names:
                val_dice, val_hd = validate(test_loader = dataloaders_map[dname]['test'], dataset_name = dname)

                val_dice *= 100

                metrics = metrics_map[dname]

                metrics[f'Epoch'] = epoch
                metrics[f'{metric_prefix}_{dname}_curr_dice'] = val_dice 
                metrics[f'{metric_prefix}_{dname}_curr_hd'] = val_hd
                metrics[f'{metric_prefix}_{dname}_best_dice'] = max(val_dice, metrics[f'{metric_prefix}_{dname}_best_dice'])

                if val_hd < metrics[f'{metric_prefix}_{dname}_best_hd'] and val_hd > 0:
                    metrics[f'{metric_prefix}_{dname}_best_hd'] = val_hd

                if wandb_log:
                    # Quantiative metrics
                    wandb.log(metrics)
                    print('Logged data to wandb')

Adjusting learning rate of group 0 to 1.0000e-03.




  diff_b_a = subtract(b, a)


Epoch: [1/100], Batch: [25/45], Loss: 0.9604,                   Dice: 0.52, HD: 92.55





Epoch: [1/100], Avg Loss: 1.004,               Train Dice: 0.29, Train HD: 97.66, Time : 43 sec
Adjusting learning rate of group 0 to 9.9975e-04.


Epoch: [2/100], Batch: [25/45], Loss: 0.8095,                   Dice: 2.00, HD: 82.41

Epoch: [2/100], Avg Loss: 0.836,               Train Dice: 2.37, Train HD: 84.36, Time : 12 sec
Adjusting learning rate of group 0 to 9.9901e-04.


Epoch: [3/100], Batch: [25/45], Loss: 0.7148,                   Dice: 0.07, HD: 80.41

Epoch: [3/100], Avg Loss: 0.749,               Train Dice: 1.42, Train HD: 75.16, Time : 11 sec
Adjusting learning rate of group 0 to 9.9778e-04.


Epoch: [4/100], Batch: [25/45], Loss: 0.6459,                   Dice: 0.13, HD: 61.66

Epoch: [4/100], Avg Loss: 0.695,               Train Dice: 3.13, Train HD: 63.75, Time : 11 sec
Adjusting learning rate of group 0 to 9.9606e-04.


Epoch: [5/100], Batch: [25/45], Loss: 0.9526,                   Dice: 0.00, HD: 104.31

Epoch: [5/100], Avg Loss: 0.679,               Train Dice:

### Store few samples to replay memory buffer

In [18]:
replay_percentage = 0.1
train_samples_count = len(dataset_map[dataset_name]['train']['images'])
replay_count = int(train_samples_count * replay_percentage)
print(f"Storing {replay_count} Promise 12 Samples to replay buffer")
idxs = idxs = list(map(int, np.linspace(0, train_samples_count-1, num=replay_count).tolist()))
replay_buffer['train']['images'] +=  [dataset_map[dataset_name]['train']['images'][idx] for idx in idxs]
replay_buffer['train']['labels'] +=  [dataset_map[dataset_name]['train']['labels'][idx] for idx in idxs]
print(f"Current replay buffer size : {len(replay_buffer['train']['labels'])}")

Storing 4 Promise 12 Samples to replay buffer
Current replay buffer size : 4


In [19]:
# # Train on ISBI --> Test on Promise12 & ISBI

epochs = 100
initial_lr = 1e-3
optimizer = Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5)
# optimizer = ASGD(model.parameters(), lr=initial_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose=True)


dataset_name = 'isbi'
# img_paths = dataset_map[dataset_name]['train']['images'] + replay_buffer['train']['images']
# label_paths = dataset_map[dataset_name]['train']['labels'] + replay_buffer['train']['labels']

img_paths = dataset_map[dataset_name]['train']['images'] 
label_paths = dataset_map[dataset_name]['train']['labels'] 

train_loader = get_dataloader(img_paths = img_paths,
                           label_paths = label_paths,
                           train = True)

em_loader = get_dataloader(img_paths = replay_buffer['train']['images'],
                              label_paths = replay_buffer['train']['labels'],
                              train = True)

test_dataset_names = ['promise12', 'isbi']

metric_prefix  += 'isbi->'

metrics_map = {}
for dname in test_dataset_names:
    metrics_map[dname] = {
        f'{metric_prefix}_{dname}_curr_dice' : 0,
        f'{metric_prefix}_{dname}_best_dice' : 0,
        f'{metric_prefix}_{dname}_curr_hd' : 1e10,
        f'{metric_prefix}_{dname}_best_hd' : 1e10,
        'Epoch' : 0
    }

for epoch in range(1, epochs+1):   
        train(train_loader = train_loader, em_loader = em_loader)
        
        if epoch % val_interval == 0:
            for dname in test_dataset_names:
                val_dice, val_hd = validate(test_loader = dataloaders_map[dname]['test'], dataset_name = dname)

                val_dice *= 100

                metrics = metrics_map[dname]

                metrics[f'Epoch'] = epoch
                metrics[f'{metric_prefix}_{dname}_curr_dice'] = val_dice 
                metrics[f'{metric_prefix}_{dname}_curr_hd'] = val_hd
                metrics[f'{metric_prefix}_{dname}_best_dice'] = max(val_dice, metrics[f'{metric_prefix}_{dname}_best_dice'])

                if val_hd < metrics[f'{metric_prefix}_{dname}_best_hd'] and val_hd > 0:
                    metrics[f'{metric_prefix}_{dname}_best_hd'] = val_hd

                if wandb_log:
                    # Quantiative metrics
                    wandb.log(metrics)
                    print('Logged data to wandb')

Adjusting learning rate of group 0 to 1.0000e-03.


Epoch: [1/100], Batch: [25/63], Loss: 0.6046,                   Dice: 58.43, HD: 24.58
Epoch: [1/100], Batch: [50/63], Loss: 0.4705,                   Dice: 64.39, HD: 22.09

Epoch: [1/100], Avg Loss: 0.446,               Train Dice: 63.60, Train HD: 23.29, Time : 49 sec
Adjusting learning rate of group 0 to 9.9975e-04.


Epoch: [2/100], Batch: [25/63], Loss: 0.4831,                   Dice: 66.90, HD: 23.24
Epoch: [2/100], Batch: [50/63], Loss: 0.4426,                   Dice: 65.45, HD: 23.29

Epoch: [2/100], Avg Loss: 0.436,               Train Dice: 65.93, Train HD: 22.47, Time : 29 sec
Adjusting learning rate of group 0 to 9.9901e-04.


Epoch: [3/100], Batch: [25/63], Loss: 0.4757,                   Dice: 66.24, HD: 19.35
Epoch: [3/100], Batch: [50/63], Loss: 0.2772,                   Dice: 71.11, HD: 17.70

Epoch: [3/100], Avg Loss: 0.408,               Train Dice: 71.65, Train HD: 17.53, Time : 28 sec
Adjusting learning rate of g

### Store few samples to replay memory buffer

In [20]:
replay_percentage = 0.1
train_samples_count = len(dataset_map[dataset_name]['train']['images'])
replay_count = int(train_samples_count * replay_percentage)
print(f"Storing {replay_count} ISBI Samples to replay buffer")
idxs = idxs = list(map(int, np.linspace(0, train_samples_count-1, num=replay_count).tolist()))
replay_buffer['train']['images'] +=  [dataset_map[dataset_name]['train']['images'][idx] for idx in idxs]
replay_buffer['train']['labels'] +=  [dataset_map[dataset_name]['train']['labels'][idx] for idx in idxs]
print(f"Current replay buffer size : {len(replay_buffer['train']['labels'])}")

Storing 6 ISBI Samples to replay buffer
Current replay buffer size : 10


In [21]:
# # Train on Decathlon --> Test on Promise12, ISBI & Decathlon
epochs = 100
initial_lr = 1e-3
optimizer = Adam(model.parameters(), lr=initial_lr, weight_decay=1e-5)
# optimizer = ASGD(model.parameters(), lr=initial_lr)
scheduler = CosineAnnealingLR(optimizer, T_max=epochs, verbose=True)


dataset_name = 'isbi'
# img_paths = dataset_map[dataset_name]['train']['images'] + replay_buffer['train']['images']
# label_paths = dataset_map[dataset_name]['train']['labels'] + replay_buffer['train']['labels']

img_paths = dataset_map[dataset_name]['train']['images'] 
label_paths = dataset_map[dataset_name]['train']['labels'] 

train_loader = get_dataloader(img_paths = img_paths,
                           label_paths = label_paths,
                           train = True)

em_loader = get_dataloader(img_paths = replay_buffer['train']['images'],
                              label_paths = replay_buffer['train']['labels'],
                              train = True)

test_dataset_names = ['promise12', 'isbi', 'decathlon']

metric_prefix  += 'dec->'

metrics_map = {}
for dname in test_dataset_names:
    metrics_map[dname] = {
        f'{metric_prefix}_{dname}_curr_dice' : 0,
        f'{metric_prefix}_{dname}_best_dice' : 0,
        f'{metric_prefix}_{dname}_curr_hd' : 1e10,
        f'{metric_prefix}_{dname}_best_hd' : 1e10,
        'Epoch' : 0
    }

for epoch in range(1, epochs+1):   
    
        train(train_loader)
        
        if epoch % val_interval == 0:
            for dname in test_dataset_names:
                val_dice, val_hd = validate(test_loader = dataloaders_map[dname]['test'], dataset_name = dname)

                val_dice *= 100

                metrics = metrics_map[dname]

                metrics[f'Epoch'] = epoch
                metrics[f'{metric_prefix}_{dname}_curr_dice'] = val_dice 
                metrics[f'{metric_prefix}_{dname}_curr_hd'] = val_hd
                metrics[f'{metric_prefix}_{dname}_best_dice'] = max(val_dice, metrics[f'{metric_prefix}_{dname}_best_dice'])

                if val_hd < metrics[f'{metric_prefix}_{dname}_best_hd'] and val_hd > 0:
                    metrics[f'{metric_prefix}_{dname}_best_hd'] = val_hd

                if wandb_log:
                    # Quantiative metrics
                    wandb.log(metrics)
                    print('Logged data to wandb')

Adjusting learning rate of group 0 to 1.0000e-03.


Epoch: [1/100], Batch: [25/63], Loss: 0.2645,                   Dice: 80.90, HD: 15.08
Epoch: [1/100], Batch: [50/63], Loss: 0.3405,                   Dice: 77.67, HD: 17.06

Epoch: [1/100], Avg Loss: 0.357,               Train Dice: 79.19, Train HD: 15.86, Time : 12 sec
Adjusting learning rate of group 0 to 9.9975e-04.


Epoch: [2/100], Batch: [25/63], Loss: 0.3937,                   Dice: 81.40, HD: 11.62
Epoch: [2/100], Batch: [50/63], Loss: 0.3842,                   Dice: 80.69, HD: 13.08

Epoch: [2/100], Avg Loss: 0.374,               Train Dice: 79.37, Train HD: 15.09, Time : 12 sec
Adjusting learning rate of group 0 to 9.9901e-04.


Epoch: [3/100], Batch: [25/63], Loss: 0.4959,                   Dice: 73.53, HD: 19.03
Epoch: [3/100], Batch: [50/63], Loss: 0.4530,                   Dice: 76.28, HD: 16.59

Epoch: [3/100], Avg Loss: 0.383,               Train Dice: 78.24, Train HD: 15.34, Time : 12 sec
Adjusting learning rate of g

In [22]:
replay_percentage = 0.1
train_samples_count = len(dataset_map[dataset_name]['train']['images'])
replay_count = int(train_samples_count * replay_percentage)
print(f"Storing {replay_count} Decathlon Samples to replay buffer")
idxs = idxs = list(map(int, np.linspace(0, train_samples_count-1, num=replay_count).tolist()))
replay_buffer['train']['images'] +=  [dataset_map[dataset_name]['train']['images'][idx] for idx in idxs]
replay_buffer['train']['labels'] +=  [dataset_map[dataset_name]['train']['labels'][idx] for idx in idxs]
print(f"Current replay buffer size : {len(replay_buffer['train']['labels'])}")

Storing 6 Decathlon Samples to replay buffer
Current replay buffer size : 16
