# Download, Install, Import :blush: :cd:

In [None]:
!pip install wandb==0.13.3 -qqq
!pip install monai -qqq
!pip install pytorch_lightning -qqq

In [None]:
!pip uninstall kornia -y
!pip install kornia

## Libraries :books: 

In [3]:

## misc
import os
import argparse
import yaml
import tempfile
import glob

from datetime import datetime, timedelta
from typing import Any, Callable, Dict, Hashable, no_type_check, List, Mapping, Optional, Sequence, Tuple, TypeVar, Union
from typing_extensions import Literal
from types import SimpleNamespace

## logging
import wandb

## math, dl, vision
import numpy as np

import torchvision
from kornia.losses.hausdorff import HausdorffERLoss3D

import torch
import torch.nn as nn
from torch import rand, cat, stack, amax, where, sigmoid
import torch.nn.functional as F
from torch.cuda import is_available
from torch.nn import Dropout3d
from torch.nn.init import kaiming_normal_, normal_, constant_


## torchmetrics
from torchmetrics import Metric
from torchmetrics.functional.classification import multilabel_stat_scores
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.utilities.enums import AverageMethod, MDMCAverageMethod

## pytorch lightning 
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint, LearningRateMonitor, EarlyStopping
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme, CustomBarColumn, ProcessingSpeedColumn


#from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import LightningDataModule



## aesthetics
from rich.progress import ProgressColumn, Task, TextColumn
from rich.text import Text
from rich.style import Style
from rich.console import RenderableType

## MONAI
from monai.apps import DecathlonDataset
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.config import DtypeLike, KeysCollection
from monai.data import DataLoader, decollate_batch
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.inferers import sliding_window_inference
from monai.losses import DiceLoss, FocalLoss, DiceFocalLoss, GeneralizedDiceLoss, GeneralizedDiceFocalLoss
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet, SegResNetVAE
from monai.utils import TraceKeys, set_determinism, UpsampleMode
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type, convert_to_tensor, get_equivalent_dtype


## data-viz
import plotly
import plotly.express as px
plotly.offline.init_notebook_mode (connected = True)

# Env Variables & Dirs :open_file_folder:

In [5]:
os.makedirs('/kaggle/working/wandb/', exist_ok=True)

#os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
os.environ["WANDB_DIR"] = os.path.abspath("/kaggle/working/")

In [6]:
task = "Task01_BrainTumour"
config_dir = "/kaggle/input/brats-configurations/"

In [4]:
## Set seeds for reproducibility and avoid 'workers' to apply
## the same DataAugmentation twice
my_seed = 0
set_determinism(seed=my_seed)
seed_everything(my_seed, workers=True)

0

# src :gear:

In [7]:
from torch import tensor
Tensor = TypeVar("Tensor", bound=tensor)

## Custom Metrics :chart_with_upwards_trend:

In [10]:
class Val_Dice(Metric):
    """
    Custom Val Dice Metric mutuated from MONAI Dice Metric, but with the possibility
    to be extended to multiclass/multilabel segmentation.
    In the original I wasn't able to find the correct combination of parameters to get the
    desired reduction (on batch and classes) for multidim data.
    The same for torchmetrics Dice metric.

    So, I decided to rebuilt it from scratch, based on the pred/gt stats.
    """
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = False

    @no_type_check
    def __init__(
        self,
        zero_division: int = 0,
        num_classes: Optional[int] = None,
        threshold: float = 0.5,
        average: Optional[Literal["micro", "macro", "weighted", "none"]] = "micro",
        mdmc_average: Optional[str] = "samplewise",
        ignore_index: Optional[int] = None,
        top_k: Optional[int] = None,
        multiclass: Optional[bool] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__(**kwargs)
        allowed_average = ("micro", "macro", "weighted", "samples", "none", None)
        if average not in allowed_average:
            raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

        _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None)
        if "reduce" not in kwargs:
            kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average
        if "mdmc_reduce" not in kwargs:
            kwargs["mdmc_reduce"] = mdmc_average

        self.reduce = average
        self.mdmc_reduce = mdmc_average
        self.num_classes = num_classes
        self.threshold = threshold
        self.multiclass = multiclass
        self.ignore_index = ignore_index
        self.top_k = top_k

        default: Callable = lambda: []
        reduce_fn: Optional[str] = "cat"

        for s in ("tp", "fp", "tn", "fn", "sup"):
            self.add_state(s, default=default(), dist_reduce_fx=reduce_fn)
        
        self.average = average
        self.zero_division = zero_division
    
    @no_type_check
    def update(self, preds: Tensor, target: Tensor) -> None:
        """Update state with predictions and targets."""
        stats = multilabel_stat_scores(preds, target, self.num_classes , average=self.average, multidim_average=self.mdmc_reduce)
        tp, fp, tn, fn, sup = torch.unbind(stats.reshape(preds.shape[0], self.num_classes, 5), dim=-1) # 3/4 dpending on num_classes
        
        # Update states
        self.tp.append(tp)
        self.fp.append(fp)
        self.tn.append(tn)
        self.fn.append(fn)
        self.sup.append(sup)
    
    @no_type_check
    def _get_final_stats(self) -> Tuple[Tensor, Tensor, Tensor, Tensor,Tensor]:
        """Performs concatenation on the stat scores if neccesary, before passing them to a compute function."""
        tp = torch.cat(self.tp) if isinstance(self.tp, list) else self.tp
        fp = torch.cat(self.fp) if isinstance(self.fp, list) else self.fp
        tn = torch.cat(self.tn) if isinstance(self.tn, list) else self.tn
        fn = torch.cat(self.fn) if isinstance(self.fn, list) else self.fn
        sup = torch.cat(self.sup) if isinstance(self.sup, list) else self.sup
        return tp, fp, tn, fn, sup
        
    @no_type_check
    def compute(self) -> Tensor:
        """Compute metric."""
        tp, fp, _, fn, sup = self._get_final_stats()
        res =  _dice_compute(tp, fp, fn, self.average, self.mdmc_reduce, self.zero_division)
        return res

## Pytorch Lightning Modules :zap:

### Data Module :zap: :file_cabinet:

In [8]:
# @title brats2018

class DatasetModule(LightningDataModule):
    
    def __init__(
        self,
        data_path: str,
        dataset: str,
        train_batch_size: int = 1,
        val_batch_size: int = 2,
        num_workers: int = 0,
        download: bool = False,
        cache_rate = 0.0,
        transform: Optional[dict] = None,
        classes: int = 3,
        **kwargs,
    ):
        super().__init__()
        self.data_path = data_path 
        self.dataset = dataset # task 
        self.train_batch_size = train_batch_size 
        self.val_batch_size = val_batch_size 
        self.num_workers = num_workers
        self.download = download
        self.cache_rate = cache_rate
        self.transform = transform
        self.args = kwargs # If needed try SimpeNameSpace(**kwargs)


    def setup(self, stage: Optional[str] = None) -> None:
        ### DATASETS ###
        if self.dataset.lower() == "task01_braintumour":
            if stage in ["fit", "validation", None]:
                self.train_dataset = DecathlonDataset(
                    root_dir = self.data_path,
                    task = self.dataset,
                    transform = self.transform['training'] if self.transform is not None else None,
                    section = "training",
                    download = self.download,
                    cache_rate = self.cache_rate,
                    **self.args
                )

        self.val_dataset = DecathlonDataset(
            root_dir = self.data_path,
            task = self.dataset,
            transform = self.transform['validation'] if self.transform is not None else None,
            section = "validation",
            download = self.download,
            cache_rate = self.cache_rate,
            **self.args
        )
  
      
    ### DATALOADERS ###
    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size = self.train_batch_size, shuffle=True, num_workers = self.num_workers)
    
    def val_dataloader(self) -> DataLoader:
        return DataLoader(self.val_dataset, batch_size = self.val_batch_size, shuffle=False, num_workers = self.num_workers)

### Experiment Module :zap: :brain: 

In [11]:
# @title ThesisExperiment

from pytorch_lightning import LightningModule

from torch.optim import SGD, Adam, RMSprop
from torch import cat

from torch.optim.lr_scheduler import CosineAnnealingLR

