# MRI based brain tumor IDH classification with MONAI (3D multiparametric MRI)

This tutorial shows how to construct a training workflow of binary classification task.  
And it contains below features:
1. Transforms for Monai dictionary format data.
2. Define a new transform according MONAI transform API.
3. Load Nifti image with metadata, load a list of images and stack them.
5. 3D Voxel DynUNet model, Dice loss, cross entropy loss function for IDH classification task.
6. Deterministic training for reproducibility.

The Brain tumor dataset can be downloaded from 
https://ipp.cbica.upenn.edu/ and  http://medicaldecathlon.com/.  

Target: IDH classification based on whole brain, tumour core, whole tumor, and enhancing tumor from MRI 
Modality: Multimodal multisite MRI data (FLAIR, T1w, T1gd,T2w)  
training: 135 3D MRI \
validation:  \
testing: Not revealed

Source: BRATS 2020/2021 datasets.  
Challenge: RSNA-MICCAI Brain Tumor Radiogenomic Classification

Below figure shows image patches with the tumor sub-regions that are annotated in the different modalities (top left) and the final labels for the whole dataset (right). (Figure taken from the [BraTS IEEE TMI paper](https://ieeexplore.ieee.org/document/6975210/))  
![image](https://ieeexplore.ieee.org/mediastore_new/IEEE/content/media/42/7283692/6975210/6975210-fig-3-source-large.gif)

The image patches show from left to right:
1. the whole tumor (yellow) visible in T2-FLAIR (Fig.A).
2. the tumor core (red) visible in T2 (Fig.B).
3. the enhancing tumor structures (light blue) visible in T1Gd, surrounding the cystic/necrotic components of the core (green) (Fig. C).
4. The segmentations are used to generate the final labels of the tumor sub-regions (Fig.D): edema (yellow), non-enhancing solid core (red), necrotic/cystic core (green), enhancing core (blue).

In [1]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import sys
import gc
import logging
import copy

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns
import numpy as np
import pandas as pd
import scipy
from scipy import ndimage
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.nets import DynUNet, EfficientNetBN, DenseNet121, SegResNet, SegResNetVAE
from monai.data import CacheDataset, Dataset, DataLoader, ThreadDataLoader
from torch.utils.data import WeightedRandomSampler

import monai
from monai.transforms import (
    Activations,
    AsDiscrete,
    CastToTyped,
    Compose, 
    CropForegroundd,
    ResizeWithPadOrCrop,
    ResizeWithPadOrCropd,
    Spacingd,
    RandRotate90d,
    Resized,
    EnsureChannelFirstd, 
    Orientationd,
    LoadImaged,
    NormalizeIntensity,
    HistogramNormalize,
    NormalizeIntensityd,
    RandCropByPosNegLabeld,
    RandCropByLabelClassesd,
    RandAffined,
    RandFlipd,
    Flipd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandGibbsNoised,
    RandStdShiftIntensityd,
    RandScaleIntensityd,
    RandZoomd, 
    SpatialCrop, 
    SpatialPadd, 
    MapTransform,
    CastToType,
    ToTensord,
    AddChanneld,
    MapTransform,
    Orientationd,
    ScaleIntensityd,
    ScaleIntensity,
    ScaleIntensityRangePercentilesd,
    ScaleIntensityRange,
    RandShiftIntensityd,
    RandAdjustContrastd,
    AdjustContrastd,
    Rotated,
    ToNumpyd,
    ToDeviced,
    EnsureType,
    EnsureTyped,
    DataStatsd,
)

from monai.config import KeysCollection
from monai.transforms.compose import MapTransform, Randomizable
from collections.abc import Iterable
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from monai.utils import set_determinism
from monai.utils import (
    ensure_tuple,
    ensure_tuple_rep,
    ensure_tuple_size,
)

from monai.optimizers import LearningRateFinder

from monai.transforms.compose import MapTransform
from monai.transforms.utils import generate_spatial_bounding_box
from skimage.transform import resize
from monai.losses import DiceCELoss, DiceLoss
from monai.utils import set_determinism
from monai.inferers import sliding_window_inference


from monai.metrics import DiceMetric, ROCAUCMetric
from monai.data import decollate_batch
import glob
import monai
from monai.metrics import compute_meandice
import random
import pickle
from collections import OrderedDict
from typing import Sequence, Optional
import ipywidgets as widgets
from itertools import compress
import SimpleITK as sitk
import torchio as tio

import sklearn
from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, recall_score, \
accuracy_score, precision_score, f1_score, make_scorer 

from monai.utils import ensure_tuple_rep
from monai.networks.layers.factories import Conv, Dropout, Norm, Pool
import matplotlib.pyplot as plt
from ranger21 import Ranger21

### monai and ignite based imports
import logging
from ignite.engine import Engine, Events
from ignite.contrib.handlers import FastaiLRFinder, ProgressBar
from ignite.engine import (
    Events,
    _prepare_batch,
    create_supervised_evaluator,
    create_supervised_trainer,
)
from ignite.handlers import EarlyStopping, ModelCheckpoint

from itkwidgets import view
import random
monai.config.print_config()
#from sliding_window_inference_classes import sliding_window_inference_classes

MONAI version: 0.9.0
Numpy version: 1.22.3
Pytorch version: 1.10.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: af0e0e9f757558d144b655c63afcea3a4e0a06f5
MONAI __file__: /home/mmiv-ml/anaconda3/envs/sa_tumorseg22/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.9
Nibabel version: 4.0.1
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.11.2
tqdm version: 4.64.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.1
pandas version: 1.4.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-rec

In [2]:
MAX_THREADS =2
sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(MAX_THREADS)

In [3]:
seeds = 40961024
set_determinism(seed=seeds)
##np.random.seed(seeds) np random seed does not work here
!nvidia-smi

Mon Jul 18 00:41:17 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.172.01   Driver Version: 450.172.01   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   37C    P0    54W / 300W |   2305MiB / 32505MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   39C    P0    53W / 300W |   7832MiB / 32508MiB |      0%      Default |
|       

In [4]:
#patch_size = (128, 128, 128)

spacing = (1.0, 1.0, 1.0)
os.environ["CUDA_VISIBLE_DEVICES"] ="2"
device = torch.device('cuda:0')
deviceName = 'cuda:0'

In [5]:
pd.set_option('display.max_colwidth', None)
data_rpath = '/home/mmiv-ml/data'


In [6]:
BraTS20SubjectsIDHWithMetaDF  = pd.read_csv('assets/BraTS20SubjectsIDHWithMetaDF.csv')
BraTS20SubjectsIDHWithMetaDF

Unnamed: 0,BraTS2020,t1wPath,t1cwPath,t2wPath,flairPath,segPath,t1w_BrainmaskPath,IDH_value,BraTS2021,BraTS2019,...,ET_CoordX,ET_CoordY,ET_CoordZ,ED_CoordX,ED_CoordY,ED_CoordZ,NEC_CoordX,NEC_CoordY,NEC_CoordZ,is_merged_3
0,BraTS20_Training_274,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_274/BraTS20_Training_274_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_274/BraTS20_Training_274_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_274/BraTS20_Training_274_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_274/BraTS20_Training_274_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_274/BraTS20_Training_274_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_274/BraTS20_Training_274_BrainROI.nii.gz,0,BraTS2021_01479,BraTS19_TCIA09_141_1,...,102.788589,97.639325,94.684363,118.038612,106.882949,88.290266,106.202255,91.666134,96.335216,both
1,BraTS20_Training_293,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_293/BraTS20_Training_293_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_293/BraTS20_Training_293_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_293/BraTS20_Training_293_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_293/BraTS20_Training_293_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_293/BraTS20_Training_293_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_293/BraTS20_Training_293_BrainROI.nii.gz,0,BraTS2021_01498,BraTS19_TCIA10_410_1,...,101.672477,85.079209,81.828045,96.746334,107.863478,83.794617,99.950014,88.945683,89.526921,both
2,BraTS20_Training_190,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_190/BraTS20_Training_190_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_190/BraTS20_Training_190_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_190/BraTS20_Training_190_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_190/BraTS20_Training_190_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_190/BraTS20_Training_190_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_190/BraTS20_Training_190_BrainROI.nii.gz,1,BraTS2021_01300,BraTS19_TCIA02_226_1,...,161.133893,117.501382,72.152114,158.558518,124.733076,69.336469,161.917178,114.233129,78.978528,both
3,BraTS20_Training_298,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_298/BraTS20_Training_298_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_298/BraTS20_Training_298_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_298/BraTS20_Training_298_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_298/BraTS20_Training_298_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_298/BraTS20_Training_298_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_298/BraTS20_Training_298_BrainROI.nii.gz,0,BraTS2021_01503,BraTS19_TCIA10_276_1,...,110.542553,73.074468,70.808511,107.090113,82.676138,76.029439,105.099771,65.077985,76.992437,both
4,BraTS20_Training_334,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_334/BraTS20_Training_334_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_334/BraTS20_Training_334_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_334/BraTS20_Training_334_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_334/BraTS20_Training_334_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_334/BraTS20_Training_334_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_334/BraTS20_Training_334_BrainROI.nii.gz,0,BraTS2021_01665,BraTS19_TCIA13_624_1,...,81.914264,146.606787,69.382751,80.744102,138.641146,66.538335,80.554348,144.624506,74.856719,both
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
130,BraTS20_Training_234,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_234/BraTS20_Training_234_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_234/BraTS20_Training_234_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_234/BraTS20_Training_234_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_234/BraTS20_Training_234_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_234/BraTS20_Training_234_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_234/BraTS20_Training_234_BrainROI.nii.gz,0,BraTS2021_00134,BraTS19_TCIA05_444_1,...,135.527764,82.683842,82.832916,139.213709,94.138320,68.037037,136.153603,95.470173,68.157774,both
131,BraTS20_Training_249,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_249/BraTS20_Training_249_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_249/BraTS20_Training_249_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_249/BraTS20_Training_249_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_249/BraTS20_Training_249_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_249/BraTS20_Training_249_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_249/BraTS20_Training_249_BrainROI.nii.gz,1,BraTS2021_00148,BraTS19_TCIA08_280_1,...,167.902679,124.229911,48.055506,161.954812,120.572396,53.495751,165.754088,122.700629,53.863522,both
132,BraTS20_Training_171,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_171/BraTS20_Training_171_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_171/BraTS20_Training_171_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_171/BraTS20_Training_171_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_171/BraTS20_Training_171_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_171/BraTS20_Training_171_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_171/BraTS20_Training_171_BrainROI.nii.gz,1,BraTS2021_00999,BraTS19_TCIA01_180_1,...,104.061792,111.841124,95.391230,106.538165,100.652189,97.482928,102.660285,106.012682,93.015594,both
133,BraTS20_Training_185,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_185/BraTS20_Training_185_t1.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_185/BraTS20_Training_185_t1ce.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_185/BraTS20_Training_185_t2.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_185/BraTS20_Training_185_flair.nii.gz,/home/mmiv-ml/data/MICCAI_BraTS2020_TrainingData/BraTS20_Training_185/BraTS20_Training_185_seg.nii.gz,/home/mmiv-ml/data/ROIBrain_MICCAI_BraTS2020/BraTS20_Training_185/BraTS20_Training_185_BrainROI.nii.gz,1,BraTS2021_00142,BraTS19_TCIA02_314_1,...,145.337610,127.855479,61.117421,145.918486,133.762711,67.759784,145.855751,132.161974,59.734923,both


## Cearing a list of dictionaries in order to feed into Monai's Dataset
Keys:
- ***image:*** T1, T1c, T2, and flair image
- ***label:*** Segmented mask GT
- ***brain_mask:*** Whole brain area (brain area=1 and Non brain area=0)
- ***IDH_value:*** IDH class corresponding to the subject/images

In [7]:
train_files = [{'image': (image_nameT1, image_nameT1ce, image_nameT2, image_nameFl), 'label': label_name, 'brain_mask':brain_mask, 'IDH_label': np.array(IDH_label_name)} 
               for image_nameT1,image_nameT1ce, image_nameT2, image_nameFl, label_name, brain_mask, IDH_label_name 
               in zip(BraTS20SubjectsIDHWithMetaDF['t1wPath'], BraTS20SubjectsIDHWithMetaDF['t1cwPath'], BraTS20SubjectsIDHWithMetaDF['t2wPath'], BraTS20SubjectsIDHWithMetaDF['flairPath'],\
                     BraTS20SubjectsIDHWithMetaDF['segPath'], BraTS20SubjectsIDHWithMetaDF['t1w_BrainmaskPath'], BraTS20SubjectsIDHWithMetaDF['IDH_value'])]

# train_files_image = [(image_nameT1, image_nameT1ce, image_nameT2, image_nameFl) 
#                      for image_nameT1,image_nameT1ce, image_nameT2, image_nameFl 
#                      in zip(dfTrainLbl['t1wPath'], dfTrainLbl['t1cwPath'], dfTrainLbl['T2wPath'], dfTrainLbl['FlairPath'])]
# train_files_label = dfTrainLbl['segPath'].tolist()
# train_files_brain_mask = dfTrainLbl['brain_maskPath'].tolist()
# train_files_IDH_label = dfTrainLbl['IDH_value'].values.ravel().tolist()


## Creating 3 splits for cross validaion (3 cross validaion)

In [8]:
n_splits = 3
#train_index = np.linspace(0, train_features.shape[0]-1, num = train_features.shape[0], dtype = np.uint16, endpoint=True)
#partition_data = monai.data.utils.partition_dataset_classes(train_index, train_labels.values.ravel().tolist(), shuffle=True, num_partitions=n_splits) 
#partition_data = monai.data.utils.partition_dataset_classes(train_files, dfTrainLbl['IDH_value'].values.ravel().tolist(), shuffle=True, num_partitions=n_splits)
partition_data = monai.data.partition_dataset_classes(train_files, BraTS20SubjectsIDHWithMetaDF['IDH_value'].values.ravel().tolist(), shuffle=True, num_partitions=n_splits)
print(len(partition_data), len(partition_data[0]), len(partition_data[1]), len(partition_data[2]))


# val_folds = {}
# train_folds = {}
# flds = np.linspace(0, n_splits, num=n_splits, dtype = np.int8)
# for cfold in range(n_splits):
#     not_cfold = np.delete(flds, cfold)
#     val_folds[cfold] = partition_data[cfold]
# #     train_folds[cfold] = 
# # sub_flds = flds[..., ~0]   
# # sub_flds

val_folds = {}
train_folds = {}
flds = np.linspace(0, n_splits, num=n_splits, dtype = np.uint8)
for cfold in range(n_splits):
    #val_folds[f"fold{cfold}"] = train_features.values[partition_data[cfold],:]
    #train_folds[f"fold{cfold}"] = np.delete(train_features.values, partition_data[cfold], axis=0)
    #not_cfold = np.delete(flds, cfold)
    
    val_folds[f"fold{cfold}"] = partition_data[cfold]
    val_folds[f"fold{cfold}_IDH_label"] = [adct['IDH_label'].item() for adct in partition_data[cfold]]
    train_folds_masks = [1]*n_splits
    train_folds_masks[cfold] = 0
    partition_data_non_cfold = list()
    for aDctLstitem in compress(partition_data, train_folds_masks):
        partition_data_non_cfold.extend(aDctLstitem)
        
        
    train_folds[f"fold{cfold}"] = partition_data_non_cfold
    train_folds[f"fold{cfold}_IDH_label"] = [adct['IDH_label'].item() for adct in partition_data_non_cfold]

for i in range(n_splits):
    print('val: ', len(val_folds[f'fold{i}']), 'train: ', len(train_folds[f'fold{i}']), '\n')

3 45 45 45
val:  45 train:  90 

val:  45 train:  90 

val:  45 train:  90 



In [9]:

for i_cv in range(n_splits):
    print('Training classes\n')
    print(np.unique([train_folds[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(train_folds[f'fold{i_cv}']))], return_counts = True))
    print('\nValidation classes\n')
    print(np.unique([val_folds[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(val_folds[f'fold{i_cv}']))], return_counts = True))
    print('#'*4, '\n\n')

Training classes

(array([0, 1]), array([38, 52]))

Validation classes

(array([0, 1]), array([19, 26]))
#### 


Training classes

(array([0, 1]), array([38, 52]))

Validation classes

(array([0, 1]), array([19, 26]))
#### 


Training classes

(array([0, 1]), array([38, 52]))

Validation classes

(array([0, 1]), array([19, 26]))
#### 




***HistogramStandardization***

Implementing histogram standardization from [torchIO](https://github.com/fepegar/torchio) library

Bases: [torchio.transforms.preprocessing.intensity.normalization_transform.NormalizationTransform](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.preprocessing.intensity.NormalizationTransform)

Perform histogram standardization of intensity values.

Implementation of [New variants of a method of MRI scale standardization](https://ieeexplore.ieee.org/document/836373).

We can visit in [torchio.transforms.HistogramStandardization.train()]((https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization.train)) for more details.

PARAMETERS
landmarks – Dictionary (or path to a PyTorch file with .pt or .pth extension in which a dictionary has been saved) whose keys are image names in the subject and values are NumPy arrays or paths to NumPy arrays defining the landmarks after training with [torchio.transforms.HistogramStandardization.train()](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization.train).

Here, ***save_dir*** is a path where the trained histogram files for four channels (T1w, T1cw, T2w, and Flair), and trained model's weights will be saved

In [10]:
# file_prefix = 'ConvEffNet_Brats21_1SplitV0'
# savedirname = 'ConvEffNet_Brats21'
# save_dir = os.path.join('/raid/brats2021/pthBraTS2021Radiogenomics', savedirname)
# if not os.path.exists(save_dir):
#     os.makedirs(save_dir)

# train_images20T1 = dfTrainLbl['t1wPath'].values
# train_images20T1ce = dfTrainLbl['t1cwPath'].values
# train_images20T2 = dfTrainLbl['T2wPath'].values
# train_images20Flair = dfTrainLbl['FlairPath'].values

# hiseq_t1npyfile = os.path.join(save_dir, f"histeq_t1w_{file_prefix}.npy")
# t1w_landmarks = (hiseq_t1npyfile if os.path.isfile(hiseq_t1npyfile) else \
#                  tio.HistogramStandardization.train(train_images20T1, output_path = hiseq_t1npyfile))
# # #torch.save(t1w_landmarks, hiseq_t1npyfile)

# hiseq_t1cnpyfile =  os.path.join(save_dir, f"histeq_t1cw_{file_prefix}.npy")
# t1cw_landmarks = (hiseq_t1cnpyfile if os.path.isfile(hiseq_t1cnpyfile) else \
#                   tio.HistogramStandardization.train(train_images20T1ce, output_path = hiseq_t1cnpyfile))
# #torch.save(t1cw_landmarks, hiseq_t1cnpyfile)


# hiseq_t2npyfile = os.path.join(save_dir, f"histeq_t2w_{file_prefix}.npy")
# t2w_landmarks = (hiseq_t2npyfile if os.path.isfile(hiseq_t2npyfile) else \
#                  tio.HistogramStandardization.train(train_images20T2, output_path = hiseq_t2npyfile))
# #torch.save(t2w_landmarks, hiseq_t2npyfile)

# hiseq_flairnpyfile = os.path.join(save_dir, f"histeq_flair_{file_prefix}.npy")
# flair_landmarks = (hiseq_flairnpyfile if os.path.isfile(hiseq_flairnpyfile) else \
#                    tio.HistogramStandardization.train(train_images20Flair, output_path = hiseq_flairnpyfile))
# #torch.save(flair_landmarks, hiseq_flairnpyfile)

file_prefix = 'DynUNet_Brats20_3CV_4Chnls1PatchSWIRngr21_V0'
savedirname = 'DynUNetVariants_Brats20'
save_dir = os.path.join('/raid/brats2021/pthBraTS2020_IDHGenomics', savedirname)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

## Classes for Monai/Pytorch compose class

A class to rearrange label mask array as 
- [0, :, :, :] = the multi class mask (class labels: 0 (background), 1, 2, and 4)
- [1, :, :, :] = the whole tumor mask (class labels: 0 (background), and 1)\
Not using here

In [11]:
class ConvertToMultiChannelPlusWT(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            
            d[key]=np.squeeze(d[key], axis = 0) # Converting 1, H, W, D to H, W, D
            result.append(d[key])

            # merge labels 1, 2 and 4 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1
                )
            )
            ## merge label 1 and label 4 to construct TC
            #result.append(np.logical_or(d[key] == 1, d[key] == 4))
            ## label 4 is ET
            #result.append(d[key] == 4)
            d[key] = np.stack(result, axis=0).astype(np.uint8)
        return d

#### Define a new transform to convert brain tumor labels
Here we convert the multi-classes labels into multi-labels segmentation task in One-Hot format.\
Not using here

In [12]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            
            d[key]=np.squeeze(d[key], axis = 0) # Converting 1, H, W, D to H, W, D

            # merge labels 1, 2 and 4 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1
                )
            )
            # merge label 1 and label 4 to construct TC
            result.append(np.logical_or(d[key] == 1, d[key] == 4))
            # label 4 is ET
            result.append(d[key] == 4)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

#### A class to add new key having the tumor mask (GT) to the existing data dictionary
The new key, ***label_mask*** will have the same dimension (size: 4,x,x,x) with image array (size: 4,x,x,x)\
Using in ***compose*** class

In [13]:
class ConvertToIDHLabel2WTd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __init__(self, keys: KeysCollection, IDH_label_key:str = 'IDH_label') -> None:

        super().__init__(keys)
        self.IDH_label_key = IDH_label_key
       
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # merge labels 1, 2 and 4 to construct WT
            #WT = np.logical_or(np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1).astype(np.uint8)
            WT = d[key]
            if d[self.IDH_label_key].item() == 1:
                WT=np.multiply(WT, 2)
                #WT = 2*WT
            d[key] = WT
    
        return d

In [14]:
class Convert2WTd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            
            # merge labels 1, 2 and 4 to construct WT
            WT = np.logical_or(np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1).astype(np.uint8)

            d[f'{key}_mask'] = WT
            d[f'{key}'] = WT
    
        return d

In [15]:
class SpatialCropWTCOMd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    def __init__(self, keys: KeysCollection, roi_size, COM_label_key:str = 'label_mask') -> None:

        super().__init__(keys)
        self.COM_label_key = COM_label_key
        self.roi_size = roi_size
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            Coms = np.array([ndimage.measurements.center_of_mass(lbl) for lbl in list(d[self.COM_label_key])])
            Coms[np.isnan(Coms)] = 70
            Coms=Coms[0].astype(np.uint16).tolist()
        
            sc_com= SpatialCrop(roi_center= Coms, roi_size=self.roi_size)
            d[key] = sc_com(d[key])
                
        return d

### A class for blending (alpha or other types) 4 channel image array with label(GT) if necessay
Not using here

In [16]:
class AddBlendImaged(MapTransform):
    """
          we do not need labels as it is a generative problem
    """
    
    
    
    def __init__(self, keys: KeysCollection, label_key:str = 'label') -> None:
        
        super().__init__(keys)
        self.label_key = label_key
    
    
    def __call__(self, data):
     
        d = dict(data)
        kyesImgLbl = self.keys
        
        for key in self.keys:
            num_labels = d[self.label_key].shape[0]
            class_weights = [0.04, 0.15, 0.5, .2]
            
            dd=np.zeros_like(d[self.label_key][0:1,...])
            
            #for i in range(num_labels):
            
            values, counts = np.unique(d[self.label_key], return_counts=True)
            dlbl = d[self.label_key]
            
            for cl in range(1,  len(values)):
                
                d[self.label_key] = np.where(d[self.label_key]==values[cl], class_weights[cl], d[self.label_key])
                
                #dd = dd + label_weights[i]*d[self.label_key][i:i+1,:,:,:]
            
            #blend_d = 0.33*d[key][1:2,:,:,:] + 0.33*d[key][2:3,:,:,:] + 0.34 * dd
            
            blend_d = 0.7*d[key][0:1,:,:,:] + 0.3 * d[self.label_key]
            d[key] = np.concatenate((d[key], blend_d),axis=0)
            
        return d

### A class for concatenating brain_mask and label(GT) with 4-channel image array if necessary
(Not using here)

In [17]:
class ConcatLabelBrainmaskd(MapTransform):
    """
          we do not need labels as it is a generative problem
    """
    
    def __init__(self, keys: KeysCollection, image_key = 'image', label_key = 'label', 
                 brain_mask_key = 'brain_mask') -> None:
        
        super().__init__(keys)
        self.brain_mask_key = brain_mask_key
        self.image_key = image_key
        self.label_key = label_key
    
    
    def __call__(self, data):
     
        d = dict(data)
        #d[self.image_key] = np.concatenate((d[self.image_key], d[self.label_key], d[self.brain_mask_key][0:1]), axis = 0)
        d[self.image_key] = np.concatenate((d[self.image_key], d[self.label_key][0:1]), axis = 0)
        
        return d

### Implementing channelwise histogram normalization
(Not using here)

In [18]:
class HistogramNormalizeChannelWised(MapTransform):
    """
          we do not need labels as it is a generative problem
    """
    
    def __init__(self, keys: KeysCollection, brain_mask_key = 'brain_mask', min=0, max=255) -> None:
        
        super().__init__(keys)
        self.brain_mask_key = brain_mask_key
        self.histnorms = HistogramNormalize(num_bins=256, min=min, max=max)
    
    
    def __call__(self, data):
     
        d = dict(data)
        for key in self.keys:
            nchnl = d[key].shape[0]
            for ch in range(nchnl):
                d[key][ch] = self.histnorms(d[key][ch], d[self.brain_mask_key][ch])
        
        return d

### Defining traning and validation transforms

- Training transform includes:
    - LoadImaged
    - EnsureChannelFirstd
    - HistogramNormalizeChannelWised: Histogram normalization channel wise (custom class defined aboove)
    - NormalizeIntensityd
    - RandRotate90d
    - RandZoomd
    - ConvertToIDHLabel2WTd (custom class defined above)
    - CropForegroundd: Cropping foreground based on the whole tumor mask (WT GT)
    - RandCropByPosNegLabeld: Randomly cropping 8 patches based on 3: 1 (WT : non tumor tissus) ratio
    - RandGaussianNoised
    - RandStdShiftIntensityd
    - RandFlipd
    
- validation transform includes:
    - LoadImaged
    - EnsureChannelFirstd
    - HistogramNormalizeChannelWised: Histogram normalization channel wise (custom class defined aboove)
    - NormalizeIntensityd
    - ConvertToIDHLabel2WTd (custom class defined above)
    
Most of transfroms are implemented using [Monai](https://docs.monai.io/en/latest/transforms.html#dictionary-transforms) library



In [19]:
def threshold_foreground(x):
    # threshold at not equal to 0
    return x == 1


#Resized(keys=keys[0:-1], spatial_size=patch_size, mode = ('area','nearest','nearest')),

# ConvertToMultiChannelBasedOnBratsClassesd(keys = ['label']),
# ConcatLabelBrainmaskd(keys = None, image_key = 'image', label_key = 'label', brain_mask_key = 'brain_mask'),
# CropForegroundd(keys=keys[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
# SpatialPadd(keys=keys[0:-1], spatial_size=patch_size),

# RandGaussianSmoothd(
#     keys=["image"],
#     sigma_x=(0.5, 1.15),
#     sigma_y=(0.5, 1.15),
#     sigma_z=(0.5, 1.15),
#     prob=0.3,
# ),

# RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.3),
# RandGibbsNoised(keys=["image"], prob=0.3, alpha=(0.1, 0.5), as_tensor_output=False),

          
# DataStatsd(keys=keys[0:-1], prefix="Data", data_type=True, data_shape=True, value_range=True, data_value=False),
#DataStatsd(keys=keysExt[0:-1], prefix="Data", data_type=True, data_shape=True, value_range=True, data_value=False),


def get_task_transforms(patch_size, task='train', pos_sample_num=1, neg_sample_num=1, num_samples=1):
    
    #spatial_size=(30, 30, 30)
    orig_img_size = (240, 240, 155)

    if task=='train':
        keys = ["image", 'label', 'brain_mask', 'IDH_label']
        keysExt = ["image", 'label', 'brain_mask', 'label_mask', 'IDH_label']
        
        all_transform = [
            
            LoadImaged(keys=keys[0:-1], reader = "NibabelReader"),
            EnsureChannelFirstd(keys=keys[0:-1]),
            #adapter_tioChannelWise2monai(tiofn = tio.HistogramStandardization, mode = 'train',landmarks = landmarks_dict),
            #HistogramNormalizeChannelWised(keys = ['image'], brain_mask_key = 'brain_mask', min = 1, max = 65535),
            
            RandAffined(
                keys = keys[0:-1],
                prob=0.2,
                spatial_size= orig_img_size, #(240, 240, 155),
                rotate_range=np.pi/9,
                scale_range=(0.1, 0.1),
                mode=("bilinear", "nearest", "nearest"),
                as_tensor_output=False,
                padding_mode = ("zeros", "zeros", "zeros"),
            ),
            
            RandRotate90d(keys=keys[0:-1], prob=0.3, spatial_axes=[0, 2]),

            RandZoomd(
                keys=keys[0:-1],
                min_zoom=0.9,
                max_zoom=1.1,
                mode=("trilinear", "nearest", "nearest"),
                align_corners=(True, None, None),
                prob=0.3,
            ),
            
            #ConvertToIDHLabel2WTd(keys = ["label"]),
            Convert2WTd(keys = ["label"]),
            ConvertToIDHLabel2WTd(keys = ["label"], IDH_label_key = 'IDH_label'),
            
            CropForegroundd(keys=keysExt[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            #Spacingd(keys = keysExt[0:-1], pixdim=(1.25, 1.25, 1.25), mode = ('bilinear','nearest', 'nearest', 'nearest')),
            NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
            
            ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = (128, 160, 128)),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.3),
            RandStdShiftIntensityd(keys = ["image"], factors=0.3, nonzero=True, channel_wise=True, prob=0.3), 
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=0),
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=1),
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=2),
            
            #CropForegroundd(keys=keysExt[0:-1], source_key="label_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            #ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = patch_size),
#             RandCropByLabelClassesd(
#                 keys=keysExt[0:-1],            
#                 label_key = "label_mask",
#                 spatial_size = patch_size,    
#                 ratios= [1, 10],     
#                 num_classes=2,              
#                 num_samples=1,
#                 image_key="brain_mask",
#                 image_threshold=0,
#             ),
            #SpatialCropWTCOMd(keys=keysExt[0:-1], roi_size=patch_size, COM_label_key = "label_mask"),
            #SpatialPadd(keys = keysExt[0:-1], spatial_size = patch_size),
        
            CastToTyped(keys=keysExt, dtype=(np.float32, np.uint8, np.uint8, np.uint8, np.float32)),
            ToTensord(keys=keysExt),

#             #EnsureTyped(keys=keys, data_type = "tensor"),
#             #ToDeviced(keys = keys, device = deviceName),
        ]
        
        
        
    elif task=='validation':
    
        keys = ["image", 'label', 'brain_mask', 'IDH_label']
        keysExt = ["image", 'label', 'brain_mask', 'label_mask', 'IDH_label']
        all_transform = [
            
            LoadImaged(keys=keys[0:-1], reader = "NibabelReader"),
            EnsureChannelFirstd(keys=keys[0:-1]),
            #adapter_tioChannelWise2monai(tiofn = tio.HistogramStandardization, mode = 'train',landmarks = landmarks_dict),
            #HistogramNormalizeChannelWised(keys = ['image'], brain_mask_key = 'brain_mask', min = 1, max = 65535),
            #ConvertToIDHLabel2WTd(keys = ["label"]),
            Convert2WTd(keys = ["label"]),
            ConvertToIDHLabel2WTd(keys = ["label"], IDH_label_key = 'IDH_label'),
            CropForegroundd(keys=keysExt[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            #Spacingd(keys = keysExt[0:-1], pixdim=(1.25, 1.25, 1.25), mode = ('bilinear','nearest', 'nearest', 'nearest')),
            
            NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
            ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = (128, 160, 128)),
            #CropForegroundd(keys=keysExt[0:-1], source_key="label_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            #SpatialCropWTCOMd(keys=keysExt[0:-1], roi_size=patch_size, COM_label_key = "label_mask"),
            #SpatialPadd(keys = keysExt[0:-1], spatial_size = patch_size),
#             RandCropByPosNegLabeld(
#                 keys=keysExt[0:-1],
#                 label_key="label_mask",
#                 spatial_size=patch_size,
#                 pos=pos_sample_num,
#                 neg=neg_sample_num,
#                 num_samples=num_samples,
#                 image_key="brain_mask",
#                 image_threshold=0,
#             ),

            CastToTyped(keys=keysExt, dtype=(np.float32, np.uint8, np.uint8, np.uint8, np.float32)),
            ToTensord(keys=keysExt),
            #EnsureTyped(keys=keys, data_type = "tensor"),
            #ToDeviced(keys = keys, device = deviceName),
        ]
        
    else:
        print('print task either train or validation here')


    return Compose(all_transform)

# def create_cachedir(cache_dir):
#     if not os.path.exists(cache_dir):
#         os.makedirs(cache_dir)
#     return 1

### Section for visual inspection and debugging

In [20]:
patch_size=(128, 160, 128)
train_transforms = get_task_transforms(patch_size, task='train', pos_sample_num=1, neg_sample_num=1, num_samples=1)
val_transforms = get_task_transforms(patch_size, task='validation', pos_sample_num=1, neg_sample_num=1, num_samples=1)
len(train_transforms), len(val_transforms)

(17, 9)

In [21]:
#few_train_dataset = Dataset(data=train_files[0:10], transform=train_transforms)
# asub = few_train_dataset[5] 
# view(image = asub['image'][3].cpu(), label_image = asub['label_mask'][0].cpu())


### Few investigation

In [22]:
afold_train_dataset = monai.data.Dataset(data=train_folds['fold0'], transform=train_transforms)
#train_folds['fold0_IDH_label']
uval, ucnt = np.unique(train_folds['fold0_IDH_label'], return_counts=True)
weight = 1. / ucnt
#weight = np.array([0.55, 0.45])
sample_weights = np.array([weight[int(t)] for t in train_folds['fold0_IDH_label']])
sample_weights = torch.from_numpy(sample_weights)
weight, ucnt

(array([0.02631579, 0.01923077]), array([38, 52]))

In [23]:
#afold_train_dataset[200]['IDH_label'], afold_train_dataset[200]['label'].unique(return_counts = True)

## Custom editing of SegResNetVAE

In [24]:
def get_kernels_strides(sizes, spacings):
    #sizes, spacings = patch_size[task_id], spacing[task_id]
    strides, kernels = [], []

    while True:
        spacing_ratio = [sp / min(spacings) for sp in spacings]
        stride = [
            2 if ratio <= 2 and size >= 8 else 1
            for (ratio, size) in zip(spacing_ratio, sizes)
        ]
        kernel = [3 if ratio <= 2 else 1 for ratio in spacing_ratio]
        if all(s == 1 for s in stride):
            break
        sizes = [i / j for i, j in zip(sizes, stride)]
        spacings = [i * j for i, j in zip(spacings, stride)]
        kernels.append(kernel)
        strides.append(stride)
    strides.insert(0, len(spacings) * [1])
    kernels.append(len(spacings) * [3])
    return kernels, strides
#task_id = "01"
kernels, strides = get_kernels_strides(patch_size, spacing)
#kernels.append([3, 3, 3])
#strides.append([2, 2, 2])

print(kernels,'\n', strides)

print('strides length', len(strides))
#filters = [64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]
#filters = [16, 32, 64, 128, 160, 160][: len(strides)]
#print("Filters:", filters)

[[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]] 
 [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]
strides length 6


In [25]:
# kernels, strides = get_kernels_strides((128, 128, 128), spacing)
# #kernels, strides
# kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 6, 3]]
# strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

inps = torch.randn(3, 4, 64, 80, 64).to(device)
x = model(inps)
# # model
x.shape

In [26]:
# from torchsummary import summary
# summary(model, (4, 64, 80, 64))
# # inps = torch.randn(3, 4, 48, 64, 48)
# # litConv = nn.Conv3d(4, 64, kernel_size=(7, 7, 7), stride=(2, 2, 2), padding=(3, 3, 3), bias=False)
# # litConv(inps).shape

### Defining loss functions
- ***CrossEntropyLogitLoss*** Cross entropy logit loss from [PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)
- ***DiceCELoss*** Dice + Cross entropy loss from Monai
https://docs.monai.io/en/latest/losses.html


In [27]:
class CrossEntropyInstLoss(nn.Module):
    def __init__(self, is_smooth=False, label_smoothing = 0.1):
        super().__init__()
        self.is_smooth = is_smooth
        self.label_smoothing = label_smoothing
        
       
    def forward(self, y_pred, y_true):

        if self.is_smooth == True:
            y_true = y_true.float() * (1 - self.label_smoothing) + 0.5 * self.label_smoothing
            
        y_true = y_true.squeeze(dim=1).long()
        #deviceidx = y_pred.get_device()
        #device = torch.device('cpu') if deviceidx == -1 else torch.device(f'cuda:{deviceidx}')
        #loss = F.binary_cross_entropy_with_logits(y_pred.to(device), y_true.to(device), pos_weight = weight.to(device))  ##pos_weight = weight 
        loss = F.cross_entropy(y_pred, y_true, weight = None) 
        return loss
    

class DeepDiceCELogitInstLoss(nn.Module):
    def __init__(self):
        super().__init__()
        #self.volweight = torch.softmax(torch.tensor([0.12, 0.33, 0.55]), dim = 0)
        self.dice = DiceLoss(include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch = False)  # reduction = "none", batch = True
        #self.smcross_entropy = CrossEntropyInstLoss()  ### was none torch.Tensor([0.66, 0.33, 1]), torch.tensor(self.volweight)

    def forward(self, y_pred, y_true):
        
        y_true = y_true.unsqueeze(dim=0).expand(y_pred.shape[1],-1,-1,-1,-1, -1)
        #return sum([0.5 ** i * ((self.dice(p, l)) + self.smcross_entropy(p, l)) \
        #            for i, (p, l) in enumerate(zip(torch.unbind(y_pred, dim=1), torch.unbind(y_true, dim=0)))])
        return sum([0.5 ** i * self.dice(p, l) for i, (p, l) in enumerate(zip(torch.unbind(y_pred, dim=1), torch.unbind(y_true, dim=0)))])
    
loss_function = DeepDiceCELogitInstLoss()

In [28]:
xxp = torch.randint(0,3,size=(6,1, 128, 128, 128))
#pred = [torch.randn(6,3,8,8,6), torch.randn(6,3,8,8,6), torch.randn(6,3,8,8,6)]
pred = torch.stack([torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128)], dim=1)
loss_function(pred, xxp.float())

tensor(1.0315)

#### A function to create ***cache_dir*** to save transformed outputs

In [29]:
def removeAndcreate_cachedir(cache_dir):
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    else:
        #print("Pass")
        shutil.rmtree(cache_dir)
        os.makedirs(cache_dir)
    return 1

### Pytorch training loop

Following functionalities are added

- Implemeting learning rate finder
- Defining Ranger21 optimizer (learning rate scheduler attached to it)
- Implementing mixed precision (AMP)
- Saving the model weights based on the performance on validation data
- Executing 5 fold cross validation (CV) training/validation pipeline, saving a few best model's weights at each fold
- Defining train_dataset/train_loader, and val_dataset/val_loader at each fold
- Defining a CNN based classification model (DenseNet, EfficientNet, etc.) at each fold to make sure that all accumulated gradients get vanished

The key variables which are used here
- ***val_cache_dir:*** The path where the transformed ouputs of validaion files will be cached/saved
- ***train_cache_dir:*** The path where the transformed ouputs of training files will be cached/saved
- ***max_epochs:*** Total number of epochs
- ***save_dir:*** The path to save the checkpoints/weights of the model
- ***file_prefix:*** The text file where loss/accuracy is recoded like

```
current fold: 0 current epoch: 1, acc_metric: 0.4579 accuracy: 0.5085, f1score: 0.5085 epoch 1 average training loss: 0.7250 average validation loss: 0.7128 
current fold: 0 current epoch: 2, acc_metric: 0.4876 accuracy: 0.5424, f1score: 0.5424 epoch 2 average training loss: 0.6961 average validation loss: 0.6940 
current fold: 0 current epoch: 3, acc_metric: 0.4870 accuracy: 0.4915, f1score: 0.4915 epoch 3 average training loss: 0.6862 average validation loss: 0.6965 
current fold: 0 current epoch: 4, acc_metric: 0.4882 accuracy: 0.5593, f1score: 0.5593 epoch 4 average training loss: 0.6715 average validation loss: 0.6885 
current fold: 0 current epoch: 5, acc_metric: 0.5927 accuracy: 0.5593, f1score: 0.5593 epoch 5 average training loss: 0.6555 average validation loss: 0.6826 
```