class ThesisExperiment(LightningModule):
    def __init__(self,
                 model_name,
                 model,
                 criterion_name,
                 criterion,
                 optimizer: str,
                 params: dict) -> None:
        
        super(ThesisExperiment, self).__init__()
        
        
        # FOR MANUAL OPTIMIZATION set to False
        self.automatic_optimization = True
        self.model_name = model_name.lower()
        self.model = model
        self.criterion_name = criterion_name
        self.criterion = criterion
        self.optimizer = optimizer.lower()
        self.params = SimpleNamespace(**params)
        self.curr_device = None
        self.hold_graph = False
        self.val_dice = Val_Dice(num_classes=params['outClasses'],
                                 average='none',
                                 mdmc_average='samplewise') # here params is still a dict
        self.hausdorff_loss = HausdorffERLoss3D(alpha=params['Hausdorff']['alpha'],
                                               k=params['Hausdorff']['k'],
                                               reduction=params['Hausdorff']['reduction']
                                              )
        self.hausdorff_lambda = 1.0
        self.hd_losses = []
        self.dsc_losses = []
        
        try:
            self.hold_graph = self.params['retain_first_backpass']
        except:
            pass
        
    
    def forward(self, input: Tensor, **kwargs) -> Tensor:
        return self.model(input, **kwargs)
    
    def training_step(self, batch, batch_idx):
        #ipdb.set_trace(context=4)
        if self.params.inModalities == 4:
            input_tensor, target = batch['image'].as_tensor(), batch['label'].as_tensor()
        elif self.params.inModalities == 1:
            input_tensor, target = batch[f'image_{self.params.modality}'].as_tensor(), batch['label'].as_tensor()
            
        
        output = self.forward(input_tensor) # output shape:  torch.Size([2, 3, 3, 128, 128, 128]) B, DS, K, H, W, D
        
        if self.model_name == "segresnetvae":
            output, vae_reg_loss, vae_mse_loss = output


        # BEFORE REFACTORING FOR DEEP SUPERVISION
        """
        if self.criterion_name == "dice":
            loss_batch = self.criterion(output, target).mean(dim=0).squeeze() # reduction over batches
            loss = loss_batch.mean() # reduction over classes
        elif self.criterion_name == "focal":
            loss_batch = self.criterion(output, target).mean(dim=[0, 2, 3, 4]) # reduction over batches
            loss = loss_batch.mean() #reduction over classes
        elif self.criterion_name in ["dicefocal","gen_dicefocal"]:
            dl_batch = self.criterion['dice'](output, target).mean(dim=0).squeeze() # reduction over batches
            fl_batch = self.criterion['focal'](output, target).mean(dim=[0, 2, 3, 4]) # reduction over batches
            loss_batch = (self.criterion['lambda_dice'] * dl_batch) + (self.criterion['lambda_focal'] * fl_batch)
            loss = loss_batch.mean() # reduction over classes 
        """
        
        ## AFTER REFACTORING FOR DEEP SUPERVISION
        if self.model_name == "attentionunet":
            if not self.params.deep_supervision:
                pass # as before (after debug copy the commented lines in this block)
            else:
                loss, weights = 0., 0.
                for i in range(output.shape[1]):
                    dl_batch = self.criterion['dice'](output[:,i], target).mean(dim=0).squeeze() # reduction over batches
                    fl_batch = self.criterion['focal'](output[:,i], target).mean(dim=[0, 2, 3, 4]) # reduction over batches
                    loss_batch = (self.criterion['lambda_dice'] * dl_batch) + (self.criterion['lambda_focal'] * fl_batch)
                    loss += loss_batch.mean() * 0.5 ** i # reduction over classes
                    weights += 0.5 ** i
                loss = loss/weights
        else:
            if self.criterion_name == "dice":
                loss_batch = self.criterion(output, target).mean(dim=0).squeeze() # reduction over batches
                loss = loss_batch.mean() # reduction over classes
            elif self.criterion_name == "focal":
                loss_batch = self.criterion(output, target).mean(dim=[0, 2, 3, 4]) # reduction over batches
                loss = loss_batch.mean() #reduction over classes
            elif self.criterion_name in ["dicefocal","gen_dicefocal"]:
                dl_batch = self.criterion['dice'](output, target).mean(dim=0).squeeze() # reduction over batches
                fl_batch = self.criterion['focal'](output, target).mean(dim=[0, 2, 3, 4]) # reduction over batches
                loss_batch = (self.criterion['lambda_dice'] * dl_batch) + (self.criterion['lambda_focal'] * fl_batch)
                loss = loss_batch.mean() # reduction over classes 
            
        
        
        if self.model_name == "segresnetvae":
            loss = loss + 0.1 * vae_mse_loss + 0.1 * vae_reg_loss
            
            
            
        ## NEW: compute the 3D (differentiable) HausdorffLoss for each of the 3 labels and add to the final loss
        if self.params.use_hausdorff:
            if not self.params.deep_supervision:
                pass
            else:
                
                ## FOR NOW: compute HD on all three seg_pred
                ## it could be worthwhile to compute HD onto WT (and TC) only -> ET is usually far from being convex and it
                ## could only be a problem for the loss descent
                label_seg_pred = [TC_seg_pred, WT_seg_pred, ET_seg_pred] = output[:,0].unbind(dim=1) # need to select the first 'head' from DS'output
                label_seg_gt = [TC_seg_gt, WT_seg_gt, ET_seg_gt] = target.unbind(dim=1)
                
                HD_losses = torch.zeros(3)
                
                for lbl in range(len(label_seg_pred)):
                    ## Avoid using 'hard values' when computing the 
                    #hdl = self.hausdorff_loss(val_act_threshold(label_seg_pred[lbl][:,None,:]), label_seg_gt[lbl][:,None,:]) # shape of :math:`(B, C, D, H, W), here C == 1`
                    ## use insted the probability map
                    hdl = self.hausdorff_loss(label_seg_pred[lbl][:,None,:], label_seg_gt[lbl][:,None,:]) # shape of :math:`(B, C, D, H, W), here C == 1`
                    
                    
                    HD_losses[lbl] = hdl
                    
                hd_loss = HD_losses.mean()
            
        
        ## Log the losses
        if self.params.outClasses == 4:
            loss_names = [f'{self.criterion_name}_loss_train', 'bg_loss_train', 'TC_loss_train', 'WT_loss_train', 'ET_loss_train', 'hausdorff_loss']
            losses = loss_batch.squeeze().tolist()
        else:
            loss_names = [f'{self.criterion_name}_loss_train', 'TC_loss_train', 'WT_loss_train', 'ET_loss_train', 'hausdorff_loss']
            losses = loss_batch.squeeze().tolist()
            
        
        self.log_dict(
            dict(
                zip(loss_names,
                    [loss.item(), *losses, hd_loss.item()]
                   )
            ),
            batch_size = self.params.val_batch_size,
            on_step = False,
            on_epoch = True,
            rank_zero_only = True
        )
        
        ## Update the losses list to update 'on train epoch end' the `hausdorff_lambda` parameter
        self.hd_losses.append(hd_loss.item())
        self.dsc_losses.append(loss.item())
        
        
        # Combine the hd_loss and the dicefocal loss
        loss = (hd_loss * self.hausdorff_lambda) + loss
        
        
        return loss
    
    
    def on_train_epoch_end(self):
        ## Update the `hausdorff_lambda` parameter and clear the losses list
        ## for the next epoch
        hd_loss_mean = np.mean(self.hd_losses)
        dsc_loss_mean = np.mean(self.dsc_losses)
        r = hd_loss_mean/dsc_loss_mean
        
        self.hausdorff_lambda = r
        self.hd_losses.clear()
        self.dsc_losses.clear()
        
    
    def validation_step(self, batch, batch_idx):
        ## During this step the HD is not computed
        ## 

        if self.params.inModalities == 4:
            input_tensor, target = batch['image'], batch['label']
        elif self.params.inModalities == 1:
            input_tensor, target = batch[f'image_{self.params.modality}'], batch['label']
        
        
        val_output = sliding_window_inference(
            inputs=input_tensor,
            roi_size=(240, 240, 160),
            sw_batch_size=4,
            predictor=self.model,
            overlap=0.5)
    
        
        batch['pred'] = val_output
        batch['pred_meta_dict'] = val_output.as_dict(key='pred')['pred_meta_dict']
        
        ## Compute the validation losses
        if self.criterion_name == "dice":
            loss_batch = self.criterion(val_output.as_tensor(), target.as_tensor()).mean(dim=0).squeeze() # reduction over batches
            loss = loss_batch.mean() # reduction over classes
        elif self.criterion_name == "focal":
            loss_batch = self.criterion(val_output.as_tensor(), target.as_tensor()).mean(dim=[0, 2, 3, 4]) # reduction over batches
            loss = loss_batch.mean() #reduction over classes
        elif self.criterion_name in ["dicefocal", "gen_dicefocal"]:
            dl_batch = self.criterion['dice'](val_output.as_tensor(), target.as_tensor()).mean(dim=0).squeeze() # reduction over batches
            fl_batch = self.criterion['focal'](val_output.as_tensor(), target.as_tensor()).mean(dim=[0, 2, 3, 4]) # reduction over batches
            loss_batch = (self.criterion['lambda_dice'] * dl_batch) + (self.criterion['lambda_focal'] * fl_batch)
            loss = loss_batch.mean() # reduction over classes 
        
        
        ## Activate and Discretize the outputs
        ## Maybe it is possible to compute `val_act_threshold` directly on batch, w/o decollating
        ## this way there should be no need to stack the output again
        val_output = [val_act_threshold(i) for i in decollate_batch(val_output.as_tensor())]
        
    
        ## Update the valdation metrics
        # BEFORE:
        #self.val_dice.update(batch['pred'].as_tensor(), target.as_tensor())
        # AFTER:
        self.val_dice.update(torch.stack(val_output), target.as_tensor())
        
        
        ## Log the losses
        if self.params.outClasses == 4: # the bg is explicitely computed
            loss_names = [f'{self.criterion_name}_loss_val', 'bg_loss_val', 'TC_loss_val', 'WT_loss_val', 'ET_loss_val']
            losses = loss_batch.squeeze().tolist()
        else:
            loss_names = [f'{self.criterion_name}_loss_val', 'TC_loss_val', 'WT_loss_val', 'ET_loss_val']
            losses = loss_batch.squeeze().tolist()
        
        self.log_dict(
            dict(
                zip(
                    loss_names,
                    [loss.item(), *losses]
                   )
                ),
            batch_size = self.params.val_batch_size,
            on_step = False,
            on_epoch = True,
            rank_zero_only = True
            )
        
        
        ## Log the model predictions along with gt and its background image
        if batch_idx == 0:
            #batch_orig = [post_transform(i) for i in batch]
            
            input_t = batch['image'][0].as_tensor()[0,...] # For the `modality` log the FLAIR (0)
            output_t = batch['pred'][0].as_tensor()
            target_t = batch['label'][0].as_tensor()
            
            
            wandb_mask_list = log_brain_slices(input_t, output_t, target_t, total_slices=144)
            wandb.log({"Segmentation Mask" : wandb_mask_list}, commit=True)
        
        
        
        val_dsc = self.val_dice.compute()
        
        ## Eliminate "bg val dice score"
        if self.params.outClasses == 4:
            val_dsc = val_dsc[1:]
        
        ## Log val_scores
        self.log_dict(
            dict(
                zip(
                    ['dice_score', 'TC_dice_score', 'WT_dice_score', 'ET_dice_score'],
                    [val_dsc.mean().item(), val_dsc[0].item(), val_dsc[1].item(), val_dsc[2].item()]
                    )
                ),
            batch_size = self.params.val_batch_size,
            on_step = False,
            on_epoch = True,
            rank_zero_only = True
            )
        
        self.val_dice.reset()
    
        
    
    def configure_optimizers(self):
        if self.optimizer == 'sgd':
            optimizer = SGD(self.model.parameters(),
            lr = self.params.sgd['lr'],
            momentum = self.params.sgd['momentum'],
            weight_decay = self.params.sgd['weight_decay'])
        
        elif self.optimizer == 'adam':
            optimizer = Adam(self.model.parameters(),
            lr = self.params.adam['lr'],
            weight_decay = self.params.adam['weight_decay'])
        
        elif self.optimizer == 'rmsprop':
            optimizer = RMSprop(self.model.parameters(),
            lr = self.params.rmsprop['lr'],
            momentum = self.params.rmsprop['momentum'],
            alpha = self.params.rmsprop['alpha'],
            weight_decay = self.params.rmsprop['weight_decay'])

        scheduler = CosineAnnealingLR(optimizer, T_max=self.params.max_epochs)
    
        return {
            "optimizer" : optimizer,
            "lr_scheduler" : {
                "scheduler" : scheduler,
                "monitor" : 'dice_val_loss',
                "frequency" : self.trainer.check_val_every_n_epoch
            }
        }

## Utils :toolbox:

### Log: Volume & Masks :pen:

In [12]:
def labels():
    segmentation_ids = [0,1,2,3]
    segmentation_classes = ["background", "Whole Tumor", "Tumor Core", "Enhancing Tumor"]
    return dict(zip(segmentation_ids, segmentation_classes))


def wb_mask(bg_img, pred_mask, true_mask):
    
    return wandb.Image(bg_img, masks={
        'prediction' : {'mask_data' : pred_mask, 'class_labels' : labels()},
        'ground_truth' : {'mask_data' : true_mask, 'class_labels' : labels()}
    })

def onehotToIndexes(labels):
    #ipdb.set_trace(context=6)
    labels = labels.bool()
    indexes = torch.zeros_like(labels[0, ...], dtype=torch.uint8)

    # nb: To create a proper mask, this order must be respected
    indexes[labels[0, ...]] = 0 # background - CAN BE IGNORED
    indexes[labels[2, ...]] = 1 # Whole Tumor
    indexes[labels[1, ...]] = 2 # Tumor Core
    indexes[labels[3, ...]] = 3 # Enhancing Tumor

    return indexes

def log_brain_slices(volume, preds, gt, th: float=0.5, total_slices=120):
    
    wandb_mask_list = []
    
    volume = volume.detach().cpu()
    
    # Activate the predictions
    if preds.shape[0] == 4: # if there's also the background
        preds = preds[1:,...] # ... remove it
    
    preds = sigmoid(preds).detach().cpu()
    
    # filter for activated voxels exceeding an arbitrary threshold
    activated_voxels = where(preds>th, 1, 0)
    preds_idx = stack( [activated_voxels[1], activated_voxels[0]*2, activated_voxels[2]*3], dim=0) # WT=1, TC=2, ET=3
    preds_idx = amax(preds_idx, dim=0)
    
    # Remove the background class if present
    gt = gt.detach().cpu()
    
    if gt.shape[0] == 4:
        gt = gt[1:,...]
        
    gt_idx = amax(stack([gt[1], gt[0]*2, gt[2]*3], dim=0), dim=0).detach().cpu()
    
    # Transform to numpy arrays
    preds_idx = preds_idx.numpy()
    gt_idx = gt_idx.numpy()
    
    
    #gt_idx = onehotToIndexes(gt).detach().cpu().numpy()
    
    for vol_slice_idx in range(total_slices):
        #print(vol_slice_idx)
        img = volume[:, :, vol_slice_idx]
        prd = preds_idx[:, :, vol_slice_idx]
        gt = gt_idx[:, :, vol_slice_idx]
    
        wandb_mask_list.append(wb_mask(img, prd, gt))
    
    return wandb_mask_list

## Transforms :man_mechanic:

In [13]:
# @title Custom Transforms

from monai.transforms import (
    MapTransform,
    InvertibleTransform,
    Compose,
    Invertd,
    LoadImaged,
    CropForegroundd,
    RandSpatialCropd,
    SpatialPadd,
    CenterSpatialCropd,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd,
    Spacingd,
    SplitDimd,
    RandSpatialCropd,
    RandFlipd,
    NormalizeIntensity,
    NormalizeIntensityd,
    RandScaleIntensityd,
    RandShiftIntensityd,
    Activations,
    Activationsd,
    AsDiscrete,
    AsDiscreted,
    )