- ***val_interval*** Epoch interval to investivate the model's performance on validation data. If the current performance is better than in previous epochs, the model's weights will be saved
- ***key_metric_n_saved*** The number of model checkpoints we want to save. It it is set as 5, top 5 checkpoints/weights will be saved in 5 different ***.pth*** files  


In [30]:

#***Executed pipeline***\
#<img src="assets/ProposedIDHClass.png" align="left" width="1024" height="1800">

In [31]:
def train(train_files, train_files_IDH_label, val_files, val_files_IDH_label, batch_size = 2, epochs = 10, find_lr=False, cfold = 0):
    

    #     model = DenseNet201(spatial_dims=2, in_channels=3,
    #                        out_channels=num_classes, pretrained=True).to(device)

    # create spatial 3D
    #model = MultiDenseNet(spatial_dims=3, in_channelsList=(4, 1, 1, 1, 1), out_channels=2, block_config = (6, 12, 24, 16)).to(device)
    #model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=4, out_channels=1).to(device)
    #model = monai.networks.nets.DenseNet264(spatial_dims=3, in_channels=4, out_channels=1, init_features=64, growth_rate=32, block_config=(6, 12, 64, 48)).to(device)
    #patch_size=(64, 80, 64)
    num_classes = 3
        
    model = DynUNet(
        spatial_dims=3,
        in_channels=4,
        out_channels=num_classes,
        kernel_size=kernels,
        strides=strides,
        upsample_kernel_size=strides[1:],
        norm_name="batch",
        #filters = filters,
        deep_supervision=True,
        res_block=False,
        deep_supr_num=2,
    ).to(device)
    
    
    
    auc_metric = ROCAUCMetric()
    

    #train_files, train_files_IDH_label, val_files, val_files_IDH_label = train_files[:48], train_files_IDH_label[:48], val_files[:16], val_files_IDH_label[:16]

    """
    Block for using Monai's caching mechanishm for faster training
    """
    
    file_prefixfold = file_prefix  ##or file_prefix f"{file_prefix}_fold{cfold}" if saving cv file
    data_rpath = '/home/mmiv-ml/data'
    train_cache_dir = os.path.join(data_rpath,f'cachingDataset/{file_prefixfold}/train')    
    val_cache_dir = os.path.join(data_rpath,f'cachingDataset/{file_prefixfold}/val')
    
    is_done_train = removeAndcreate_cachedir(train_cache_dir)
    is_done_val = removeAndcreate_cachedir(val_cache_dir)
    

    n_train_cache_n_trans = len(train_transforms) #15
    n_val_cache_n_trans = len(val_transforms)
    
     # create a training data loader

    train_dataset = monai.data.CacheNTransDataset(data=train_files, transform=train_transforms,\
                                                cache_n_trans = n_train_cache_n_trans, cache_dir = train_cache_dir)
    #train_folds['fold0_IDH_label']
    uval, ucnt = np.unique(train_files_IDH_label, return_counts=True)
    weight = 1./ucnt
    #weight = np.array([0.55, 0.45])
    sample_weights = np.array([weight[int(t)] for t in train_files_IDH_label])
    sample_weights = torch.from_numpy(sample_weights)
    sampler = WeightedRandomSampler(sample_weights, num_samples= len(sample_weights), replacement=True)


    #train_dataset = Dataset(data=train_files, transform=train_transforms)
    #train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    #train_dataset = CacheDataset(data=train_files, transform=train_transforms, cache_rate = 1.0, num_workers=8)
    #train_loader = ThreadDataLoader(train_dataset, num_workers=0, batch_size=batch_size, shuffle=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=1, sampler = sampler)
    
    # create a validation data loader
    
    #val_dataset = Dataset(data=val_files, transform=val_transforms)
    #val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
    #val_dataset = CacheDataset(data=val_files, transform=val_transforms, cache_rate = 1.0, num_workers=5)
    
    val_dataset = monai.data.CacheNTransDataset(data=val_files, transform=val_transforms,\
                                            cache_n_trans = n_val_cache_n_trans, cache_dir = val_cache_dir)
    #val_loader = ThreadDataLoader(val_dataset, num_workers=0, batch_size=1)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False) #num_workers=2, pin_memory=True
    
    
    