class MyNormalizeIntensity(NormalizeIntensity, InvertibleTransform):
    """
    Was originally created to make MONAI NormalizeIntensity invertible.
    The problem with their `transform` is that when the inversion is applied,
    it is performed on the whole set of (invertible) transformations applied originally and it seems
    that personalized keys cannot be targeted -> when appling this inversion, also the label was normalized (wrongly).
    
    I issued a suggestion to MONAI to make this transform invertible, and to date it has not been implemented.
    """
    def __init__(self,
                 subtrahend: Union[Sequence, NdarrayOrTensor, None] = None,
                 divisor: Union[Sequence, NdarrayOrTensor, None] = None,
                 nonzero: bool = False,
                 channel_wise: bool = False,
                 dtype: DtypeLike = np.float32):
        super().__init__(subtrahend, divisor, nonzero, channel_wise, dtype)
    
    def _normalize(self, img: NdarrayOrTensor, sub=None, div=None) -> NdarrayOrTensor:
        
        img, *_ = convert_data_type(img, dtype=torch.float32)

        if self.nonzero:
            slices = img != 0
        else:
            if isinstance(img, np.ndarray):
                slices = np.ones_like(img, dtype=bool)
            else:
                slices = torch.ones_like(img, dtype=torch.bool)
        if not slices.any():
            return img

        _sub = sub if sub is not None else self._mean(img[slices])
        if isinstance(_sub, (torch.Tensor, np.ndarray)):
            _sub, *_ = convert_to_dst_type(_sub, img)
            _sub = _sub[slices]

        _div = div if div is not None else self._std(img[slices])
        
        if np.isscalar(_div):
            if _div == 0.0:
                _div = 1.0
        elif isinstance(_div, (torch.Tensor, np.ndarray)):
            _div, *_ = convert_to_dst_type(_div, img)
            _div = _div[slices]
            _div[_div == 0.0] = 1.0

        img[slices] = (img[slices] - _sub) / _div
        return img, [_sub, _div]
    
    def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
        """
        Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
        """
        
        img = convert_to_tensor(img, track_meta=get_track_meta())
        list_stats = []
        dtype = self.dtype or img.dtype
        if self.channel_wise:
            if self.subtrahend is not None and len(self.subtrahend) != len(img):
                raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
            if self.divisor is not None and len(self.divisor) != len(img):
                raise ValueError(f"img has {len(img)} channels, but divisor has {len(self.divisor)} components.")

            #ipdb.set_trace(context=4)
            for i, d in enumerate(img):
                img[i], stats = self._normalize(d, sub=self.subtrahend[i] if self.subtrahend is not None else None, div=self.divisor[i] if self.divisor is not None else None)
                list_stats.append(stats)
            
        else:
            img = self._normalize(img, self.subtrahend, self.divisor)

        out = convert_to_dst_type(img, img, dtype=dtype)[0]
    
        if get_track_meta():
            self.update_meta(tensor=out, stats=list_stats)
            self.push_transform(out, extra_info={'stats' : torch.tensor(list_stats)} )
        
        return out
  
    def update_meta(self, tensor: MetaTensor, stats):
        #ipdb.set_trace(context=4)
        tensor.mean = stats[0]
        tensor.std = stats[1]
        #self.inverse_update(tensor)
  
    def inverse(self, data: MetaTensor) -> MetaTensor:
        #ipdb.set_trace(context=4)
        transform = self.pop_transform(data)
        stats = transform[TraceKeys.EXTRA_INFO]["stats"]
        if stats.shape[0] != data.shape[0]:
            return data
        mean = stats[:,0].to(data.device) 
        std = stats[:,1].to(data.device)

        # perform the inversion
        out = data * std[..., None, None, None]
        out += mean[..., None, None, None]

        return out
    

class MyNormalizeIntensityd(NormalizeIntensityd, InvertibleTransform):
    
    def __init__(
        self,
        keys: KeysCollection,
        subtrahend: Optional[NdarrayOrTensor] = None,
        divisor: Optional[NdarrayOrTensor] = None,
        nonzero: bool = False,
        channel_wise: bool = False,
        dtype: DtypeLike = np.float32,
        allow_missing_keys: bool = False
    ):
        super().__init__(keys, allow_missing_keys)
        self.normalizer = MyNormalizeIntensity(subtrahend, divisor, nonzero, channel_wise, dtype)

    
    def inverse(self, data: Mapping[Hashable, MetaTensor]) -> Dict[Hashable, MetaTensor]:    
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = self.normalizer.inverse(d[key])
        return d



class ToMultiChannelBratsClassesd(MapTransform):
    """
    Convert labels to multi channels based on brats classes
    label 1 is the peritumoral edema
    label 2 is the GD-enhancing tumor
    label 3 is the necrotic and non-enhancing tumor core

    The classes are: TC (Tumore core), WT (Whole tumor), ET (Enhancing tumor)
    """

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            ## label 0 id the background
            ## merge label 2 and 3 to consrtuct TC
            ## merge label 1, 2 and 3 to construct WT
            ## label 2 is ET
            result = [(d[key]==2) | (d[key]==3), (d[key]==1) | (d[key]==2) | (d[key]==3), d[key]==2 ]
            d[key] = stack(result, axis=0).float()
        return d

### Train/Val/Test Transforms

In [14]:
# @title Data Transforms
roi_size = (128, 128, 128) # (192, 192, 144)

# Define data transformation
train_transform = Compose(
  [
      # Load 4 Nifti images and stack them together (4 in_modalities)
      LoadImaged(keys=["image", "label"]),
      EnsureChannelFirstd(keys="image"),
      # here was SplitDimd
      EnsureTyped(keys=['image', 'label']),
      ToMultiChannelBratsClassesd(keys='label'),
      Orientationd(keys=['image', 'label'], axcodes="RAS"), # Right-Anterior-Superior
      Spacingd(
          keys=['image', 'label'],
          pixdim=(1.0, 1.0, 1.0), # Resample input image into the specified pixdim
          mode=('bilinear', 'nearest')
      ),
      CropForegroundd(
          keys=["image", "label"],
          source_key = "image",
          channel_indices = 0, # use FLAIR to compute the b_box
          #k_divisible = [roi_size[0], roi_size[1], roi_size[2]] # comment since afterward CenterSpatialCrop cuts the image to roi_size
                     ),
      CenterSpatialCropd(
          keys=['image', 'label'],
          roi_size=[roi_size[0], roi_size[1], roi_size[2]]
      ),
      SpatialPadd(
          keys = ['image', 'label'],
          spatial_size = [roi_size[0], roi_size[1], roi_size[2]]
      ),
      RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=0),
      RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=1),
      RandFlipd(keys=['image', 'label'], prob=0.5, spatial_axis=2),
      NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True), # If subtrahend and divisor are None -> use mean and std of the image/image channel (if `channek_wise` = True)
      RandScaleIntensityd(keys='image', factors=0.1, prob=1.0), # if `factors` is single number 'f' -> [-f; f]
      RandShiftIntensityd(keys='image', offsets=0.1, prob=1.0),
      
      ## NOTE: SplitDimd creates copies of the original data.
      ## splitting 4 multichannel images result into: original 4 MC image + 4 single channel (SC) images.
      ## To date, it seems impossible to reduce memory consumption by loading SC images from Decathlon Dataset.
      ## The only viable option is to work with the more recent BraTS 2021 dataset,
      ##  which comes divided per patient, but the modalities are stored in separated files.
       
      #SplitDimd(dim = 0, keys=['image'], list_output=False), # Use it for single inModality training
  ]  
)

val_transform = Compose(
  [
      LoadImaged(keys=['image', 'label']),
      EnsureChannelFirstd(keys='image'),
      # here was SplitDimd
      EnsureTyped(keys=['image', 'label']),
      ToMultiChannelBratsClassesd(keys='label'),
      Orientationd(keys=['image', 'label'], axcodes='RAS'),
      Spacingd(
          keys=['image', 'label'],
          pixdim=(1.0, 1.0, 1.0),
          mode=('bilinear', 'nearest')
      ),
      #RandSpatialCropd(keys=['image', 'label'], roi_size=[224, 224, 144], random_size=False), # Make if False to have specifically the requested `roi_size`
      NormalizeIntensityd(keys='image', nonzero=True, channel_wise=True),
      #SplitDimd(dim = 0, keys=['image'], list_output=False),
  ]
)

val_act_threshold = Compose(
    [
        Activations(sigmoid=True),
        AsDiscrete(threshold=0.5)
    ])

val_log_slice_transform = Compose(
    [
        Activations(sigmoid=True)
    ]
)



post_transform = Compose([
    Invertd(
        keys=['image', 'label'],
        transform = val_transform,
        orig_keys= ['image', 'label'],
        meta_keys = ['image_meta_dict', 'label_meta_dict'],
        orig_meta_keys = ['image_meta_dict', 'label_meta_dict'],
        meta_key_postfix = 'meta_dict',
        nearest_interp = False,
        to_tensor = True,
        device = "cuda" if torch.cuda.is_available() else "cpu",
        allow_missing_keys=False
    )
])

post_transform_pred = Compose([
    Invertd(
        keys="pred", # w only "pred" works fine
        transform=val_transform,
        orig_keys="image",
        #meta_keys="pred_meta_dict",
        orig_meta_keys="image_meta_dict",
        meta_key_postfix="meta_dict",
        nearest_interp=False,
        to_tensor=False,
        device="cuda" if torch.cuda.is_available() else "cpu",
        allow_missing_keys=False
    ),
    #Activationsd(keys="pred", sigmoid=True),
    #AsDiscreted(keys="pred", threshold=0.5),
])


transform = {'transform' : 
              {
                  'training' : train_transform,
                  'validation' : val_transform
               }
             }


## Models

#### Unet Submodules