#     for ibatch in train_loader:
#         ibatch_IDH = ibatch['IDH_label']
#         print(torch.eq(ibatch_IDH, 0).sum(), torch.eq(ibatch_IDH, 1).sum())
#         print('#'*50)
    
    
    """
    just initialising some basic steps
    """
    
    max_epochs = epochs
    find_lr=False
    
    ### Calling the loss function ***CrossEntropyPlusMSELoss**,and optimizer   
    #loss_function = nn.CrossEntropyLoss()
    #loss_function=nn.BCEWithLogitsLoss()
    
    #optimizer = torch.optim.Adam(model.parameters(), lr=1e-03, weight_decay=1e-07)
    optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-05, weight_decay = 1e-4)
    #optimizer = torch.optim.SGD(model.parameters(), lr=1e-03, momentum= 0.99, nesterov=True)
    #optimizer = torch.optim.Adam(model.parameters(), 1e-03, weight_decay=1e-04)
    scaler = torch.cuda.amp.GradScaler()
    
    
    max_lr_init = 1e-03
    """
     ###################### Block for LR finder from pytorch ignite ########################

    """

    if find_lr:

        def prepare_batch(batch, device=None, non_blocking=False):
            return _prepare_batch((batch['image'], batch['IDH_label'].long()), device, non_blocking)

        #trainer = create_supervised_trainer(model, optimizer, loss_function, device, non_blocking=False, prepare_batch=prepare_batch)
        def train_step(engine, batch):
            model.train()
            optimizer.zero_grad()
            x, y = batch['image'].to(device), batch['IDH_label'].to(device)  #non_blocking=True
            with torch.cuda.amp.autocast():
                y_pred = model(x)
                loss4lr = loss_function(y_pred, y)
                
            scaler.scale(loss4lr).backward()
            scaler.step(optimizer)
            scaler.update()
            return loss4lr.item()

        trainer = Engine(train_step)

        ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {"batch loss": x})
        lr_finder = FastaiLRFinder()
        to_save={'model': model, 'optimizer': optimizer}
        num_iter = 100  #2*len(train_loader)
        run_epochs = int(np.ceil(num_iter/len(train_loader)))
        with lr_finder.attach(trainer, to_save, end_lr=10, num_iter=num_iter, diverge_th=1.5) as trainer_with_lr_finder:    ####diverge_th=1.5

            trainer_with_lr_finder.run(train_loader, max_epochs=run_epochs)  #max_epochs=run_epochs or 5

        ax = lr_finder.plot()
        plt.show()
        
        max_lr = lr_finder.lr_suggestion() if lr_finder.lr_suggestion()<5e-03 else max_lr_init
        #max_lr = lr_finder.lr_suggestion() ##max_lr/10 i guess not needed, ignite does itself
        print(f'Suggested learning rate by LR finder for this fold: {lr_finder.lr_suggestion()}')
        
    else:
        max_lr = max_lr_init
        
    #max_lr_slice = 1e01*max_lr if max_lr<5e-03 else 1e-02
    #max_lr_slice = 1e-01*max_lr if max_lr<1e-05 else max_lr
        
    
    """
    ### defining learning rate scheduler
    
    """
    #steps_per_epoch=len(train_loader)
    #optimizer.param_groups[0]['lr'] = max_lr #*1e-01
    #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr_slice, steps_per_epoch=len(train_loader), epochs=max_epochs)
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs) ** 0.9)
    
    #max_lr = 1e-3   
    optimizer = Ranger21(model.parameters(), lr = max_lr, num_epochs = epochs, num_batches_per_epoch = len(train_loader))
    


    """
     ###################### Block for native pytorch training loop and  ########################
    """

    key_metric_n_saved = 1   ### Usually I keep it 5
    save_last = False 
    dispformat_specs = '.4f'

        
#     file_prefix = 'ConvEffNet_Brats21_5CV'
#     savedirname = 'ConvEffNet_Brats21'
#     save_dir = os.path.join('/raid/brats2021/pthBraTS2021Radiogenomics', savedirname)
#     if not os.path.exists(save_dir):
#         os.makedirs(save_dir)

    logsfile_path = f"{save_dir}/Logs_{file_prefix}.txt"


    epoch_num = max_epochs #  max_epochs
    val_interval = 1
    valstep = 0
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()


    numsiters = len(train_files) // train_loader.batch_size

    first_batch = monai.utils.misc.first(train_loader)
        
    
    post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)  ### num_classes=num_classes
    post_label = AsDiscrete(to_onehot=num_classes) ###num_classes=num_classes
    dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='mean', get_not_nans=False)
    
    def one_hot_permute(x):
        return F.one_hot(x.squeeze(dim=0).long(), num_classes=num_classes).permute(3, 0, 1, 2)
    
    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0.
        stepiter = 0
        for batch_data in train_loader:
            stepiter += 1
            inputs, labels, IDH_labels= (
                batch_data['image'].to(device),
                batch_data['label'].to(device),
                batch_data['IDH_label'].to(device),
            )
            
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():

                # compute output
                outputs  = model(inputs)
                loss = loss_function(outputs, labels) 

            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

            print(f"{stepiter}/{numsiters}, train_loss: {loss.item():.4f}")

            #scheduler.step() 
            
        epoch_loss /= stepiter
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        
        
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():

                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                val_losses = torch.tensor([], dtype=torch.float32, device=device)


                for val_data in val_loader:

                    val_inputs, val_labels, val_IDH_labels = (
                        val_data['image'].to(device),
                        val_data['label'].to(device),
                        val_data['IDH_label'].to(device),
                    )
                
                    #roi_size = patch_size #(32, 32, 32)
                    #sw_batch_size = 1
                    #overlap = 0.25
                    
                    with torch.cuda.amp.autocast():
                        
                        #val_outputsC = sliding_window_inference_classes(val_inputs, roi_size, sw_batch_size, model, overlap=overlap, sw_device = device, device = device)
                        val_outputs = model(val_inputs)
                        val_ce_loss = loss_function(val_outputs.unsqueeze(dim=1), val_labels)

                    val_losses = torch.cat([val_losses, val_ce_loss.view(1)], dim = 0)
                    val_outputs = torch.stack([post_pred(i) for i in torch.unbind(val_outputs, dim = 0)], dim = 0)
                    
                    
                    val_labels2hot = torch.stack([one_hot_permute(i) for i in torch.unbind(val_labels, dim = 0)], dim = 0)
                    dice_metric(y_pred=val_outputs, y=val_labels2hot)
                    
                    uval_label4mSeg_vl, uval_label4mSeg_cnt = torch.argmax(val_outputs, dim=1).unique(return_counts = True)
                    val_label4mSeg_C = uval_label4mSeg_cnt[1:].argmax()
                    
                    #val_surv_labels = val_surv_labels.squeeze(dim=1)  ###Squeezing from B, 1 to B if needed
                    y_pred = torch.cat([y_pred, val_label4mSeg_C.view(1)], dim=0)
                    y = torch.cat([y, val_IDH_labels], dim=0)

                mdice_value = dice_metric.aggregate()
                dice_metric.reset()
                
                y_pred, y = y_pred.cpu(), y.cpu()
                acc_value = torch.eq(y_pred, y)
                acc_metric = acc_value.sum().item() / len(acc_value)
                
                #y_onehot = [post_label(i) for i in torch.unbind(y, dim=0)]
                #y_pred_act = [post_pred(i) for i in torch.unbind(y_pred, dim=0)]
                
                auc_metric(y_pred, y)
                auc_result = auc_metric.aggregate()
                auc_metric.reset()
            
                accscore = accuracy_score(y, y_pred)
                f1score = f1_score(y, y_pred, average='micro')
                del y, y_pred
                
            
                epoch_val_losses=torch.mean(val_losses).detach().cpu().item()
                #metric = auc_result
                mdice_value = mdice_value.item()
                metric = mdice_value
                metric_values.append(metric) ######List of over number of epochs
                printstring = "Best AUC"
                

                with open(logsfile_path, 'a') as file:
                    file.write(
                        f"current fold: {cfold} current epoch: {epoch + 1} dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}" 
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                        f" epoch {epoch + 1} average training loss: {epoch_loss:^{dispformat_specs}} average validation loss: {epoch_val_losses:^{dispformat_specs}} \n"

                    )

                if valstep < key_metric_n_saved:
                    torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_{metric:^{dispformat_specs}}_Fold{cfold}_epoch{epoch + 1}.pth"))
                    print(
                        f"current fold: {cfold} current epoch: {epoch + 1} dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}" 
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                        f" epoch {epoch + 1} average training loss: {epoch_loss:^{dispformat_specs}} average validation loss: {epoch_val_losses:^{dispformat_specs}}"
                        
                    )

                else:

                    #sortmetric_values = sorted(metric_values[:-1], reverse=True)  ###Higher loss needs to be deleted, so sorting is reversed
                    sortmetric_values = sorted(metric_values[:-1], reverse=False)  

                    if metric>=sortmetric_values[-key_metric_n_saved]:
                        savegood_metric = metric
                        good_metric_epoch = epoch + 1

                        #if os.path.exists(f"{save_dir}/{file_prefix}_{sortmetric_values[-key_metric_n_saved]:.4f}.pth"):
                        #    os.remove(f"{save_dir}/{file_prefix}_{sortmetric_values[-key_metric_n_saved]:.4f}.pth")
                        #else:
                        #    print("The file does not exist")

                        glblist = glob.glob(f"{save_dir}/{file_prefix}_{sortmetric_values[-key_metric_n_saved]:^{dispformat_specs}}_*")

                        if not glblist:
                            print("The file does not exist")
                        else:
                            os.remove(glblist[0])


                        torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_{metric:^{dispformat_specs}}_Fold{cfold}_epoch{epoch + 1}.pth"))
                        print("saved new best metric model")
                        print(
                            f"current fold: {cfold} current epoch: {epoch + 1} validation loss: {epoch_val_losses:^{dispformat_specs}}"
                            f" dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}"
                            f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                            f"\n saved {printstring}: {savegood_metric:^{dispformat_specs}} at epoch: {good_metric_epoch}"
                        )

                    else:

                        f"current fold: {cfold} current epoch: {epoch + 1} validation loss: {epoch_val_losses:^{dispformat_specs}}"
                        f" dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}"
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"

                        #pass

                valstep += 1
        ####Saving last epoch
        if epoch==epoch_num-1:
            if save_last:
                torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_{metric:^{dispformat_specs}}_Fold{cfold}_last_epoch{epoch + 1}.pth"))
            #break
            
    # Free up GPU memory after training
    model = None
    train_loader, val_loader = None, None
    gc.collect()
    torch.cuda.empty_cache()