In [18]:
class UnetConv3(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1), dropout_rate=0.2):
        super(UnetConv3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size=kernel_size, stride=init_stride, padding=padding_size, bias=False),
                nn.BatchNorm3d(out_size),
                nn.ReLU6(inplace=True),
                nn.Dropout3d(p=dropout_rate)
            )
            
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size=kernel_size, stride=1, padding=padding_size),
                nn.BatchNorm3d(out_size),
                nn.ReLU6(inplace=True),
                nn.Dropout3d(p=dropout_rate)
            )
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size=kernel_size, stride=init_stride, padding=padding_size),                       
                nn.ReLU6(inplace=True),
                nn.Dropout3d(p=dropout_rate)
                )
            
            self.conv2 = nn.Sequential(
                nn.Conv3d(out_size, out_size, kernel_size=kernel_size, stride=1, padding=padding_size),
                nn.ReLU6(inplace=True),
                nn.Dropout3d(p=dropout_rate)
            )
        
        # initialise the blocks
        for m in self.children():
            init_weights(m, init_type='kaiming')
    

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class UnetGridGatingSignal3(nn.Module):
    def __init__(self, in_size, out_size, kernel_size=(1,1,1), is_batchnorm=True):
        super(UnetGridGatingSignal3, self).__init__()

        if is_batchnorm:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0), bias=False),
                nn.BatchNorm3d(out_size),
                nn.ReLU6(inplace=True),
                )
        
        else:
            self.conv1 = nn.Sequential(
                nn.Conv3d(in_size, out_size, kernel_size, (1,1,1), (0,0,0)),
                nn.ReLU6(inplace=True),
                )
        
        for m in self.children():
            init_weights(m, init_type="kaiming")
        
    def forward(self, inputs):
        outputs = self.conv1(inputs)
        return outputs


class UnetUp3(nn.Module):
    def __init__(self, in_size, out_size, is_deconv, is_batchnorm=True):
        super(UnetUp3, self).__init__()
        if is_deconv:
            self.conv = UnetConv3(in_size, out_size, is_batchnorm)
            self.up = nn.ConvTranspose3d(in_size, out_size, kernel_size=(4,4,1), stride=(2,2,1), padding=(1,1,0))
        else:
            self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm)
            self.up = nn.Upsample(scale_factor=(2,2,1), mode='trilinear')
        
        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
    
    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset = outputs2.size()[2] - inputs1.size()[2] # what's this 'offset'?
        padding = 2 * [offset // 2, offset // 2, 0]
        outputs1 = F.pad(inputs1, padding)
        return self.conv(cat([outputs1, outputs2], 1))

class UnetUp3_CT(nn.Module):
    def __init__(self, in_size, out_size, is_batchnorm=True):
        super(UnetUp3_CT, self).__init__()
        
        self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.up = nn.Upsample(scale_factor=(2,2,2), mode='trilinear')
        
        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('UnetConv3') != -1: continue
            init_weights(m, init_type='kaiming')
    
    def forward(self, inputs1, inputs2):
        outputs2 = self.up(inputs2)
        offset = outputs2.size()[2] - inputs1.size()[2] # what's this 'offset'?
        padding = 2 * [offset // 2, offset // 2, 0]
        outputs1 = F.pad(inputs1, padding)
        return self.conv(cat([outputs1, outputs2], 1))

#### Attention Submodules

In [19]:
class _GridAttentionBlockND(nn.Module):
    def __init__(self, in_channels, gating_channels, inter_channels=None, dimension=3, mode='concatenation',
                 sub_sample_factor=(2,2,2)):
        super(_GridAttentionBlockND, self).__init__()

        assert dimension in [2,3]
        assert mode in ['concatenation', 'concatenation_debug', 'concatenation_residual']

        # Downsampling rate for the input featuremap
        if isinstance(sub_sample_factor, tuple): self.sub_sample_factor = sub_sample_factor
        elif isinstance(sub_sample_factor, list): self.sub_sample_factor = tuple(sub_sample_factor)
        else: self.sub_sample_factor = tuple([sub_sample_factor]) * dimension

        # Default parameter set
        self.mode = mode
        self.dimension = dimension
        self.sub_sample_kernel_size = self.sub_sample_factor

        # Number of channels (pixel dimensions)
        self.in_channels = in_channels
        self.gating_channels = gating_channels
        self.inter_channels = inter_channels

        if self.inter_channels is None:
            self.inter_channels = in_channels // 2
            if self.inter_channels == 0:
                self.inter_channels = 1
        
        if dimension == 3:
            conv_nd = nn.Conv3d
            bn = nn.BatchNorm3d
            self.upsample_mode = 'trilinear'
        elif dimension == 2:
            conv_nd = nn.Conv2d
            bn = nn.BatchNorm2d
            self.upsample_mode = 'bilinear'
        else:
            raise NotImplementedError

        # Output transform
        self.W = nn.Sequential(
            conv_nd(in_channels = self.in_channels, out_channels= self.in_channels, kernel_size=1, stride=1, padding=0),
            bn(self.in_channels),
        )

        # Theta^T * x_ij + Phi^T * gating_signal + bias
        self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
                             kernel_size=self.sub_sample_kernel_size, stride=self.sub_sample_factor, padding=0, bias=False)
        
        self.phi = conv_nd(in_channels=self.gating_channels, out_channels=self.inter_channels,
                           kernel_size=1, stride=1, padding=0, bias=True)
        
        self.psi = conv_nd(in_channels=self.inter_channels, out_channels=1, kernel_size=1, stride=1, padding=0, bias=True)


        # Initialise weights
        for m in self.children():
            init_weights(m, init_type='kaiming')
        
        # Define the operation
        if mode == 'concatenation':
            self.operation_function = self._concatenation
        elif mode == 'concatenation_debug':
            self.operation_function = self._concatenation_debug
        elif mode == 'concatenation_residual':
            self.operation_function = self._concatenation_residual
        else:
            raise NotImplementedError('Unknown operation function.')
    
    def forward(self, x, g):
        '''
         :param x: (b, c, t, h, w)
         :param g: (b, g_d)
         :return:
        '''

        output = self.operation_function(x, g)
        return output
    

    def _concatenation(self, x, g):
        input_size = x.size()
        batch_size = input_size[0]
        assert batch_size == g.size(0)

        # theta -> (b, c, t, h, w) -> (b, i_c, t, h, w) -> (b, i_c, thw)
        # phi -> (b, g_d) -> (b, i_c)
        theta_x = self.theta(x)
        theta_x_size = theta_x.size()

        # g (b, c, t', h', w') -> phi_g (b, i_c, t', h', w')
        # Relu(theta_x + phi_g + bias) -> f = (b, i_c, thw) -> (b, i_c. t/s1, h/s2, w/s3)
        phi_g = F.interpolate(self.phi(g), size=theta_x_size[2:], mode=self.upsample_mode)
        f = F.relu6(theta_x + phi_g, inplace=True)

        # psi^T * f -> (b, psi_i_c, t/s1, h/s2, w/s3)
        sigm_psi_f = sigmoid(self.psi(f))

        # upsample the attention and multiply
        sigm_psi_f = F.interpolate(sigm_psi_f, size=input_size[2:], mode=self.upsample_mode)
        y = sigm_psi_f.expand_as(x) * x
        W_y = self.W(y) # check why this conv is performed, and how it does transform the 'element-wise' multiplication between x' and g'

        return W_y, sigm_psi_f







class GridAttentionBlock3D(_GridAttentionBlockND):
    def __init__(self, in_channels, gating_channels, inter_channels=None, mode='concatenation',
                 sub_sample_factor=(2,2,2)):
        super(GridAttentionBlock3D, self).__init__(in_channels,
                                                   gating_channels = gating_channels,
                                                   inter_channels=inter_channels,
                                                   dimension = 3,
                                                   mode = mode,
                                                   sub_sample_factor = sub_sample_factor,
                                                   )


class MultiAttentionBlock(nn.Module):
    def __init__(self, in_size, gate_size, inter_size, nonlocal_mode, sub_sample_factor):
        super(MultiAttentionBlock, self).__init__()
        self.gate_block_1 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
                                                 inter_channels=inter_size, mode=nonlocal_mode,
                                                 sub_sample_factor=sub_sample_factor)
        self.gate_block_2 = GridAttentionBlock3D(in_channels=in_size, gating_channels=gate_size,
                                                 inter_channels=inter_size, mode=nonlocal_mode,
                                                 sub_sample_factor=sub_sample_factor)
        self.combine_gates = nn.Sequential(nn.Conv3d(in_size*2, in_size, kernel_size=1, stride=1, padding=0), # *2 since we will concatenate the fmaps coming from 'g' and 'x'
                                           nn.BatchNorm3d(in_size),
                                           nn.ReLU6(inplace=True)
                                           )
    
        # initialise the blocks
        for m in self.children():
            if m.__class__.__name__.find('GridAttentionBlock3D') != -1: continue
            init_weights(m, init_type='kaiming')
    
    def forward(self, input, gating_signal):
        gate_1, attention_1 = self.gate_block_1(input, gating_signal)
        gate_2, attention_2 = self.gate_block_2(input, gating_signal)

        return self.combine_gates(cat([gate_1, gate_2], 1)), cat([attention_1, attention_2], 1)

#### DeepSupervision Submodules

In [20]:
class UnetDsv3(nn.Module):
    def __init__(self, in_size, out_size, scale_factor):
        super(UnetDsv3, self).__init__()
        self.dsv = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size=1, stride=1, padding=0),
                                 nn.Upsample(scale_factor=scale_factor, mode='trilinear'))
    
    def forward(self, input):
        return self.dsv(input)

#### utils

In [21]:
def init_weights(net, init_type='normal'):
    if init_type == 'kaiming':
        net.apply(weights_init_kaiming)
    else:
        raise NotImplementedError('Initialization method [%s] is not implemented' % init_type)

def weights_init_kaiming(m):
    classname = m.__class__.__name__
    #print(classname)
    if classname.find('Conv') != -1:
        kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='relu' )
    elif classname.find('Linear') != -1:
        kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='relu' )
    elif classname.find('BatchNorm') != -1:
        normal_(m.weight, mean=1.0, std=0.02)
        constant_(m.bias, 0.0)

def printInfo(model):
    C, H, W, D = 4, 128, 128, 128
    image_size = ((C, H, W, D))
    batch_size = (1,)
    input_size = batch_size + image_size
    device = "cuda:0" if is_available() else "cpu"


    # gets printed by default
    summary = torchinfo.summary(
        model,
        device = device,
        input_size = input_size,
        mode = "train",
        col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"),
        verbose = 1,
        depth=1
        )
    
    
    model = model.to(device)
    input_tensor = rand(input_size, device=device)

    out = model(input_tensor)

    print(out.shape)


alphas = [0.2391, 0.6477, 0.1132] # alpha values for WT, TC, ET (rarest class)
#alphas = [1, 1, 0.1747] #  # alpha values based on the original target (EDEMA, ET, NECROTIC) -> downweight only the rarest label  

def CreateLoss(params: dict={}):
    
    criterion_name = params['criterion_name'] 
    
    if criterion_name == "dice":
        include_background = params['n_classes'] == 3 # if n_classes == 4 -> there also the 'bg' class and it must be excluded
        squared_pred = params['squared_pred']
        to_onehot_y = params['to_onehot_y']
        sigmoid = params['sigmoid']
        reduction = params['dice_reduction']
        smooth_nr = params['smooth_nr']
        smooth_dr = params['smooth_dr']
        
        
        return DiceLoss(include_background=include_background,
                        squared_pred=squared_pred,
                        to_onehot_y=to_onehot_y,
                        sigmoid=sigmoid,
                        reduction=reduction,
                        smooth_nr=smooth_nr,
                        smooth_dr=smooth_dr
                       )
    elif criterion_name == "focal":
        include_background = params['n_classes'] == 3 # if n_classes == 4 -> there also the 'bg' class and it must be excluded
        to_onehot_y = params['to_onehot_y']
        gamma = params['gamma']
        reduction = params['dice_reduction']
        
        return FocalLoss(include_background=include_background,
                         to_onehot_y = to_onehot_y,
                         gamma = gamma,
                         weight = alphas,
                         reduction = reduction)

    
        
    elif criterion_name == "dicefocal":
        
        lambda_dice = params['lambda_dice']
        lambda_focal = params['lambda_focal']
        
        dice_params = dict(params)
        focal_params = dict(params)
        dice_params.update({'criterion_name':'dice'})
        focal_params.update({'criterion_name':'focal'})
        
        
        return {"dice": CreateLoss(dice_params),
                 "focal":CreateLoss(focal_params),
                 "lambda_dice": lambda_dice,
                 "lambda_focal": lambda_focal}


def createModel(params: dict={}):
    
    model_name = params["model_name"].lower()
    
    
    if model_name == "segresnet":
        blocks_down = params['blocks_down'] #n down_sample blocks in each layer
        blocks_up = params['blocks_up'] # n up_sample blocks in each layer
        init_filters = params['init_filters'] # n out_channels for initial convolution layer
        in_channels = params['in_channels']
        out_channels = params['out_channels'] # 4 = bg + relevant classes
        dropout_prob = params['dropout_prob']
        upsample_mode = params['upsample_mode'] # 'nontrainable' -> non trainable linear interp. (torch.nnUpsample).
                                           # 'deconv' uses trainable tranpose conv layers
        
        return SegResNet(
            blocks_down = blocks_down,
            blocks_up = blocks_up,
            init_filters = init_filters, 
            in_channels = in_channels,
            out_channels = out_channels, 
            dropout_prob = dropout_prob,
            upsample_mode = upsample_mode 
        )
    
    
    elif model_name == "segresnetvae":
        
        input_image_size = params['input_image_size']
        vae_estimate_std = params['vae_estimate_std']
        vae_default_std = params['vae_default_std']
        vae_nz = params['vae_nz']
        blocks_down = params['blocks_down'] #n down_sample blocks in each layer
        blocks_up = params['blocks_up'] # n up_sample blocks in each layer
        init_filters = params['init_filters'] # n out_channels for initial convolution layer
        in_channels = params['in_channels']
        out_channels = params['out_channels'] # 4 = bg + relevant classes
        dropout_prob = params['dropout_prob']
        upsample_mode = params['upsample_mode'] # 'nontrainable' -> non trainable linear interp. (torch.nnUpsample).
                                           # 'deconv' uses trainable tranpose conv layers
        
        
        
        return MySegResNetVAE(
            input_image_size = input_image_size,
            vae_estimate_std = vae_estimate_std,
            vae_default_std = vae_default_std,
            vae_nz = vae_nz,
            init_filters = init_filters,
            blocks_down = blocks_down,
            blocks_up = blocks_up,
            in_channels = in_channels,
            out_channels = out_channels,
            dropout_prob = dropout_prob,
            upsample_mode = upsample_mode
        )
    
    elif model_name == "attentionunet":
        feature_scale = params['feature_scale']
        in_channels = params['in_channels']
        out_channels = params['out_channels'] 
        is_deconv = params['is_deconv']
        nonlocal_mode = params['nonlocal_mode']
        attention_dsample = tuple(params['attention_dsample'])
        is_batchnorm = params['is_batchnorm']
        dropout_prob = params['dropout_prob']
        
        return MultiAttentionUnet(
            feature_scale = feature_scale,
            in_channels = in_channels,
            n_classes = out_channels,
            nonlocal_mode = nonlocal_mode,
            attention_dsample = attention_dsample,
            is_batchnorm = is_batchnorm,
            is_deconv = is_deconv,
            dropout_prob = dropout_prob)

## Attention Unet