### Loop to execute n_splits=3 fold cross validation
if the model is trained and the checkpoints are saved already, just setting the start_training flag as false, to run remaining part of the programs of this notebook

In [32]:
start_training = True

In [33]:
n_splits = 3
if start_training:
    #### Running 10 folds
    for i in range(0, n_splits):
        
        train_files_fld, train_files_fld_IDH_label, val_files_fld, val_files_fld_IDH_label  = train_folds[f'fold{i}'], train_folds[f'fold{i}_IDH_label'],\
        val_folds[f'fold{i}'], val_folds[f'fold{i}_IDH_label']
        batch_size=8
        ### Need to change batch size if minimux training batch size == 1
        print('fold', i, "Bacth Investigation, minimum batch size", len(train_files_fld)%batch_size)
        #train_files_fld, val_files_fld = sklearn_transforms(train_files_fld, val_files_fld)
        train(train_files_fld, train_files_fld_IDH_label, val_files_fld, val_files_fld_IDH_label, batch_size=batch_size, epochs = 500, cfold = i)
    
    start_training = False
else:
    
    pass

fold 0 Bacth Investigation, minimum batch size 2
Ranger21 optimizer ready with following settings:

Core optimizer = AdamW
Learning rate of 0.001

Important - num_epochs of training = ** 500 epochs **
please confirm this is correct or warmup and warmdown will be off