In [22]:
class MultiAttentionUnet(nn.Module):

    def __init__(self, feature_scale=4, n_classes=3, is_deconv=True, in_channels=4,
                 nonlocal_mode='concatenation', attention_dsample=(2,2,2), is_batchnorm=True, dropout_prob=0.2):
        super(MultiAttentionUnet, self).__init__()
        self.is_deconv = is_deconv # might be unused
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale


        filters = [64, 128, 256, 512, 1024] # [32, 64, 128, 512, 1024] #
        filters = [int(x / self.feature_scale) for x in filters]
        
        # dropout - added for reducing overfit and regulariing loss descent
        self.dropout = nn.Dropout3d(p=dropout_prob)

        # downsampling
        self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool1 = nn.MaxPool3d(kernel_size=(2,2,2))
        
        self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool2 = nn.MaxPool3d(kernel_size=(2,2,2))
        
        self.conv3 = UnetConv3(filters[1], filters[2], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool3 = nn.MaxPool3d(kernel_size=(2,2,2))
        
        self.conv4 = UnetConv3(filters[2], filters[3], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.maxpool4 = nn.MaxPool3d(kernel_size=(2,2,2))

        self.center = UnetConv3(filters[3], filters[4], self.is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
        self.gating = UnetGridGatingSignal3(filters[4], filters[4], kernel_size=(1,1,1), is_batchnorm=self.is_batchnorm)

        # attention blocks


        self.attentionblock2 = MultiAttentionBlock(in_size=filters[1], gate_size=filters[2], inter_size=filters[1],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock3 = MultiAttentionBlock(in_size=filters[2], gate_size=filters[3], inter_size=filters[2],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        self.attentionblock4 = MultiAttentionBlock(in_size=filters[3], gate_size=filters[4], inter_size=filters[3],
                                                   nonlocal_mode=nonlocal_mode, sub_sample_factor= attention_dsample)
        
        # upsampling
        self.up_concat4 = UnetUp3_CT(filters[4], filters[3], is_batchnorm)
        self.up_concat3 = UnetUp3_CT(filters[3], filters[2], is_batchnorm)
        self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
        self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

        # deep supervision
        self.dsv4 = UnetDsv3(in_size=filters[3],  out_size=n_classes, scale_factor=8)
        self.dsv3 = UnetDsv3(in_size=filters[2],  out_size=n_classes, scale_factor=4)
        self.dsv2 = UnetDsv3(in_size=filters[1],  out_size=n_classes, scale_factor=2)
        self.dsv1 = nn.Conv3d(in_channels=filters[0],  out_channels=n_classes, kernel_size=1)

        # final conv (without concat)
        self.final = nn.Conv3d(n_classes*4, n_classes, 1)

        # initialise weights
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                init_weights(m, init_type='kaiming')
            elif isinstance(m, nn.BatchNorm3d):
                init_weights(m, init_type='kaiming')
    
    def forward(self, inputs):
        # Feature Extraction
        conv1 = self.conv1(inputs)
        maxpool1 = self.maxpool1(conv1)
        maxpool1 = self.dropout(maxpool1)
        
        conv2 = self.conv2(maxpool1)
        maxpool2 = self.maxpool2(conv2)
        maxpool2 = self.dropout(maxpool2)
        
        conv3 = self.conv3(maxpool2)
        maxpool3 = self.maxpool3(conv3)
        maxpool3 = self.dropout(maxpool3)
        
        conv4 = self.conv4(maxpool3)
        maxpool4 = self.maxpool4(conv4)
        maxpool4 = self.dropout(maxpool4)

        # Gating Signal Generation
        center = self.center(maxpool4)
        center = self.dropout(center)
        gating = self.gating(center)

        # Attention Mechanism
        # Upscaling Part (Decoder)
        g_conv4, att4 = self.attentionblock4(conv4, gating)
        up4 = self.up_concat4(g_conv4, center)

        g_conv3, att3 = self.attentionblock3(conv3, up4)
        up3 = self.up_concat3(g_conv3, up4)
        
        g_conv2, att2 = self.attentionblock2(conv2, up3)
        up2 = self.up_concat2(g_conv2, up3)
        up1 = self.up_concat1(conv1, up2)

        # Deep Supervision
        if self.training:
            #dsv4 = self.dsv4(up4) # Convolve and Upsample
            dsv3 = self.dsv3(up3)
            dsv2 = self.dsv2(up2)
            dsv1 = self.dsv1(up1)
            #final = self.final(cat([dsv1, dsv2, dsv3, dsv4], dim=1))
            final = stack([dsv1, dsv2, dsv3], dim=1) # stack them from "up tp bottom"
        else:
            final = self.dsv1(up1) # Apply only the final convolution
        

        return final

## MySegResNetVAE

In [23]:
class MySegResNetVAE(SegResNetVAE):
    def __init__(
        self,
        input_image_size: Sequence[int],
        vae_estimate_std: bool = False,
        vae_default_std: float = 0.3,
        vae_nz: int = 256,
        spatial_dims: int = 3,
        init_filters: int = 8,
        in_channels: int = 1,
        out_channels: int = 2,
        dropout_prob: Optional[float] = None,
        act: Union[str, tuple] = ("RELU", {"inplace": True}),
        norm: Union[Tuple, str] = ("GROUP", {"num_groups": 8}),
        use_conv_final: bool = True,
        blocks_down: tuple = (1, 2, 2, 4),
        blocks_up: tuple = (1, 1, 1),
        upsample_mode: Union[UpsampleMode, str] = UpsampleMode.NONTRAINABLE,
    ):
        super().__init__(
            input_image_size = input_image_size,
            spatial_dims=spatial_dims,
            init_filters=init_filters,
            in_channels=in_channels,
            out_channels=out_channels,
            dropout_prob=dropout_prob,
            act=act,
            norm=norm,
            use_conv_final=use_conv_final,
            blocks_down=blocks_down,
            blocks_up=blocks_up,
            upsample_mode=upsample_mode,
        )

        self.input_image_size = input_image_size
        self.smallest_filters = 16

        zoom = 2 ** (len(self.blocks_down) - 1)
        self.fc_insize = [s // (2 * zoom) for s in self.input_image_size]

        self.vae_estimate_std = vae_estimate_std
        self.vae_default_std = vae_default_std
        self.vae_nz = vae_nz
        self._prepare_vae_modules()
        self.vae_conv_final = self._make_final_conv(in_channels)
    
    def _get_vae_loss(self, net_input: torch.Tensor, vae_input: torch.Tensor):
        """
        Args:
        net_input: the original input of the network.
        vae_input: the input of VAE module, which is also the output of the network's encoder.
        """
        x_vae = self.vae_down(vae_input)
        x_vae = x_vae.view(-1, self.vae_fc1.in_features)
        z_mean = self.vae_fc1(x_vae)

        z_mean_rand = torch.randn_like(z_mean)
        z_mean_rand.requires_grad_(False)

        if self.vae_estimate_std:
            z_sigma = self.vae_fc2(x_vae)
            z_sigma = F.softplus(z_sigma)
            vae_reg_loss = 0.5 * torch.mean(z_mean**2 + z_sigma**2 - torch.log(1e-8 + z_sigma**2) - 1)

            x_vae = z_mean + z_sigma * z_mean_rand
        else:
            z_sigma = self.vae_default_std
            vae_reg_loss = torch.mean(z_mean**2)

            x_vae = z_mean + z_sigma * z_mean_rand

        x_vae = self.vae_fc3(x_vae)
        x_vae = self.act_mod(x_vae)
        x_vae = x_vae.view([-1, self.smallest_filters] + self.fc_insize)
        x_vae = self.vae_fc_up_sample(x_vae)

        for up, upl in zip(self.up_samples, self.up_layers):
            x_vae = up(x_vae)
            x_vae = upl(x_vae)

        x_vae = self.vae_conv_final(x_vae)
        vae_mse_loss = F.mse_loss(net_input, x_vae)
        vae_loss = vae_reg_loss + vae_mse_loss
        return vae_reg_loss, vae_mse_loss
    
    def forward(self, x):
        net_input = x
        x, down_x = self.encode(x)
        down_x.reverse()

        vae_input = x
        x = self.decode(x, down_x)

        if self.training:
            vae_reg_loss, vae_mse_loss = self._get_vae_loss(net_input, vae_input)
            return x, vae_reg_loss, vae_mse_loss

        return x

## Progress Bar :arrows_counterclockwise:

In [15]:
class BatchesProcessedColumn(ProgressColumn):
        def __init__(self, style: Union[str, Style]):
            self.style = style
            super().__init__()

        def render(self, task: "Task") -> RenderableType:
            total = task.total if task.total != float("inf") else "--"
            return Text(f"{int(task.completed)}/{total}", style=self.style)

class CustomTimeColumn(ProgressColumn):

        # Only refresh twice a second to prevent jitter
        max_refresh = 0.5

        def __init__(self, style: Union[str, Style]) -> None:
            self.style = style
            self.start_time = datetime.now()
            super().__init__()

        def render(self, task: "Task") -> Text:
            elapsed = task.finished_time if task.finished else task.elapsed
            remaining = task.time_remaining
            total = datetime.now() - self.start_time
            elapsed_delta = "-:--:--" if elapsed is None else str(timedelta(seconds=int(elapsed)))
            remaining_delta = "-:--:--" if remaining is None else str(timedelta(seconds=int(remaining)))
            return Text(f"{elapsed_delta} • {remaining_delta} • {total.days}:{total.seconds // 3600}:{(total.seconds//60)%60}:{total.seconds%60}", style=self.style)


class CustomRichProgressBar(RichProgressBar):
    """
    Custom Progress Bar for PyTorch Lightning.
    Added the possibility to show the current steps and total steps in the progress bar.
    Added the elasped time from start training.

    NOTE: in pytorch-lightning>=2.0 the RichProgressBar has undergone minor refactoring.
            Some names field have been changed to better distinguish the progress bar displayed during the different phases of the training.
    """

    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if self.is_disabled:
            return
        total_batches = self.total_batches_current_epoch
        train_description = self._get_train_description(trainer.current_epoch)

        if self.main_progress_bar_id is not None and self._leave:
            self._stop_progress()
            self._init_progress(trainer)
        if self.progress is not None:
            if self.main_progress_bar_id is None:
                self.main_progress_bar_id = self._add_task(total_batches, train_description)
            else:
                self.progress.reset(
                    self.main_progress_bar_id, total=total_batches, description=train_description, visible=True
                )

        self.refresh()
    
    def _get_train_description(self, current_epoch: int) -> str:
        if self.trainer.max_epochs not in [None, -1]:
            train_description = f"Epoch {current_epoch}"
            if self.trainer.max_epochs is not None:
                train_description += f"/{self.trainer.max_epochs - 1}"
            if len(self.validation_description) > len(train_description):
                # Padding is required to avoid flickering due of uneven lengths of "Epoch X"
                # and "Validation" Bar description
                train_description = f"{train_description:{len(self.validation_description)}}"
            return train_description
        elif self.trainer.max_steps not in [None, -1]:
            train_description = f"Step {self.trainer.global_step}"
            if self.trainer.max_steps is not None:
                train_description += f"/{self.trainer.max_steps - 1}"
            if len(self.validation_description) > len(train_description):
                # Padding is required to avoid flickering due of uneven lengths of "Epoch X"
                # and "Validation" Bar description
                train_description = f"{train_description:{len(self.validation_description)}}"
            return train_description
        else:
            global max_time
            train_description = f"Training Time: {max_time}"
            return train_description




    def configure_columns(self, trainer: "pl.Trainer") -> list:
        return [
            TextColumn("[progress.description]{task.description}"),
            CustomBarColumn(
                complete_style=self.theme.progress_bar,
                finished_style=self.theme.progress_bar_finished,
                pulse_style=self.theme.progress_bar_pulse,
            ),
            BatchesProcessedColumn(style=self.theme.batch_progress),
            CustomTimeColumn(style=self.theme.time),
            ProcessingSpeedColumn(style=self.theme.processing_speed),
        ]

# Configs :notebook: 

In [None]:
parser = argparse.ArgumentParser(description="Thesis Experiment")
parser.add_argument('--default_config', '-c',
                    dest="default_config",
                    metavar='FILE',
                    help="path to the default config file",
                    default = config_dir + "t_attunet_4_3_dicefocal.yaml") # debug_vae_gpu_config mono0.yaml 

args, unknown = parser.parse_known_args() # add [] as argument
with open(args.default_config, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)

config

In [17]:
config = {'train_batch_size': 2,
 'val_batch_size': 1,
 'inChannels': 4,
 'outClasses': 3,
 'modelName': 'attentionunet',
 
'sweep_params': {'ntrials': 2,
  'project': 'MICCAI 2018 Medical Image Segmentation'},
 
'data_params': {'dataset': 'Task01_BrainTumour',
  'classes': 3,
  'data_path': '/kaggle/input/task01braintumour/',
  'train_batch_size': 2,
  'val_batch_size': 1,
  'num_workers': 2,
  'download': False,
  'cache_rate': 0.0},
 
'wandb_params': {'save_dir': '/kaggle/working/wandb',
  'manual_seed': 42,
  'ckpt_name': 'attentionunet',
  'model_name': 'attentionunet',
  'log_model': True},
 
'model_params': {'model_name': 'attentionunet',
  'feature_scale': 2,
  'in_channels': 4,
  'out_channels': 3,
  'is_deconv': True,
  'nonlocal_mode': 'concatenation',
  'attention_dsample': [2, 2, 2],
  'is_batchnorm': True,
  'dropout_prob': 0.2},
 
'criterion_params': {'criterion_name': 'dicefocal',
  'n_classes': 3,
  'squared_pred': True,
  'to_onehot_y': False,
  'sigmoid': True,
  'dice_reduction': 'none',
  'smooth_nr': 0,
  'smooth_dr': '1e-5',
  'gamma': 2.0,
  'lambda_dice': 0.8,
  'lambda_focal': 1.0},
 
'exp_params': {'deep_supervision': True,
  'use_hausdorff': True,
  'Hausdorff': {
      'alpha': 2.0,
      'k': 5,
      'reduction': 'mean'
  },
  'max_epochs': 25,
  'num_iters': 100,
  'inModalities': 4,
  'outClasses': 3,
  'train_batch_size': 2,
  'val_batch_size': 1,
  'adam': {'lr': 0.0008, 'weight_decay': 1e-05},
  'sgd': {'lr': 2e-05, 'momentum': 0.9, 'weight_decay': 0.01},
  'rmsprop': {'lr': 0.01,
   'momentum': 0.9,
   'alpha': 0.99,
   'weight_decay': 0.01}},
 
'trainer_params': {'accelerator': 'gpu', 'devices': 1},
 
'checkpoint_params': {'mode': 'min',
  'save_top_k': 1,
  'every_n_epochs': 1,
  'save_on_train_epoch_end': True},
 
'training_params': {'debug': False,
  'resume_train': False,
  'ckpt_path': None,
  'n_models_to_save': 2}}

# Main :train2:

In [24]:
def main(config):
    
    
    # ------------------------
    # 1 WANDB LOGGER
    # ------------------------
    id = wandb.util.generate_id()
    print("Run id: ", id)
    
    wandb_logger = WandbLogger(
        project = 'MICCAI 2018 Medical Image Segmentation',
        save_dir = "/kaggle/working/",
        log_model = True, #"all": log while training
        config = config,
        resume = "allow",
        name = "AttentionUnet - Train",
        notes = "BS: 2 + New DS + Both D + FD + FS: 2 + ReLU6 + a-Focal + HD_er",
        #checkpoint_name = "AttentionUnet",
        id= id
    )
    
    # ------------------------
    # 2 INIT LIGHTNING MODEL
    # ------------------------
    model_name = wandb.config["model_params"]["model_name"] #"segresnetvae"
    model = createModel(wandb.config['model_params'])
    
    #wandb_logger.watch(model, log_freq=480, log_graph=True) # every 500 steps; 
    
    
    # ------------------------
    # 3 DATA PIPELINE
    # ------------------------
    wandb.config['data_params'].update(transform)
    
    data = DatasetModule(**wandb.config['data_params'])
    
    # ------------------------
    # 4 LIGHTNING EXPERIMENT
    # ------------------------
    criterion_name = wandb.config['criterion_params']['criterion_name']
    criterion = CreateLoss(wandb.config['criterion_params'])
    
    optimizer = "adam"
    
    experiment = ThesisExperiment(model_name, model, criterion_name, criterion, optimizer, wandb.config['exp_params'])
    
    # ------------------------
    # 5 TRAINER
    # ------------------------
        # 5.1 PROGRESSBAR
    # ----------------------
    theme = RichProgressBarTheme(description="white",
                                progress_bar="blue",
                                progress_bar_finished="green",
                                progress_bar_pulse="bright_blue",
                                batch_progress="white",
                                time="gray54",
                                processing_speed="gray70",
                                metrics="white")
    global max_time
    max_time = "00:10:00:00"
    
    
    progress_bar = CustomRichProgressBar(theme=theme)
    
    # add callbacks
    
    #ckpt_name = config['wandb_params']['ckpt_name']
    
    
    
    checkpointCBK = ModelCheckpoint(
            monitor = criterion_name + "_loss_val", #wandb.config['checkpoint_params']['monitor'],
            mode = wandb.config['checkpoint_params']['mode'],
            dirpath = os.path.join(wandb.config['wandb_params']['save_dir'], "checkpoints"),
            filename = model_name + '-{epoch}-{' + '{}_loss_val'.format(criterion_name) + ':.2f}',
            save_top_k = wandb.config['checkpoint_params']['save_top_k'],
            #train_time_interval = timedelta(minutes=5), #(hours=1),
            every_n_epochs = 1, #wandb.config['checkpoint_params']['every_n_epochs'],
            save_on_train_epoch_end= False #wandb.config['checkpoint_params']['save_on_train_epoch_end']
        )
    
    lrMonitorCBK = LearningRateMonitor(logging_interval='epoch')
    
    callbacks = [
        checkpointCBK,
        lrMonitorCBK,
        progress_bar
    ]
    
    doResume= False
    ckpt_file = "/kaggle/working/AttentionUnet_epoch_25_aFocal.ckpt"
    
    trainer = Trainer(
            accelerator = 'gpu',
            devices = 1,
            logger= wandb_logger,
            callbacks = callbacks,
            fast_dev_run = False,
            enable_checkpointing = True,
            benchmark=True,
            log_every_n_steps = 50,
            max_epochs = 25, #10
            check_val_every_n_epoch = 1,
            #val_check_interval = 1,
            max_time = max_time,
            precision = 16,
            limit_train_batches = 1.0,
            limit_val_batches = 0.5,
            #**wandb.config['trainer_params']
        )
    
    
    
    # ------------------------
    # 6 START TRAINING
    # ------------------------
    print("\nStart training!\n")
    trainer.fit(experiment, datamodule=data, ckpt_path = ckpt_file if doResume else None)
    

In [None]:
if __name__ == '__main__':
    
    main(config)
    print("\nEnd training!\n")
    wandb.finish()

Run id:  1n6vpurq


[34m[1mwandb[0m: Currently logged in as: [33mmattiacapparella[0m. Use [1m`wandb login --relogin`[0m to force relogin



Start training!




Checkpoint directory /kaggle/working/wandb/checkpoints exists and is not empty.



Output()