Warm-up: linear warmup, over 2000 iterations

Lookahead active, merging every 5 steps, with blend factor of 0.5
Norm Loss active, factor = 0.0001
Stable weight decay of 0.0001
Gradient Centralization = On

Adaptive Gradient Clipping = True
	clipping value of 0.01
	steps for clipping = 0.001

Warm-down: Linear warmdown, starting at 72.0%, iteration 4320 of 6000
warm down will decay until 3e-05 lr


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


----------
epoch 1/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


params size saved
total param groups = 1
total params in groups = 81
1/11, train_loss: 1.6227
2/11, train_loss: 1.6118
3/11, train_loss: 1.6257
4/11, train_loss: 1.6052
5/11, train_loss: 1.5962
6/11, train_loss: 1.6348
7/11, train_loss: 1.6555
8/11, train_loss: 1.6276
9/11, train_loss: 1.6371
10/11, train_loss: 1.6171
11/11, train_loss: 1.6208
12/11, train_loss: 1.6274
epoch 1 average loss: 1.6235


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


current fold: 0 current epoch: 1 dice_score: 0.0987 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778 epoch 1 average training loss: 1.6235 average validation loss: 0.9200
----------
epoch 2/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.6297
2/11, train_loss: 1.5678
3/11, train_loss: 1.6583
4/11, train_loss: 1.5833
5/11, train_loss: 1.6406
6/11, train_loss: 1.6270
7/11, train_loss: 1.6065
8/11, train_loss: 1.6355
9/11, train_loss: 1.6244
10/11, train_loss: 1.6091
11/11, train_loss: 1.6078
12/11, train_loss: 1.5839
epoch 2 average loss: 1.6145
----------
epoch 3/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.5636
2/11, train_loss: 1.5988
3/11, train_loss: 1.6239
4/11, train_loss: 1.6306
5/11, train_loss: 1.5875
6/11, train_loss: 1.5986
7/11, train_loss: 1.5910
8/11, train_loss: 1.6004
9/11, train_loss: 1.5942
10/11, train_loss: 1.5887
11/11, train_loss: 1.5779
12/11, train_loss: 1.5279
epoch 3 average loss: 1.5903
----------
epoch 4/500
1/11, train_loss: 1.5955
2/11, train_loss: 1.5458
3/11, train_loss: 1.6194
4/11, train_loss: 1.6167
5/11, train_loss: 1.5697
6/11, train_loss: 1.5831
7/11, train_loss: 1.5930
8/11, train_loss: 1.6003
9/11, train_loss: 1.6208
10/11, train_loss: 1.6270
11/11, train_loss: 1.5725
12/11, train_loss: 1.4195
epoch 4 average loss: 1.5803
saved new best metric model
current fold: 0 current epoch: 4 validation loss: 0.9145 dice_score: 0.1144 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1144 at epoch: 4
----------
epoch 5/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.5767
2/11, train_loss: 1.5916
3/11, train_loss: 1.5489
4/11, train_loss: 1.5549
5/11, train_loss: 1.5990
6/11, train_loss: 1.5726
7/11, train_loss: 1.5750
8/11, train_loss: 1.6025
9/11, train_loss: 1.5987
10/11, train_loss: 1.5655
11/11, train_loss: 1.6169
12/11, train_loss: 1.5274
epoch 5 average loss: 1.5775
saved new best metric model
current fold: 0 current epoch: 5 validation loss: 0.9078 dice_score: 0.1279 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1279 at epoch: 5
----------
epoch 6/500
1/11, train_loss: 1.5110
2/11, train_loss: 1.5649
3/11, train_loss: 1.5250
4/11, train_loss: 1.4356
5/11, train_loss: 1.5853
6/11, train_loss: 1.5411
7/11, train_loss: 1.5847
8/11, train_loss: 1.4702
9/11, train_loss: 1.5267
10/11, train_loss: 1.5674
11/11, train_loss: 1.5346
12/11, train_loss: 1.5524
epoch 6 average loss: 1.5332
saved new best metric model
current fold: 0 current epoch: 6 validation loss: 0.8963 dice_score: 0.1515 acc_metric: 0.50

Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


3/11, train_loss: 1.4634
4/11, train_loss: 1.5460
5/11, train_loss: 1.5000
6/11, train_loss: 1.5736
7/11, train_loss: 1.5103
8/11, train_loss: 1.5782
9/11, train_loss: 1.4834
10/11, train_loss: 1.4570
11/11, train_loss: 1.5903
12/11, train_loss: 1.5207
epoch 7 average loss: 1.5216
saved new best metric model
current fold: 0 current epoch: 7 validation loss: 0.8839 dice_score: 0.1766 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1766 at epoch: 7
----------
epoch 8/500
1/11, train_loss: 1.5646
2/11, train_loss: 1.5446
3/11, train_loss: 1.4479
4/11, train_loss: 1.5426
5/11, train_loss: 1.5811
6/11, train_loss: 1.5263
7/11, train_loss: 1.5223
8/11, train_loss: 1.5021
9/11, train_loss: 1.4562
10/11, train_loss: 1.4799
11/11, train_loss: 1.4332
12/11, train_loss: 1.4007
epoch 8 average loss: 1.5001
saved new best metric model
current fold: 0 current epoch: 8 validation loss: 0.8754 dice_score: 0.1941 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best A

Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


----------
epoch 1/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


params size saved
total param groups = 1
total params in groups = 81
1/11, train_loss: 1.6248
2/11, train_loss: 1.6369
3/11, train_loss: 1.6411
4/11, train_loss: 1.6212
5/11, train_loss: 1.6190
6/11, train_loss: 1.5994
7/11, train_loss: 1.6324
8/11, train_loss: 1.5810
9/11, train_loss: 1.6308
10/11, train_loss: 1.5857
11/11, train_loss: 1.5985
12/11, train_loss: 1.6680
epoch 1 average loss: 1.6199


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


current fold: 1 current epoch: 1 dice_score: 0.0611 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778 epoch 1 average training loss: 1.6199 average validation loss: 0.9330
----------
epoch 2/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.6239
2/11, train_loss: 1.6092
3/11, train_loss: 1.6651
4/11, train_loss: 1.5894
5/11, train_loss: 1.6143
6/11, train_loss: 1.5867
7/11, train_loss: 1.5903
8/11, train_loss: 1.5740
9/11, train_loss: 1.5956
10/11, train_loss: 1.6143
11/11, train_loss: 1.5772
12/11, train_loss: 1.5741
epoch 2 average loss: 1.6012
saved new best metric model
current fold: 1 current epoch: 2 validation loss: 0.9301 dice_score: 0.0706 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.0706 at epoch: 2
----------
epoch 3/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.5733
2/11, train_loss: 1.6222
3/11, train_loss: 1.6100
4/11, train_loss: 1.5950
5/11, train_loss: 1.6165
6/11, train_loss: 1.6074
7/11, train_loss: 1.6031
8/11, train_loss: 1.5885
9/11, train_loss: 1.6382
10/11, train_loss: 1.5807
11/11, train_loss: 1.5977
12/11, train_loss: 1.5248
epoch 3 average loss: 1.5964
saved new best metric model
current fold: 1 current epoch: 3 validation loss: 0.9262 dice_score: 0.0806 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.0806 at epoch: 3
----------
epoch 4/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.5811
2/11, train_loss: 1.5964
3/11, train_loss: 1.5971
4/11, train_loss: 1.5627
5/11, train_loss: 1.6016
6/11, train_loss: 1.5546
7/11, train_loss: 1.6253
8/11, train_loss: 1.5754
9/11, train_loss: 1.6439
10/11, train_loss: 1.5457
11/11, train_loss: 1.5907
12/11, train_loss: 1.6767
epoch 4 average loss: 1.5959
saved new best metric model
current fold: 1 current epoch: 4 validation loss: 0.9172 dice_score: 0.1008 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1008 at epoch: 4
----------
epoch 5/500
1/11, train_loss: 1.6048
2/11, train_loss: 1.6190
3/11, train_loss: 1.5731
4/11, train_loss: 1.6029
5/11, train_loss: 1.5431
6/11, train_loss: 1.5816
7/11, train_loss: 1.5768
8/11, train_loss: 1.6073
9/11, train_loss: 1.6324
10/11, train_loss: 1.5606
11/11, train_loss: 1.6191
12/11, train_loss: 1.4697
epoch 5 average loss: 1.5825
saved new best metric model
current fold: 1 current epoch: 5 validation loss: 0.9090 dice_score: 0.1192 acc_metric: 0.50

Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


----------
epoch 1/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


params size saved
total param groups = 1
total params in groups = 81
1/11, train_loss: 1.6370
2/11, train_loss: 1.6089
3/11, train_loss: 1.6170
4/11, train_loss: 1.6145
5/11, train_loss: 1.5966
6/11, train_loss: 1.6036
7/11, train_loss: 1.6050
8/11, train_loss: 1.6012
9/11, train_loss: 1.6010
10/11, train_loss: 1.6367
11/11, train_loss: 1.5886
12/11, train_loss: 1.6738
epoch 1 average loss: 1.6153


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


current fold: 2 current epoch: 1 dice_score: 0.0906 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778 epoch 1 average training loss: 1.6153 average validation loss: 0.9056
----------
epoch 2/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.6314
2/11, train_loss: 1.5956
3/11, train_loss: 1.6157
4/11, train_loss: 1.6625
5/11, train_loss: 1.6117
6/11, train_loss: 1.6087
7/11, train_loss: 1.6272
8/11, train_loss: 1.6083
9/11, train_loss: 1.6324
10/11, train_loss: 1.6029
11/11, train_loss: 1.5743
12/11, train_loss: 1.5462
epoch 2 average loss: 1.6098
saved new best metric model
current fold: 2 current epoch: 2 validation loss: 0.9058 dice_score: 0.0939 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.0939 at epoch: 2
----------
epoch 3/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.6066
2/11, train_loss: 1.6281
3/11, train_loss: 1.6190
4/11, train_loss: 1.6404
5/11, train_loss: 1.6232
6/11, train_loss: 1.6225
7/11, train_loss: 1.6091
8/11, train_loss: 1.5880
9/11, train_loss: 1.6174
10/11, train_loss: 1.6208
11/11, train_loss: 1.6070
12/11, train_loss: 1.6034
epoch 3 average loss: 1.6155
saved new best metric model
current fold: 2 current epoch: 3 validation loss: 0.9056 dice_score: 0.0956 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.0956 at epoch: 3
----------
epoch 4/500
1/11, train_loss: 1.6158
2/11, train_loss: 1.5834
3/11, train_loss: 1.5712
4/11, train_loss: 1.5925
5/11, train_loss: 1.6103
6/11, train_loss: 1.6500
7/11, train_loss: 1.5700
8/11, train_loss: 1.6244
9/11, train_loss: 1.6090
10/11, train_loss: 1.6130
11/11, train_loss: 1.6484
12/11, train_loss: 1.6760
epoch 4 average loss: 1.6137
saved new best metric model
current fold: 2 current epoch: 4 validation loss: 0.9012 dice_score: 0.1039 acc_metric: 0.50

Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.5588
2/11, train_loss: 1.5935
3/11, train_loss: 1.5980
4/11, train_loss: 1.5399
5/11, train_loss: 1.5870
6/11, train_loss: 1.5900
7/11, train_loss: 1.5817
8/11, train_loss: 1.5980
9/11, train_loss: 1.5762
10/11, train_loss: 1.5960
11/11, train_loss: 1.6130
12/11, train_loss: 1.6675
epoch 5 average loss: 1.5916
saved new best metric model
current fold: 2 current epoch: 5 validation loss: 0.8976 dice_score: 0.1101 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1101 at epoch: 5
----------
epoch 6/500


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]


1/11, train_loss: 1.6088
2/11, train_loss: 1.5190
3/11, train_loss: 1.6207
4/11, train_loss: 1.5741
5/11, train_loss: 1.5630
6/11, train_loss: 1.5291
7/11, train_loss: 1.5382
8/11, train_loss: 1.5718
9/11, train_loss: 1.5603
10/11, train_loss: 1.5778
11/11, train_loss: 1.5444
12/11, train_loss: 1.5049
epoch 6 average loss: 1.5593
saved new best metric model
current fold: 2 current epoch: 6 validation loss: 0.8920 dice_score: 0.1210 acc_metric: 0.5000 accuracy: 0.5778, f1score: 0.5778
 saved Best AUC: 0.1210 at epoch: 6
----------
epoch 7/500
1/11, train_loss: 1.5561
2/11, train_loss: 1.4754
3/11, train_loss: 1.5368
4/11, train_loss: 1.5916
5/11, train_loss: 1.5552
6/11, train_loss: 1.5994
7/11, train_loss: 1.5257
8/11, train_loss: 1.5332
9/11, train_loss: 1.5199
10/11, train_loss: 1.5919
11/11, train_loss: 1.5896
12/11, train_loss: 1.5406
epoch 7 average loss: 1.5513
saved new best metric model
current fold: 2 current epoch: 7 validation loss: 0.8857 dice_score: 0.1342 acc_metric: 0.50

In [None]:
# from sklearn.datasets import make_multilabel_classification
# from sklearn.multioutput import MultiOutputClassifier
# from sklearn.linear_model import LogisticRegression
# X, y = make_multilabel_classification(random_state=0, n_classes=2)
# inner_clf = LogisticRegression(solver="liblinear", random_state=0)
# clf = MultiOutputClassifier(inner_clf).fit(X, y)
# y_score = np.transpose([y_pred[:, 1] for y_pred in clf.predict_proba(X)])
# roc_auc_score(y, y_score, average=None)

In [33]:
# roc_auc_score(y, y_score)

In [46]:
def one_hot_permute(x):
    return F.one_hot(x.squeeze(dim=0).long(), num_classes=3).permute(2, 0, 1)