# MRI based brain tumor' 1p19q status classification with MONAI (3D multiparametric MRI)

This tutorial shows how to construct a training workflow of binary classification task.  
And it contains below features:
1. Transforms for Monai dictionary format data.
2. Define a new transform according MONAI transform API.
3. Load Nifti image with metadata, load a list of images and stack them.
5. 3D Voxel DynUNet model, Dice loss, cross entropy loss function for IDH classification task.
6. Deterministic training for reproducibility.

The Brain tumor dataset can be downloaded from 
https://ipp.cbica.upenn.edu/ and  http://medicaldecathlon.com/.  

Target: IDH classification based on whole brain, tumour core, whole tumor, and enhancing tumor from MRI 
Modality: Multimodal multisite MRI data (FLAIR, T1w, T1gd,T2w)  
training: 368 3D MRI \
validation:  \
testing: Not revealed

Source: BRATS 2020/2021 datasets.  
Challenge: RSNA-MICCAI Brain Tumor Radiogenomic Classification

Below figure shows image patches with the tumor sub-regions that are annotated in the different modalities (top left) and the final labels for the whole dataset (right). (Figure taken from the [BraTS IEEE TMI paper](https://ieeexplore.ieee.org/document/6975210/))  
![image](https://ieeexplore.ieee.org/mediastore_new/IEEE/content/media/42/7283692/6975210/6975210-fig-3-source-large.gif)

The image patches show from left to right:
1. the whole tumor (yellow) visible in T2-FLAIR (Fig.A).
2. the tumor core (red) visible in T2 (Fig.B).
3. the enhancing tumor structures (light blue) visible in T1Gd, surrounding the cystic/necrotic components of the core (green) (Fig. C).
4. The segmentations are used to generate the final labels of the tumor sub-regions (Fig.D): edema (yellow), non-enhancing solid core (red), necrotic/cystic core (green), enhancing core (blue).

In [1]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import shutil
import tempfile
import sys
import gc
import logging
import copy
import pdb

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import seaborn as sns
import numpy as np
import pandas as pd
import scipy
from scipy import ndimage
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.nn.functional as F
from monai.networks.nets import DynUNet, EfficientNetBN, DenseNet121, SegResNet, SegResNetVAE, AttentionUnet
from monai.data import CacheDataset, Dataset, DataLoader, ThreadDataLoader, list_data_collate
from torch.utils.data import WeightedRandomSampler

import monai
from monai.transforms import (
    Activations,
    AsDiscrete,
    CastToTyped,
    Compose, 
    CropForegroundd,
    ResizeWithPadOrCrop,
    ResizeWithPadOrCropd,
    Spacingd,
    RandRotate90d,
    Resized,
    EnsureChannelFirstd, 
    Orientationd,
    LoadImaged,
    CopyItemsd,
    NormalizeIntensity,
    HistogramNormalize,
    NormalizeIntensityd,
    RandCropByPosNegLabeld,
    RandCropByLabelClassesd,
    RandAffined,
    RandFlipd,
    Flipd,
    RandGaussianNoised,
    RandGaussianSmoothd,
    RandGibbsNoised,
    RandStdShiftIntensityd,
    RandScaleIntensityd,
    RandZoomd, 
    SpatialCrop, 
    SpatialPadd, 
    MapTransform,
    CastToType,
    ToTensord,
    AddChanneld,
    MapTransform,
    Orientationd,
    ScaleIntensityd,
    ScaleIntensity,
    ScaleIntensityRangePercentilesd,
    KeepLargestConnectedComponentd,
    KeepLargestConnectedComponent,
    ScaleIntensityRange,
    RandShiftIntensityd,
    RandAdjustContrastd,
    AdjustContrastd,
    Rotated,
    ToNumpyd,
    ToDeviced,
    EnsureType,
    EnsureTyped,
    DataStatsd,
)

from monai.config import KeysCollection
from monai.transforms.compose import MapTransform, Randomizable
from collections.abc import Iterable
from typing import Any, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from monai.utils import set_determinism
from monai.utils import (
    ensure_tuple,
    ensure_tuple_rep,
    ensure_tuple_size,
)

from monai.optimizers import LearningRateFinder

from monai.transforms.compose import MapTransform
from monai.transforms.utils import generate_spatial_bounding_box
from skimage.transform import resize
from monai.losses import DiceCELoss, DiceLoss
from monai.utils import set_determinism
from monai.inferers import sliding_window_inference


from monai.metrics import DiceMetric, ROCAUCMetric, HausdorffDistanceMetric
from monai.data import decollate_batch
import glob
import monai
from monai.metrics import compute_meandice
import random
import pickle
from collections import OrderedDict
from typing import Sequence, Optional
import ipywidgets as widgets
from itertools import compress
import SimpleITK as sitk
import torchio as tio

import sklearn
from sklearn.metrics import mean_squared_error, roc_auc_score, accuracy_score, recall_score, \
accuracy_score, precision_score, f1_score, make_scorer,balanced_accuracy_score 

from monai.utils import ensure_tuple_rep
from monai.networks.layers.factories import Conv, Dropout, Norm, Pool
import matplotlib.pyplot as plt
from ranger21 import Ranger21


from tqdm import tqdm
from itkwidgets import view
import random
monai.config.print_config()
#from sliding_window_inference_classes import sliding_window_inference_classes

MONAI version: 0.9.0
Numpy version: 1.22.3
Pytorch version: 1.10.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: af0e0e9f757558d144b655c63afcea3a4e0a06f5
MONAI __file__: /home/mmiv-ml/anaconda3/envs/sa_tumorseg22/lib/python3.9/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.9
Nibabel version: 4.0.1
scikit-image version: 0.19.3
Pillow version: 9.0.1
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.11.2
tqdm version: 4.64.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.1
pandas version: 1.4.2
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-rec

In [2]:
MAX_THREADS =2
sitk.ProcessObject.SetGlobalDefaultNumberOfThreads(MAX_THREADS)

In [3]:
seeds = 40961024
set_determinism(seed=seeds)
##np.random.seed(seeds) np random seed does not work here
!nvidia-smi

Tue Jul 19 14:54:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.172.01   Driver Version: 450.172.01   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-DGXS...  On   | 00000000:07:00.0 Off |                    0 |
| N/A   43C    P0    43W / 300W |    108MiB / 32505MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-DGXS...  On   | 00000000:08:00.0 Off |                    0 |
| N/A   47C    P0   120W / 300W |  18386MiB / 32508MiB |      0%      Default |
|       

In [4]:
#patch_size = (128, 128, 128)
spacing = (1.0, 1.0, 1.0)
os.environ["CUDA_VISIBLE_DEVICES"] ="0"
device = torch.device('cuda:0')
deviceName = 'cuda:0'

In [5]:
pd.set_option('display.max_colwidth', None)
data_rpath = '/home/mmiv-ml/data'


In [6]:
BraTS20Subjectsp1q19WithMetaDF  = pd.read_csv('assets/BraTS_TCGA_LGG_GBM_LGG_1p19qDFMoreMeta_N4CorrectLatDF.csv')
BraTS20Subjectsp1q19WithMetaDF

Unnamed: 0,BraTS2021,t1wPath,t1cwPath,t1cw_N4CorrectPath,t2wPath,t2w_N4CorrectPath,flairPath,segPath,brain_maskPath,brain_mask_ch2Path,...,ET_CoordX,ET_CoordY,ET_CoordZ,ED_CoordX,ED_CoordY,ED_CoordZ,NEC_CoordX,NEC_CoordY,NEC_CoordZ,is_merged_3
0,BraTS2021_00140,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_t1.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_t1ce.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_t1ce_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_t2.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_t2_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_flair.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_00140/BraTS2021_00140_seg.nii.gz,/raid/brats2021/T1wx4Brain_ROIs/BraTS2021_00140/ROI_BraTS2021_00140.nii.gz,/raid/brats2021/T1wx2Brain_ROIs_BraTS21_Training/BraTS2021_00140/BraTS2021_00140_BrainROIT1cwx2.nii.gz,...,168.685087,167.653671,79.886450,162.346647,173.396768,87.441763,168.083333,167.200000,78.066667,both
1,BraTS2021_01283,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_t1.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_t1ce.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_t1ce_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_t2.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_t2_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_flair.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01283/BraTS2021_01283_seg.nii.gz,/raid/brats2021/T1wx4Brain_ROIs/BraTS2021_01283/ROI_BraTS2021_01283.nii.gz,/raid/brats2021/T1wx2Brain_ROIs_BraTS21_Training/BraTS2021_01283/BraTS2021_01283_BrainROIT1cwx2.nii.gz,...,145.484701,134.678620,59.585174,152.096980,146.947874,73.214571,147.219848,134.146249,59.135090,both
2,BraTS2021_01528,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_t1.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_t1ce.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_t1ce_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_t2.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_t2_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_flair.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01528/BraTS2021_01528_seg.nii.gz,/raid/brats2021/T1wx4Brain_ROIs/BraTS2021_01528/ROI_BraTS2021_01528.nii.gz,/raid/brats2021/T1wx2Brain_ROIs_BraTS21_Training/BraTS2021_01528/BraTS2021_01528_BrainROIT1cwx2.nii.gz,...,77.531023,144.899230,82.371416,94.469503,140.150948,66.994481,71.698179,136.327992,62.269723,both
3,BraTS2021_01503,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_t1.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_t1ce.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_t1ce_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_t2.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_t2_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_flair.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01503/BraTS2021_01503_seg.nii.gz,/raid/brats2021/T1wx4Brain_ROIs/BraTS2021_01503/ROI_BraTS2021_01503.nii.gz,/raid/brats2021/T1wx2Brain_ROIs_BraTS21_Training/BraTS2021_01503/BraTS2021_01503_BrainROIT1cwx2.nii.gz,...,110.542553,73.074468,70.808511,107.090113,82.676138,76.029439,105.099771,65.077985,76.992437,both
4,BraTS2021_01453,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_t1.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_t1ce.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_t1ce_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_t2.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_t2_afterN4Correct.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_flair.nii.gz,/raid/brats2021/RSNA_ASNR_MICCAI_BraTS2021_TrainingData/BraTS2021_01453/BraTS2021_01453_seg.nii.gz,/raid/brats2021/T1wx4Brain_ROIs/BraTS2021_01453/ROI_BraTS2021_01453.nii.gz,/raid/brats2021/T1wx2Brain_ROIs_BraTS21_Training/BraTS2021_01453/BraTS2021_01453_BrainROIT1cwx2.nii.gz,...,86.031397,128.011381,67.940149,81.830275,119.381631,64.750845,83.705329,127.626959,65.589342,both
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
363,,,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-651/LGG-651_t1Gd.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-651/LGG-651_t1Gd_afterN4Correct.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-651/LGG-651_t2.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-651/LGG-651_t2_afterN4Correct.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/4Ensemble_LGG_1p19q_Infer/LGG-651/LGG-651_pred.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/T1wx2Brain_ROIs_LGG_1p19q/LGG-651/LGG-651_BrainROIT1cwx2.nii.gz,...,,,,150.500422,115.393242,68.151900,,,,both
364,,,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-658/LGG-658_t1Gd.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-658/LGG-658_t1Gd_afterN4Correct.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-658/LGG-658_t2.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-658/LGG-658_t2_afterN4Correct.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/4Ensemble_LGG_1p19q_Infer/LGG-658/LGG-658_pred.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/T1wx2Brain_ROIs_LGG_1p19q/LGG-658/LGG-658_BrainROIT1cwx2.nii.gz,...,,,,136.202745,166.640807,107.448810,140.095694,173.392344,97.712919,both
365,,,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-659/LGG-659_t1Gd.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-659/LGG-659_t1Gd_afterN4Correct.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-659/LGG-659_t2.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-659/LGG-659_t2_afterN4Correct.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/4Ensemble_LGG_1p19q_Infer/LGG-659/LGG-659_pred.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/T1wx2Brain_ROIs_LGG_1p19q/LGG-659/LGG-659_BrainROIT1cwx2.nii.gz,...,,,,131.580447,105.934087,122.249243,130.666667,116.000000,123.000000,both
366,,,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-660/LGG-660_t1Gd.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-660/LGG-660_t1Gd_afterN4Correct.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-660/LGG-660_t2.nii.gz,/raid/brats2021/LGG_1p19q_rawNifti/LGG_1p19q_BraTSLikeProcess_mnibet/LGG-660/LGG-660_t2_afterN4Correct.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/4Ensemble_LGG_1p19q_Infer/LGG-660/LGG-660_pred.nii.gz,,/raid/brats2021/LGG_1p19q_rawNifti/T1wx2Brain_ROIs_LGG_1p19q/LGG-660/LGG-660_BrainROIT1cwx2.nii.gz,...,,,,88.208202,185.948429,73.406706,88.338498,194.035995,76.130393,both


## Cearing a list of dictionaries in order to feed into Monai's Dataset
Keys:
- ***image:*** T1, T1c, T2, and flair image
- ***label:*** Segmented mask GT
- ***brain_mask:*** Whole brain area (brain area=1 and Non brain area=0)
- ***IDH_value:*** 1p19q_co_deletion class corresponding to the subject/images


### Creating/extracting 3 splits for cross validaion (3 cross validaion)

In [7]:
def get_train_val_test_df(BraTS20SubjectsIDHWithMetaDF):
    
    
    BraTS20SubjectsIDHTrainDCT = {}
    BraTS20SubjectsIDHValDCT = {}
    BraTS20SubjectsIDHTestDCT = {}
    
    aDCT = {'fold0':[1, 2, 3], 'fold1':[2, 3, 1], 'fold3': [3, 1, 2]}
    
    for indx, (akey, aval) in enumerate(aDCT.items()):
        
    
        BraTS20SubjectsIDHWithMetaDFTrain = BraTS20SubjectsIDHWithMetaDF.loc[BraTS20SubjectsIDHWithMetaDF['CV_group']==aval[0]]
        BraTS20SubjectsIDHWithMetaDFVal = BraTS20SubjectsIDHWithMetaDF.loc[BraTS20SubjectsIDHWithMetaDF['CV_group']==aval[1]]
        BraTS20SubjectsIDHWithMetaDFTest = BraTS20SubjectsIDHWithMetaDF.loc[BraTS20SubjectsIDHWithMetaDF['CV_group']==aval[2]]

        train_files = [{'image': (image_nameT1ce, image_nameT2), 'label': label_name, 'brain_mask':brain_mask, 'IDH_label':np.array(IDH_label_name).astype(np.float32)} 
                       for image_nameT1ce, image_nameT2, label_name, brain_mask, IDH_label_name 
                       in zip(BraTS20SubjectsIDHWithMetaDFTrain['t1cwPath'], BraTS20SubjectsIDHWithMetaDFTrain['t2wPath'], BraTS20SubjectsIDHWithMetaDFTrain['segPath'], \
                              BraTS20SubjectsIDHWithMetaDFTrain['brain_mask_ch2Path'], BraTS20SubjectsIDHWithMetaDFTrain['1p19q_co_deletion_bin'].values)]
        
        val_files =[{'image': (image_nameT1ce, image_nameT2), 'label': label_name, 'brain_mask':brain_mask, 'IDH_label':np.array(IDH_label_name).astype(np.float32)} 
                    for image_nameT1ce, image_nameT2, label_name, brain_mask, IDH_label_name 
                    in zip(BraTS20SubjectsIDHWithMetaDFVal['t1cwPath'], BraTS20SubjectsIDHWithMetaDFVal['t2wPath'],BraTS20SubjectsIDHWithMetaDFVal['segPath'],\
                           BraTS20SubjectsIDHWithMetaDFVal['brain_mask_ch2Path'], BraTS20SubjectsIDHWithMetaDFVal['1p19q_co_deletion_bin'].values)]
        
        test_files = [{'image': (image_nameT1ce, image_nameT2), 'label': label_name, 'brain_mask':brain_mask, 'IDH_label':np.array(IDH_label_name).astype(np.float32)} 
                      for image_nameT1ce, image_nameT2, label_name, brain_mask, IDH_label_name 
                      in zip(BraTS20SubjectsIDHWithMetaDFTest['t1cwPath'], BraTS20SubjectsIDHWithMetaDFTest['t2wPath'], BraTS20SubjectsIDHWithMetaDFTest['segPath'], \
                             BraTS20SubjectsIDHWithMetaDFTest['brain_mask_ch2Path'], BraTS20SubjectsIDHWithMetaDFTest['1p19q_co_deletion_bin'].values)]
        
        
        BraTS20SubjectsIDHTrainDCT[f'fold{indx}'] = copy.deepcopy(train_files)
        BraTS20SubjectsIDHValDCT[f'fold{indx}'] = copy.deepcopy(val_files)
        BraTS20SubjectsIDHTestDCT[f'fold{indx}'] = copy.deepcopy(test_files)
        
        
        
    return BraTS20SubjectsIDHTrainDCT, BraTS20SubjectsIDHValDCT, BraTS20SubjectsIDHTestDCT
        
        
        
BraTS20SubjectsIDHTrainDCT, BraTS20SubjectsIDHValDCT, BraTS20SubjectsIDHTestDCT =  get_train_val_test_df(BraTS20Subjectsp1q19WithMetaDF)    
        
        

# train_files_image = [(image_nameT1, image_nameT1ce, image_nameT2, image_nameFl) 
#                      for image_nameT1,image_nameT1ce, image_nameT2, image_nameFl 
#                      in zip(dfTrainLbl['t1wPath'], dfTrainLbl['t1cwPath'], dfTrainLbl['T2wPath'], dfTrainLbl['FlairPath'])]
# train_files_label = dfTrainLbl['segPath'].tolist()
# train_files_brain_mask = dfTrainLbl['brain_maskPath'].tolist()
# train_files_IDH_label = dfTrainLbl['IDH_value'].values.ravel().tolist()


In [8]:
# n_splits = 3
# #train_index = np.linspace(0, train_features.shape[0]-1, num = train_features.shape[0], dtype = np.uint16, endpoint=True)
# #partition_data = monai.data.utils.partition_dataset_classes(train_index, train_labels.values.ravel().tolist(), shuffle=True, num_partitions=n_splits) 
# #partition_data = monai.data.utils.partition_dataset_classes(train_files, dfTrainLbl['IDH_value'].values.ravel().tolist(), shuffle=True, num_partitions=n_splits)
# partition_data = monai.data.partition_dataset_classes(train_files, BraTS20SubjectsIDHWithMetaDF['IDH_value'].values.ravel().tolist(), shuffle=True, num_partitions=n_splits)
# print(len(partition_data), len(partition_data[0]), len(partition_data[1]), len(partition_data[2]))


# # val_folds = {}
# # train_folds = {}
# # flds = np.linspace(0, n_splits, num=n_splits, dtype = np.int8)
# # for cfold in range(n_splits):
# #     not_cfold = np.delete(flds, cfold)
# #     val_folds[cfold] = partition_data[cfold]
# # #     train_folds[cfold] = 
# # # sub_flds = flds[..., ~0]   
# # # sub_flds

# val_folds = {}
# train_folds = {}
# flds = np.linspace(0, n_splits, num=n_splits, dtype = np.uint8)
# for cfold in range(n_splits):
#     #val_folds[f"fold{cfold}"] = train_features.values[partition_data[cfold],:]
#     #train_folds[f"fold{cfold}"] = np.delete(train_features.values, partition_data[cfold], axis=0)
#     #not_cfold = np.delete(flds, cfold)
    
#     val_folds[f"fold{cfold}"] = copy.deepcopy(partition_data[cfold])
#     val_folds[f"fold{cfold}_IDH_label"] = copy.deepcopy([adct['IDH_label'].item() for adct in partition_data[cfold]])
#     train_folds_masks = [1]*n_splits
#     train_folds_masks[cfold] = 0
#     partition_data_non_cfold = list()
#     for aDctLstitem in compress(partition_data, train_folds_masks):
#         partition_data_non_cfold.extend(aDctLstitem)
        
        
#     train_folds[f"fold{cfold}"] = copy.deepcopy(partition_data_non_cfold)
#     train_folds[f"fold{cfold}_IDH_label"] = copy.deepcopy([adct['IDH_label'].item() for adct in partition_data_non_cfold])

# for i in range(n_splits):
#     print('val: ', len(val_folds[f'fold{i}']), 'train: ', len(train_folds[f'fold{i}']), '\n')

In [9]:
# len(train_folds["fold0"]), len(train_files)

# for i_cv in range(n_splits):
#     print('Training classes\n')
#     print(np.unique([train_folds[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(train_folds[f'fold{i_cv}']))], return_counts = True))
#     print('\nValidation classes\n')
#     print(np.unique([val_folds[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(val_folds[f'fold{i_cv}']))], return_counts = True))
#     print('#'*4, '\n\n')

In [10]:
n_splits = 3
dfFolds = BraTS20SubjectsIDHTrainDCT
for i_cv in range(n_splits):
    print('Training classes\n')
    print(np.unique([BraTS20SubjectsIDHTrainDCT[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(BraTS20SubjectsIDHTrainDCT[f'fold{i_cv}']))], return_counts = True))
    print('\nValidation classes\n')
    print(np.unique([BraTS20SubjectsIDHValDCT[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(BraTS20SubjectsIDHValDCT[f'fold{i_cv}']))], return_counts = True))
    #print('#'*4, '\n')
    print('\nTesting classes\n')
    print(np.unique([BraTS20SubjectsIDHTestDCT[f'fold{i_cv}'][i]['IDH_label'].item() for i in range(len(BraTS20SubjectsIDHTestDCT[f'fold{i_cv}']))], return_counts = True))
    print('#'*40, '\n\n')

Training classes

(array([0., 1.]), array([79, 43]))

Validation classes

(array([0., 1.]), array([80, 44]))

Testing classes

(array([0., 1.]), array([79, 43]))
######################################## 


Training classes

(array([0., 1.]), array([80, 44]))

Validation classes

(array([0., 1.]), array([79, 43]))

Testing classes

(array([0., 1.]), array([79, 43]))
######################################## 


Training classes

(array([0., 1.]), array([79, 43]))

Validation classes

(array([0., 1.]), array([79, 43]))

Testing classes

(array([0., 1.]), array([80, 44]))
######################################## 




In [11]:
#train_folds['fold0'][2]

***HistogramStandardization***

Implementing histogram standardization from [torchIO](https://github.com/fepegar/torchio) library

Bases: [torchio.transforms.preprocessing.intensity.normalization_transform.NormalizationTransform](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.preprocessing.intensity.NormalizationTransform)

Perform histogram standardization of intensity values.

Implementation of [New variants of a method of MRI scale standardization](https://ieeexplore.ieee.org/document/836373).

We can visit in [torchio.transforms.HistogramStandardization.train()]((https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization.train)) for more details.

PARAMETERS
landmarks – Dictionary (or path to a PyTorch file with .pt or .pth extension in which a dictionary has been saved) whose keys are image names in the subject and values are NumPy arrays or paths to NumPy arrays defining the landmarks after training with [torchio.transforms.HistogramStandardization.train()](https://torchio.readthedocs.io/transforms/preprocessing.html#torchio.transforms.HistogramStandardization.train).

Here, ***save_dir*** is a path where the trained histogram files for four channels (T1w, T1cw, T2w, and Flair), and trained model's weights will be saved

In [12]:
file_prefix = 'DynUnetCommon_OnlyBrats21_Full'
hist_save_dir = '/home/mmiv-ml/saruarlive/IDHRadiogenomics2022/assets'

    


hiseq_t1cnpyfile =  os.path.join(hist_save_dir, f"histeq_t1cw_{file_prefix}.npy")
t1cw_landmarks = (hiseq_t1cnpyfile if os.path.isfile(hiseq_t1cnpyfile) else \
                  tio.HistogramStandardization.train(image_t1cwpaths, output_path = hiseq_t1cnpyfile))



hiseq_t2npyfile = os.path.join(hist_save_dir, f"histeq_t2w_{file_prefix}.npy")
t2w_landmarks = (hiseq_t2npyfile if os.path.isfile(hiseq_t2npyfile) else \
                 tio.HistogramStandardization.train(image_t2wpaths, output_path = hiseq_t2npyfile))


landmarks_dict = {'t1cw': t1cw_landmarks, 't2w': t2w_landmarks}
#tio_landmarktransform = tio.HistogramStandardization(landmarks_dict)
landmarks_dict

{'t1cw': '/home/mmiv-ml/saruarlive/IDHRadiogenomics2022/assets/histeq_t1cw_DynUnetCommon_OnlyBrats21_Full.npy',
 't2w': '/home/mmiv-ml/saruarlive/IDHRadiogenomics2022/assets/histeq_t2w_DynUnetCommon_OnlyBrats21_Full.npy'}

In [13]:
file_prefix = 'AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand'
#file_prefix = 'AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_WeightSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand'
savedirname = 'DynUNetVariants_TCGA'
save_dir = os.path.join('/raid/brats2021/pthTCGA_1p19q_CoDeletion', savedirname)
if not os.path.exists(save_dir):
    os.makedirs(save_dir)
    

## Classes for Monai/Pytorch compose class

A class to rearrange label mask array as 
- [0, :, :, :] = the multi class mask (class labels: 0 (background), 1, 2, and 4)
- [1, :, :, :] = the whole tumor mask (class labels: 0 (background), and 1)\
Not using here

In [14]:
class ConvertToMultiChannelPlusWT(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            
            d[key]=np.squeeze(d[key], axis = 0) # Converting 1, H, W, D to H, W, D
            result.append(d[key])

            # merge labels 1, 2 and 4 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1
                )
            )
            ## merge label 1 and label 4 to construct TC
            #result.append(np.logical_or(d[key] == 1, d[key] == 4))
            ## label 4 is ET
            #result.append(d[key] == 4)
            d[key] = np.stack(result, axis=0).astype(np.uint8)
        return d

#### Define a new transform to convert brain tumor labels
Here we convert the multi-classes labels into multi-labels segmentation task in One-Hot format.\
Not using here

In [15]:
class ConvertToMultiChannelBasedOnBratsClassesd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            
            d[key]=np.squeeze(d[key], axis = 0) # Converting 1, H, W, D to H, W, D

            # merge labels 1, 2 and 4 to construct WT
            result.append(
                np.logical_or(
                    np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1
                )
            )
            # merge label 1 and label 4 to construct TC
            result.append(np.logical_or(d[key] == 1, d[key] == 4))
            # label 4 is ET
            result.append(d[key] == 4)
            d[key] = np.stack(result, axis=0).astype(np.float32)
        return d

In [16]:
class ConvertToIDHLabel2WTd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __init__(self, keys: KeysCollection, IDH_label_key:str = 'IDH_label') -> None:

        super().__init__(keys)
        self.IDH_label_key = IDH_label_key
       
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # merge labels 1, 2 and 4 to construct WT
            #WT = np.logical_or(np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1).astype(np.uint8)
            result = []
            WT = np.squeeze(d[key], axis = 0)
            if d[self.IDH_label_key].item() == 1:
                WT=np.multiply(WT, 2)
                #WT = 2*WT
            
            result.append(WT==1)
            result.append(WT==2)
            
            d[key] = np.stack(result, axis = 0).astype(np.float32)
            
    
        return d

#### A class to add new key having the tumor mask (GT) to the existing data dictionary
The new key, ***label_mask*** will have the same dimension (size: 4,x,x,x) with image array (size: 4,x,x,x)\
Using in ***compose*** class

In [17]:
class Convert2WTd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            output_classes = 2
            
            # merge labels 1, 2 and 4 to construct WT
            WT = np.logical_or(np.logical_or(d[key] == 2, d[key] == 4), d[key] == 1).astype(np.float32)
            d[f'{key}'] = WT
         
            WT = np.expand_dims(ndimage.binary_dilation(np.squeeze(WT, axis=0), iterations=2), axis = 0)
            #WT = np.stack(tuple([ndimage.binary_dilation((np.squeeze(WT, axis = 0)==_k).astype(WT.dtype), iterations=5).astype(WT.dtype) for _k in range(output_classes)]), axis = 0)
            d[f'{key}_mask'] = WT
            d[f'{key}_mask_meta_dict'] = copy.deepcopy(d[f"{key}_meta_dict"])
            
        
        return d

In [18]:
class SpatialCropWTCOMd(MapTransform):
    
    """
     GD-enhancing tumor (ET — label 4), 
     the peritumoral edema (ED — label 2), and 
     the necrotic and non-enhancing tumor core (NCR/NET — label 1)

    """
    def __init__(self, keys: KeysCollection, roi_size, COM_label_key:str = 'label_mask') -> None:

        super().__init__(keys)
        self.COM_label_key = COM_label_key
        self.roi_size = roi_size
    
    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            result = []
            Coms = np.array([ndimage.measurements.center_of_mass(lbl) for lbl in list(d[self.COM_label_key])])
            Coms[np.isnan(Coms)] = 70
            Coms=Coms[0].astype(np.uint16).tolist()
        
            sc_com= SpatialCrop(roi_center= Coms, roi_size=self.roi_size)
            d[key] = sc_com(d[key])
                
        return d

In [19]:
class ConcatLabelBrainmaskd(MapTransform):
    """
          we do not need labels as it is a generative problem
    """
    
    def __init__(self, keys: KeysCollection, image_key = 'image', label_key = 'label', 
                 brain_mask_key = 'brain_mask') -> None:
        
        super().__init__(keys)
        self.brain_mask_key = brain_mask_key
        self.image_key = image_key
        self.label_key = label_key
    
    
    def __call__(self, data):
     
        d = dict(data)
        #d[self.image_key] = np.concatenate((d[self.image_key], d[self.label_key], d[self.brain_mask_key][0:1]), axis = 0)
        d[self.image_key] = np.concatenate((d[self.image_key], d[self.label_key][0:1]), axis = 0)
        
        return d

### Implementing channelwise histogram normalization
(Not using here)

In [20]:
class HistogramNormalizeChannelWised(MapTransform):
    """
          we do not need labels as it is a generative problem
    """
    
    def __init__(self, keys: KeysCollection, brain_mask_key = 'brain_mask', min=0, max=255) -> None:
        
        super().__init__(keys)
        self.brain_mask_key = brain_mask_key
        self.histnorms = HistogramNormalize(num_bins=256, min=min, max=max)
    
    
    def __call__(self, data):
     
        d = dict(data)
        for key in self.keys:
            nchnl = d[key].shape[0]
            for ch in range(nchnl):
                d[key][ch] = self.histnorms(d[key][ch], d[self.brain_mask_key][ch])
        
        return d

In [21]:
class adapter_tio2monai(MapTransform):
    """
    # wrapper for tio affine transformation
    """

    def __init__(
        self,
        #keys: KeysCollection,
        mode = 'train',
        tiofn=None,
        **mykwargs,
    ) -> None:
        
        """
           Wrapper from torchio to monai
        """
        
        #super().__init__(keys)
        #super().__init__(**mykwargs)
        self.tiofn = tiofn(**mykwargs)
        self.mode = mode
        self.mykwargs = mykwargs
        
    
    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)
        
        if self.mode =='train':
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=d["image"], affine = d['image_meta_dict']['affine']),  # this class is new
                label=tio.LabelMap(tensor=d["label"], affine = d['label_meta_dict']['affine']),
                brain_mask = tio.LabelMap(tensor=d["brain_mask"], affine = d['brain_mask_meta_dict']['affine'])
            )
            transformed = self.tiofn(subject)
            d["image"] = transformed["image"].numpy()
            d["label"] = transformed["label"].numpy()
            d["brain_mask"] = transformed["brain_mask"].numpy()
            d["image_meta_dict"]['affine'] = transformed["image"].affine.copy()
            d["label_meta_dict"]['affine'] = transformed["label"].affine.copy()
            d["brain_mask_meta_dict"]['affine'] = transformed["brain_mask"].affine.copy()
            
            
        elif self.mode =='infer':
            
            subject = tio.Subject(
                image=tio.ScalarImage(tensor=d["image"], affine = d['image_meta_dict']['affine']),  # this class is new
                brain_mask = tio.LabelMap(tensor=d["brain_mask"], affine = d['brain_mask_meta_dict']['affine'])
            )
            
            transformed = self.tiofn(subject)
            d["image"] = transformed["image"].numpy()
            d["brain_mask"] = transformed["brain_mask"].numpy()
            d["image_meta_dict"]['affine'] = transformed["image"].affine.copy()
            d["brain_mask_meta_dict"]['affine'] = transformed["brain_mask"].affine.copy()
        
        else:
            print('Please select mode either train or infer')

        return d

In [22]:
class adapter_tioChannelWise2monai(MapTransform):
    """
    # wrapper for tio affine transformation
    """

    def __init__(
        self,
        #keys: KeysCollection,
        mode = 'train',
        tiofn=None,
        **mykwargs,
    ) -> None:
        
        """
           Wrapper from torchio to monai
        """
        
        #super().__init__(keys)
        #super().__init__(**mykwargs)
        self.tiofn = tiofn(**mykwargs)
        self.mode = mode
        self.mykwargs = mykwargs
        
    
    def __call__(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
        d = dict(data)
        if self.mode =='train':
            subject = tio.Subject(
               
                t1cw=tio.ScalarImage(tensor=d["image"][0:1,...], affine = d['image_meta_dict']['affine']),
                t2w=tio.ScalarImage(tensor=d["image"][1:2,...], affine = d['image_meta_dict']['affine']),
               
                label=tio.LabelMap(tensor=d["label"], affine = d['label_meta_dict']['affine']),
                brain_mask = tio.LabelMap(tensor=d["brain_mask"], affine = d['brain_mask_meta_dict']['affine'])
            )
            transformed = self.tiofn(subject)
            d["image"] = np.concatenate([transformed["t1cw"].numpy(), transformed["t2w"].numpy()], axis = 0)
            d["label"] = transformed["label"].numpy()
            d["brain_mask"] = transformed["brain_mask"].numpy()
            d["image_meta_dict"]['affine'] = transformed["t1cw"].affine.copy()
            d["label_meta_dict"]['affine'] = transformed["label"].affine.copy()
            d["brain_mask_meta_dict"]['affine'] = transformed["brain_mask"].affine.copy()
            
        
        elif self.mode =='infer':
            
            subject = tio.Subject(
                t1cw=tio.ScalarImage(tensor=d["image"][0:1,...], affine = d['image_meta_dict']['affine']),
                t2w=tio.ScalarImage(tensor=d["image"][1:2,...], affine = d['image_meta_dict']['affine']),
                brain_mask = tio.LabelMap(tensor=d["brain_mask"], affine = d['brain_mask_meta_dict']['affine'])
            )
            
            transformed = self.tiofn(subject)
            d["image"] = np.concatenate([transformed["t1cw"].numpy(), transformed["t2w"].numpy()], axis = 0)
    
            d["brain_mask"] = transformed["brain_mask"].numpy()
            d["image_meta_dict"]['affine'] = transformed["t1cw"].affine.copy()
            d["brain_mask_meta_dict"]['affine'] = transformed["brain_mask"].affine.copy()
        
        else:
            print('Please select mode either train or infer')

        return d

### Defining traning and validation transforms

- Training transform includes:
    - LoadImaged
    - EnsureChannelFirstd
    - HistogramNormalizeChannelWised: Histogram normalization channel wise (custom class defined aboove)
    - NormalizeIntensityd
    - RandRotate90d
    - RandZoomd
    - ConvertToIDHLabel2WTd (custom class defined above)
    - CropForegroundd: Cropping foreground based on the whole tumor mask (WT GT)
    - RandCropByPosNegLabeld: Randomly cropping 8 patches based on 3: 1 (WT : non tumor tissus) ratio
    - RandGaussianNoised
    - RandStdShiftIntensityd
    - RandFlipd
    
- validation transform includes:
    - LoadImaged
    - EnsureChannelFirstd
    - HistogramNormalizeChannelWised: Histogram normalization channel wise (custom class defined aboove)
    - NormalizeIntensityd
    - ConvertToIDHLabel2WTd (custom class defined above)
    
Most of transfroms are implemented using [Monai](https://docs.monai.io/en/latest/transforms.html#dictionary-transforms) library



In [23]:
def threshold_foreground(x):
    # threshold at not equal to 0
    return x == 1


#Resized(keys=keys[0:-1], spatial_size=patch_size, mode = ('area','nearest','nearest')),

# ConvertToMultiChannelBasedOnBratsClassesd(keys = ['label']),
# ConcatLabelBrainmaskd(keys = None, image_key = 'image', label_key = 'label', brain_mask_key = 'brain_mask'),
# CropForegroundd(keys=keys[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
# SpatialPadd(keys=keys[0:-1], spatial_size=patch_size),

# RandGaussianSmoothd(
#     keys=["image"],
#     sigma_x=(0.5, 1.15),
#     sigma_y=(0.5, 1.15),
#     sigma_z=(0.5, 1.15),
#     prob=0.3,
# ),

# RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.3),
# RandGibbsNoised(keys=["image"], prob=0.3, alpha=(0.1, 0.5), as_tensor_output=False),

          
# DataStatsd(keys=keys[0:-1], prefix="Data", data_type=True, data_shape=True, value_range=True, data_value=False),
#DataStatsd(keys=keysExt[0:-1], prefix="Data", data_type=True, data_shape=True, value_range=True, data_value=False),


def get_task_transforms(patch_size, task='train', pos_sample_num=1, neg_sample_num=1, num_samples=1, num_classes = 2, cratio = [1, 3]):
    
    #spatial_size=(30, 30, 30)
    orig_img_size = (240, 240, 155)

    if task=='train':
        keys = ["image", 'label', 'brain_mask', 'IDH_label']
        keysExt = ["image", 'label', 'brain_mask', 'label_mask', 'IDH_label']
        
        all_transform = [
            
            LoadImaged(keys=keys[0:-1], reader = "NibabelReader"),
            EnsureChannelFirstd(keys=keys[0:-1]),
            adapter_tioChannelWise2monai(tiofn = tio.HistogramStandardization, mode = 'train', landmarks = landmarks_dict),
            NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
            #HistogramNormalizeChannelWised(keys = ['image'], brain_mask_key = 'brain_mask', min = 1, max = 65535),
            
            adapter_tio2monai(tiofn = tio.OneOf, transforms={tio.RandomAffine(scales=(0.9, 1.2),degrees=15, isotropic=True): 0.6, \
                                            tio.RandomElasticDeformation(): 0.4}, p = 0.4), #p = 0.4
            #adapter_tio2monai(tiofn = tio.RandomAffine, scales=(0.9, 1.2), degrees=15, isotropic=True, p = 0.2), 
            

            
            #ConvertToIDHLabel2WTd(keys = ["label"]),
            CopyItemsd(keys=["label"], names=["label_mask"], times=1),
            Convert2WTd(keys = ["label"]),
            ConvertToIDHLabel2WTd(keys = ["label"], IDH_label_key = 'IDH_label'),
            CropForegroundd(keys=keysExt[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            #Spacingd(keys = keysExt[0:-1], pixdim=(1.25, 1.25, 1.25), mode = ('bilinear','nearest', 'nearest', 'nearest')),
            RandZoomd(
                keys=keysExt[0:-1],
                min_zoom=0.9,
                max_zoom=1.1,
                mode=("trilinear", "nearest", "nearest", "nearest"),
                align_corners=(True, None, None, None),
                prob=0.15,
            ),
           
            #ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = (128, 160, 128)),
            RandGaussianNoised(keys=["image"], std=0.01, prob=0.15),
            RandGaussianSmoothd(
                keys=["image"],
                sigma_x=(0.5, 1.15),
                sigma_y=(0.5, 1.15),
                sigma_z=(0.5, 1.15),
                prob=0.15,
            ),
            RandStdShiftIntensityd(keys = ["image"], factors=0.3, nonzero=True, channel_wise=True, prob=0.15), 
            RandScaleIntensityd(keys=["image"], factors=0.3, prob=0.15),
            RandGibbsNoised(keys=["image"], prob=0.15, alpha=(0.1, 0.5), as_tensor_output=False),
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=0),
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=1),
            RandFlipd(keys=keysExt[0:-1], prob=0.5, spatial_axis=2),
                        
            CropForegroundd(keys=keysExt[0:-1], source_key="label_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord', margin=2),
            #ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = patch_size),
            SpatialPadd(keys = keysExt[0:-1], spatial_size = patch_size),
#             RandCropByLabelClassesd(
#                 keys=keysExt[0:-1],            
#                 label_key = "label_mask",
#                 spatial_size = patch_size,    
#                 ratios= cratio,
#                 num_classes=num_classes,              
#                 num_samples=num_samples,
#                 image_key="brain_mask",
#                 image_threshold=0.0,
#                 #allow_smaller = True,
#             ),
            
            RandCropByPosNegLabeld(
                keys=keysExt[0:-1],
                label_key="label_mask",
                spatial_size=patch_size,
                pos=pos_sample_num,
                neg=neg_sample_num,
                num_samples=num_samples,
                image_key="brain_mask",
                image_threshold=0.,
            ),
                        
            #SpatialCropWTCOMd(keys=keysExt[0:-1], roi_size=patch_size, COM_label_key = "label_mask"),
            SpatialPadd(keys = keysExt[0:-1], spatial_size = patch_size),
        
            #CastToTyped(keys=keysExt, dtype=(np.float32, np.uint8, np.uint8, np.uint8, np.float32)),
            CastToTyped(keys=keysExt, dtype=(np.float32, np.float32, np.float32, np.float32, np.float32)),
            #dtype = (torch.float32, torch.float32, torch.float32, torch.float32, torch.float32)
            #ToTensord(keys=keysExt),

#             #EnsureTyped(keys=keys, data_type = "tensor"),
#             #ToDeviced(keys = keys, device = deviceName),
        ]
        
        
        
    elif task=='validation':
    
        keys = ["image", 'label', 'brain_mask', 'IDH_label']
        keysExt = ["image", 'label', 'brain_mask', 'label_mask', 'IDH_label']
        all_transform = [
            
            LoadImaged(keys=keys[0:-1], reader = "NibabelReader"),
            EnsureChannelFirstd(keys=keys[0:-1]),
            adapter_tioChannelWise2monai(tiofn = tio.HistogramStandardization, mode = 'train', landmarks = landmarks_dict),
            #HistogramNormalizeChannelWised(keys = ['image'], brain_mask_key = 'brain_mask', min = 1, max = 65535),
            NormalizeIntensityd(keys=["image"], nonzero=True, channel_wise=True),
            #ConvertToIDHLabel2WTd(keys = ["label"]),
           
            CopyItemsd(keys=["label"], names=["label_mask"], times=1),
            Convert2WTd(keys = ["label"]),
            ConvertToIDHLabel2WTd(keys = ["label"], IDH_label_key = 'IDH_label'),
            CropForegroundd(keys=keysExt[0:-1], source_key="brain_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord'),
            
            #Spacingd(keys = keysExt[0:-1], pixdim=(1.25, 1.25, 1.25), mode = ('bilinear','nearest', 'nearest', 'nearest')),
            #ResizeWithPadOrCropd(keys = keysExt[0:-1], spatial_size = (128, 160, 128)),
            CropForegroundd(keys=keysExt[0:-1], source_key="label_mask", select_fn = threshold_foreground, start_coord_key='fg_start_coord', end_coord_key='fg_end_coord', margin=2), 
            #SpatialCropWTCOMd(keys=keysExt[0:-1], roi_size=patch_size, COM_label_key = "label_mask"),
            SpatialPadd(keys = keysExt[0:-1], spatial_size = patch_size),
#             RandCropByPosNegLabeld(
#                 keys=keysExt[0:-1],
#                 label_key="label_mask",
#                 spatial_size=patch_size,
#                 pos=pos_sample_num,
#                 neg=neg_sample_num,
#                 num_samples=num_samples,
#                 image_key="brain_mask",
#                 image_threshold=0,
#             ),
        
#             RandCropByLabelClassesd(
#                 keys=keysExt[0:-1],            
#                 label_key = "label_mask",
#                 spatial_size = patch_size,    
#                 ratios= cratio,
#                 num_classes=num_classes,              
#                 num_samples=num_samples,
#                 image_key="brain_mask",
#                 image_threshold=0.0,
#                 #allow_smaller = True,
#             ),

            #CastToTyped(keys=keysExt, dtype=(np.float32, np.uint8, np.uint8, np.uint8, np.float32)),
            CastToTyped(keys=keysExt, dtype=(np.float32, np.float32, np.float32, np.float32, np.float32)),
            #ToTensord(keys=keysExt, dtype=(torch.float32, torch.float32, torch.float32, torch.float32, torch.float32)),
            #EnsureTyped(keys=keys, data_type = "tensor"),
            #ToDeviced(keys = keys, device = deviceName),
        ]
        
    else:
        print('print task either train or validation here')


    return Compose(all_transform)

# def create_cachedir(cache_dir):
#     if not os.path.exists(cache_dir):
#         os.makedirs(cache_dir)
#     return 1

### Section for visual inspection and debugging

In [24]:
#patch_size=(128, 160, 128)
#patch_size=(64, 80, 64)
patch_size=(32, 32, 32)
train_transforms = get_task_transforms(patch_size, task='train', pos_sample_num=3, neg_sample_num=1, num_samples=16, cratio = [1, 3])
val_transforms = get_task_transforms(patch_size, task='validation', pos_sample_num=1, neg_sample_num=1, num_samples=1, cratio = [1, 3])
len(train_transforms), len(val_transforms)

(23, 11)

In [25]:
# investi_files = BraTS20SubjectsIDHTrainDCT['fold0']
# all_train_dataset = monai.data.Dataset(data=investi_files, transform=train_transforms)
# all_train_dataset[0][0]['label_meta_dict']

In [26]:
# sub_patch = all_train_dataset[15][3]
# view(image = sub_patch['image'][0], label_image = sub_patch['label_mask'][0])

In [27]:
# vpatch = all_train_dataset[0][14]
# view(image = vpatch['image'][1], label_image = vpatch['label'][1])


#np.unique(all_train_dataset[0]['label'][1], return_counts = True)

In [28]:
#view(image = vpatch['image'][1], label_image = vpatch['label_mask'][0])

In [29]:
#investifiles = copy.deepcopy(BraTS20SubjectsIDHTrainDCT["fold0"])

# for i in range(len(investifiles)):
#     investifiles[i]['IDH_label'] = investifiles[i]['IDH_label'].astype(np.float32) 
# all_train_dataset = monai.data.Dataset(data=investifiles, transform=train_transforms)
#all_train_dataset[10][3]['IDH_label']

In [30]:
# for cfold in tqdm(range(len(BraTS20SubjectsIDHTrainDCT))):
#     all_train_dataset = monai.data.Dataset(data=copy.deepcopy(BraTS20SubjectsIDHTrainDCT[f"fold{cfold}"]), transform=train_transforms)
#     dls = monai.data.DataLoader(all_train_dataset, batch_size=8, shuffle=False, collate_fn=list_data_collate)
#     # abatch = next(iter(dls))
#     # print(abatch['image'].shape)
#     # print(abatch['label'].shape)
#     for epoch in range(5):
#         for abatch in dls:
#             print(abatch['image'].shape)
#             print(abatch['label'].shape)
#             print(abatch['IDH_label'].shape)
#             print(abatch['IDH_label'], '\n', '###'*10, '\n')

In [31]:
#few_train_dataset = Dataset(data=train_files[0:10], transform=train_transforms)
# asub = few_train_dataset[5] 
# view(image = asub['image'][3].cpu(), label_image = asub['label_mask'][0].cpu())


### Few investigation

In [34]:
#afold_train_dataset[200]['IDH_label'], afold_train_dataset[200]['label'].unique(return_counts = True)

## Defining model

In [35]:
class DownBasicBlock(nn.Module):

        
    def __init__(self, input_channels, output_channels,
         conv_op=nn.Conv3d, conv_kwargs=None,
         norm_op=nn.BatchNorm3d, norm_op_kwargs=None,
         dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
         nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super(DownBasicBlock, self).__init__()
        
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs3x3_0 = {'kernel_size': 3, 'stride': 2, 'padding': 1, 'dilation': 1, 'bias': True}
            conv_kwargs3x3_1 = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
        else:
            conv_kwargs3x3_0 = conv_kwargs
            conv_kwargs3x3_1 = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
            
            

        self.nonlin = nonlin
        self.nonlin_kwargs = nonlin_kwargs

        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        
        self.conv_op = conv_op
        self.conv_kwargs3x3_0 = conv_kwargs3x3_0
        self.conv_kwargs3x3_1 = conv_kwargs3x3_1
        self.conv_kwargs1x1 = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1, 'bias': True}
        
        self.norm_op = norm_op
        self.norm_op_kwargs = norm_op_kwargs
        
        

        self.conv3x3_0 = self.conv_op(input_channels, output_channels, **self.conv_kwargs3x3_0)
        self.instnorm3x3_0 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        self.conv3x3_1 = self.conv_op(output_channels, output_channels, **self.conv_kwargs3x3_1)
        self.instnorm3x3_1 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        #self.conv1x1 = self.conv_op(input_channels, output_channels, **self.conv_kwargs1x1)
        #self.instnorm1x1 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs['p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        
        self.lrelu = self.nonlin(**self.nonlin_kwargs)
  

    def forward(self, x):
        
        out = self.conv3x3_0(x)
        out = self.instnorm3x3_0(out)
        out = self.lrelu(out)

        out = self.conv3x3_1(out)
        out = self.instnorm3x3_1(out)
        out = self.lrelu(out)
        #print(out.shape)
        return out

In [36]:
class UpBasicBlock(nn.Module):

        
    def __init__(self, input_channels, output_channels,
         conv_op=nn.Conv3d, conv_kwargs=None,
         norm_op=nn.BatchNorm3d, norm_op_kwargs=None,
         dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
         nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super(UpBasicBlock, self).__init__()
        
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs3x3_0 = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
            conv_kwargs3x3_1 = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
        else:
            conv_kwargs3x3_0 = conv_kwargs
            conv_kwargs3x3_1 = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
            
            
        

        self.nonlin = nonlin
        self.nonlin_kwargs = nonlin_kwargs

        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        
        self.conv_op = conv_op
        self.conv_kwargs3x3_0 = conv_kwargs3x3_0
        self.conv_kwargs3x3_1 = conv_kwargs3x3_1
        self.conv_kwargs1x1 = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1, 'bias': True}
        
        self.norm_op = norm_op
        self.norm_op_kwargs = norm_op_kwargs
        
        

        self.conv3x3_0 = self.conv_op(input_channels, output_channels, **self.conv_kwargs3x3_0)
        self.instnorm3x3_0 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        self.conv3x3_1 = self.conv_op(output_channels, output_channels, **self.conv_kwargs3x3_1)
        self.instnorm3x3_1 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        #self.conv1x1 = self.conv_op(input_channels, output_channels, **self.conv_kwargs1x1)
        #self.instnorm1x1 = self.norm_op(output_channels, **self.norm_op_kwargs)
        
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs['p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        
        self.lrelu = self.nonlin(**self.nonlin_kwargs)
  

    def forward(self, x):
        
        out = self.conv3x3_0(x)
        out = self.instnorm3x3_0(out)
        out = self.lrelu(out)

        out = self.conv3x3_1(out)
        out = self.instnorm3x3_1(out)
        out = self.lrelu(out)
        #print(out.shape)
        return out

In [37]:
class ConvDropoutNormNonlin(nn.Module):
    """
    fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
    """

    def __init__(self, input_channels, output_channels,
                 conv_op=nn.Conv3d, conv_kwargs=None,
                 norm_op=nn.BatchNorm3d, norm_op_kwargs=None,
                 dropout_op=nn.Dropout3d, dropout_op_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super(ConvDropoutNormNonlin, self).__init__()
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if dropout_op_kwargs is None:
            dropout_op_kwargs = {'p': 0.0, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.dropout_op = dropout_op
        self.dropout_op_kwargs = dropout_op_kwargs
        self.norm_op_kwargs = norm_op_kwargs
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op
        self.norm_op = norm_op

        self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
        if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
            'p'] > 0:
            self.dropout = self.dropout_op(**self.dropout_op_kwargs)
        else:
            self.dropout = None
        self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
        self.lrelu = self.nonlin(**self.nonlin_kwargs)

    def forward(self, x):
        x = self.conv(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return self.lrelu(self.instnorm(x))
    

class ConvDropoutNonlinNorm(ConvDropoutNormNonlin):
    def forward(self, x):
        x = self.conv(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return self.instnorm(self.lrelu(x))


class ConvNonlinSeg(nn.Module):
    """
    fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
    """

    def __init__(self, input_channels, output_channels,
                 conv_op=nn.Conv3d, conv_kwargs=None,
                 nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        super(ConvNonlinSeg, self).__init__()
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}

        if conv_kwargs is None:
            conv_kwargs = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1, 'bias': False}

        self.nonlin_kwargs = nonlin_kwargs
        self.nonlin = nonlin
        self.conv_kwargs = conv_kwargs
        self.conv_op = conv_op


        self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)

        #self.lrelu = self.nonlin(**self.nonlin_kwargs)

    def forward(self, x):
        x = self.conv(x)
        #x = self.conv(self.lrelu(x))

        return x

In [38]:
class UEncoder(nn.Module):
    def __init__(self, num_input_channels, encodFilters, norm_op=nn.InstanceNorm3d):
        super(UEncoder, self).__init__()
        
        
        init_contexts_conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}        
                
        """ First block """
        
        self.context0_encod =  DownBasicBlock(num_input_channels, encodFilters[0], conv_kwargs = init_contexts_conv_kwargs, norm_op=norm_op) 
        
        
        self.context1_encod =  DownBasicBlock(encodFilters[0], encodFilters[1], norm_op=norm_op) 
  
        
        self.context2_encod =  DownBasicBlock(encodFilters[1], encodFilters[2], norm_op=norm_op) 
             
        self.context3_encod =  DownBasicBlock(encodFilters[2], encodFilters[3], norm_op=norm_op) 
      
        self.context4_encod =  DownBasicBlock(encodFilters[3], encodFilters[4], norm_op=norm_op) 
      
        #self.context5_encod =  DownBasicBlock(encodFilters[4], encodFilters[5], norm_op=norm_op)
        
        #self.context6_encod =  DownBasicBlock(encodFilters[5], encodFilters[6], norm_op=norm_op)
        #self.reduced_pool = nn.MaxPool3d(3, stride=2, padding = 1)
       
                    
                                                              
    def forward(self, ax):
        
        ax = self.context0_encod(ax)
        axdecod0 = ax
        #ax=self.reduced_pool(ax)
        
    
        ax = self.context1_encod(ax)
        axdecod1 = ax
        #ax=self.reduced_pool(ax)
        
        ax = self.context2_encod(ax)
        axdecod2 = ax
        #ax=self.reduced_pool(ax)
        
        ax = self.context3_encod(ax)
        axdecod3 = ax
        #ax=self.reduced_pool(ax)
        
        ax = self.context4_encod(ax)
        #axdecod4 = ax
        #ax=self.reduced_pool(ax)
        
        #ax = self.context5_encod(ax)
        #axdecod5 = ax
        
        #ax = self.context6_encod(ax)

        return ax, axdecod3, axdecod2, axdecod1, axdecod0

In [39]:
class Attention_block(nn.Module):
    
    def __init__(self, F_g, F_l, F_int,
         conv_op=nn.Conv3d, conv_kwargs=None,
         norm_op=nn.InstanceNorm3d, norm_op_kwargs=None,
         nonlin=nn.LeakyReLU, nonlin_kwargs=None):
        #norm_op=nn.BatchNorm3d
        super(Attention_block, self).__init__()
        
        if nonlin_kwargs is None:
            nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
        if norm_op_kwargs is None:
            norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
        if conv_kwargs is None:
            conv_kwargs1x1 = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1, 'bias': True}
            
        self.W_g = nn.Sequential(
            conv_op(F_g, F_int, **conv_kwargs1x1),
            norm_op(F_int, **norm_op_kwargs)
            )
        
        self.W_x = nn.Sequential(
            conv_op(F_l, F_int, **conv_kwargs1x1),
            norm_op(F_int, **norm_op_kwargs)
        )

        self.psi = nn.Sequential(
            conv_op(F_int, 1, **conv_kwargs1x1),
            norm_op(1, **norm_op_kwargs),
            nn.Sigmoid()
        )
        
        self.lrelu = nonlin(**nonlin_kwargs)
        
    def forward(self, gA, xA):
        gA1 = self.W_g(gA)
        xA1 = self.W_x(xA)
        psi = self.lrelu(gA1+xA1)
        psi = self.psi(psi)

        return xA*psi

In [40]:
class DynUOneEncodAttn(nn.Module):
    def __init__(self, num_classes=4, num_input_channels=4, base_filters=32, dropout_p=0.0,
                 final_nonlin=None, leakiness=1e-2, conv_bias=True, inst_norm_affine=True,
                 lrelu_inplace=True, do_ds=True):
        super(DynUOneEncodAttn, self).__init__()

        self.do_ds = do_ds
        self.lrelu_inplace = lrelu_inplace
        self.inst_norm_affine = inst_norm_affine
        self.conv_bias = conv_bias
        self.leakiness = leakiness
        self.final_nonlin = final_nonlin
        norm_op = nn.BatchNorm3d

        
        nonsymetry_loc_upTrans_kwargs = {'kernel_size': (2, 3, 2), 'stride': (2, 1, 2), 'padding': 0, 'dilation': 1, 'bias': False}
        loc_upTrans_kwargs = {'kernel_size': 2, 'stride': 2, 'padding': 0, 'dilation': 1, 'bias': False}
        loc_conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
        loc_seg_conv_kwargs = {'kernel_size': 1, 'stride': 1, 'padding': 0, 'dilation': 1, 'bias': False}
        
        
        
        dropout_op_kwargs = {'p': 0.0, 'inplace': True}
        
        #encodFilters = [base_filters, 32, 64, 128, 160, 160]
        encodFilters = [base_filters, 64, 128, 256, 320]
        #encodFilters = [base_filters, 128, 256, 384, 512]
        
        
        
        
        self.get_skip_encodcontext_mtch4 = UEncoder(num_input_channels = num_input_channels, encodFilters = encodFilters, norm_op=norm_op)
#        self.get_skip_encodcontext_mt0 = UEncoder(num_input_channels = num_input_channels-3, encodFilters = encodFilters, norm_op=norm_op)
#         self.get_skip_encodcontext_mt1 = UEncoder(num_input_channels = num_input_channels-1, encodFilters = encodFilters)
#         self.get_skip_encodcontext_mt2 = UEncoder(num_input_channels = num_input_channels-1, encodFilters = encodFilters)
#         self.get_skip_encodcontext_mt3 = UEncoder(num_input_channels = num_input_channels-1, encodFilters = encodFilters)
        
        ch_mult = 1
        #loc_upTrans_kwargs['kernel_size'], loc_upTrans_kwargs['stride'], bias = loc_upTrans_kwargs['bias']
        self.upTrans1 =  nn.ConvTranspose3d(encodFilters[-1]*ch_mult, encodFilters[-2]*ch_mult, **loc_upTrans_kwargs)
        self.att1 = Attention_block(encodFilters[-2]*ch_mult, encodFilters[-2]*ch_mult, encodFilters[-3]*ch_mult, norm_op=norm_op)
        self.uploc1 = UpBasicBlock((encodFilters[-2]+encodFilters[-2])*ch_mult, encodFilters[-2]*ch_mult, norm_op=norm_op)
        
        self.upTrans2 =  nn.ConvTranspose3d(encodFilters[-2]*ch_mult, encodFilters[-3]*ch_mult, **loc_upTrans_kwargs)
        self.att2 = Attention_block(encodFilters[-3]*ch_mult, encodFilters[-3]*ch_mult, encodFilters[-4]*ch_mult, norm_op=norm_op)
        self.uploc2 = UpBasicBlock((encodFilters[-3]+encodFilters[-3])*ch_mult, encodFilters[-3]*ch_mult, norm_op=norm_op)
        self.loc2_seg = ConvNonlinSeg(encodFilters[-3]*ch_mult, num_classes, conv_kwargs = loc_seg_conv_kwargs)
        
        self.upTrans3 =  nn.ConvTranspose3d(encodFilters[-3]*ch_mult, encodFilters[-4]*ch_mult, **loc_upTrans_kwargs) 
        self.att3 = Attention_block(encodFilters[-4]*ch_mult, encodFilters[-4]*ch_mult, encodFilters[-5]*ch_mult, norm_op=norm_op)
        self.uploc3 = UpBasicBlock((encodFilters[-4]+encodFilters[-4])*ch_mult, encodFilters[-4]*ch_mult, norm_op=norm_op)
        self.loc3_seg = ConvNonlinSeg(encodFilters[-4]*ch_mult, num_classes, conv_kwargs = loc_seg_conv_kwargs)
        
        
        self.upTrans4 =  nn.ConvTranspose3d(encodFilters[-4]*ch_mult, encodFilters[-5]*ch_mult, **loc_upTrans_kwargs)
        self.att4 = Attention_block(encodFilters[-5]*ch_mult, encodFilters[-5]*ch_mult, encodFilters[-5]*ch_mult//2, norm_op=norm_op)
        self.uploc4 = UpBasicBlock((encodFilters[-5]+encodFilters[-5])*ch_mult, encodFilters[-5]*ch_mult, norm_op=norm_op)
        self.loc4_seg = ConvNonlinSeg(encodFilters[-5]*ch_mult, num_classes, conv_kwargs = loc_seg_conv_kwargs)
        
#         self.upTrans5 =  nn.ConvTranspose3d(encodFilters[-5]*ch_mult, encodFilters[-6]*ch_mult, **loc_upTrans_kwargs)
#         self.att5 = Attention_block(encodFilters[-6]*ch_mult, encodFilters[-6]*ch_mult, encodFilters[-6]*ch_mult//2, norm_op=norm_op)
#         self.uploc5 = UpBasicBlock((encodFilters[-6]+encodFilters[-6])*ch_mult, encodFilters[-6]*ch_mult, norm_op=norm_op)
#         self.loc5_seg = ConvNonlinSeg(encodFilters[-6]*ch_mult, num_classes, conv_kwargs = loc_seg_conv_kwargs)
        
        
#         self.upTrans6 =  nn.ConvTranspose3d(encodFilters[-6]*ch_mult, encodFilters[-7]*ch_mult, **loc_upTrans_kwargs)
#         self.att6 = Attention_block(encodFilters[-7]*ch_mult, encodFilters[-7]*ch_mult, (encodFilters[-7]*ch_mult)//2, norm_op=norm_op)
#         self.uploc6 = UpBasicBlock((encodFilters[-7]+encodFilters[-7])*ch_mult, encodFilters[-7]*ch_mult, norm_op=norm_op)
#         self.loc6_seg = ConvNonlinSeg(encodFilters[-7]*ch_mult, num_classes, conv_kwargs = loc_seg_conv_kwargs)
        
        
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(torch.as_tensor(m.weight), a=0.01)
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(torch.as_tensor(m.weight), 1)
                nn.init.constant_(torch.as_tensor(m.bias), 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(torch.as_tensor(m.bias), 0)
                
        
        
    def forward(self, x):
        seg_outputs = []
        size = list(x.shape[2:])

        
        xyzh, xdecod3, xdecod2, xdecod1, xdecod0 = self.get_skip_encodcontext_mtch4(x)

        
        ##########################
        ###Strating of up block ###
        
        xyzh = self.upTrans1(xyzh)
        xdecod3 = self.att1(gA=xyzh, xA= xdecod3)
        xyzh = torch.cat([xyzh, xdecod3], dim=1)
        xyzh = self.uploc1(xyzh) 
        
        xyzh = self.upTrans2(xyzh)
        xdecod2 = self.att2(gA=xyzh, xA= xdecod2)
        xyzh = torch.cat([xyzh, xdecod2], dim=1)
        xyzh = self.uploc2(xyzh)
        seg_outputs.append(F.interpolate(self.loc2_seg(xyzh), size = size))
        
        xyzh = self.upTrans3(xyzh)
        xdecod1 = self.att3(gA=xyzh, xA= xdecod1)
        xyzh = torch.cat([xyzh, xdecod1], dim=1)
        xyzh = self.uploc3(xyzh)
        seg_outputs.append(F.interpolate(self.loc3_seg(xyzh), size = size))
        
        xyzh = self.upTrans4(xyzh)
        xdecod0 = self.att4(gA=xyzh, xA= xdecod0)
        xyzh = torch.cat([xyzh,xdecod0], dim=1)
        xyzh = self.uploc4(xyzh)
        seg_outputs.append(F.interpolate(self.loc4_seg(xyzh), size = size))
        
        
        if self.training:
            return torch.stack([seg_outputs[-1], seg_outputs[-2], seg_outputs[-3]], dim=1)
        else:
            return seg_outputs[-1]
        
                 

In [41]:
transfer_model_save_dir = os.path.join('/raid/brats2021/pthBraTS2020_IDHGenomics/TwoEncodUNetVariants_TCGA')

transfer_mode_DCTList = {'fold0': glob.glob(f'{transfer_model_save_dir}/AttnDynUNet_BratsTCGA_HistStand_3CV_4Chnls1PatchSWIRngr21_2nclass_MorePatchBNormEp500_Fold0_0.8220_epoch287.pt*'),\
               'fold1':glob.glob(f'{transfer_model_save_dir}/AttnDynUNet_BratsTCGA_HistStand_3CV_4Chnls1PatchSWIRngr21_2nclass_MorePatchBNormEp500_Fold1_0.8581_epoch296.pth*'),\
               'fold2':glob.glob(f'{transfer_model_save_dir}/AttnDynUNet_BratsTCGA_HistStand_3CV_4Chnls1PatchSWIRngr21_2nclass_MorePatchBNormEp500_Fold2_0.8866_epoch256.pth*')}
transfer_modelPath = transfer_mode_DCTList['fold0'][0]
transfer_modelPath

'/raid/brats2021/pthBraTS2020_IDHGenomics/TwoEncodUNetVariants_TCGA/AttnDynUNet_BratsTCGA_HistStand_3CV_4Chnls1PatchSWIRngr21_2nclass_MorePatchBNormEp500_Fold0_0.8220_epoch287.pth'

In [43]:
# kernels, strides = get_kernels_strides((128, 128, 128), spacing)
# #kernels, strides
# kernels = [[3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 6, 3]]
# strides = [[1, 1, 1], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2], [2, 2, 2]]

### Defining loss functions
- ***CrossEntropyLogitLoss*** Cross entropy logit loss from [PyTorch](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)
- ***DiceCELoss*** Dice + Cross entropy loss from Monai
https://docs.monai.io/en/latest/losses.html


In [44]:
class CrossEntropyInstWLogitLoss(nn.Module):
    def __init__(self, is_smooth=False, label_smoothing = 0.1):
        super().__init__()
        self.is_smooth = is_smooth
        self.label_smoothing = label_smoothing
        
       
    def forward(self, y_pred, y_true):

        if self.is_smooth == True:
            y_true = y_true.float() * (1 - self.label_smoothing) + 0.5 * self.label_smoothing

        y_true=y_true.type_as(y_pred)   ### y_pred, and y_true should be same size and same data type
        
        
        #deviceidx = y_pred.get_device()
        #device = torch.device('cpu') if deviceidx == -1 else torch.device(f'cuda:{deviceidx}')
        #loss = F.binary_cross_entropy_with_logits(y_pred.to(device), y_true.to(device), pos_weight = weight.to(device))  ##pos_weight = weight 
        loss = F.binary_cross_entropy_with_logits(y_pred, y_true) 
        return loss


# class DeepDiceCELogitInstLoss(nn.Module):
#     def __init__(self):
#         super().__init__()
#         #self.volweight = torch.softmax(torch.tensor([0.12, 0.33, 0.55]), dim = 0)
#         self.dice = DiceLoss(include_background=False, to_onehot_y=True, softmax=True, squared_pred=True, batch = False)  # reduction = "none", batch = True
#         #self.smcross_entropy = CrossEntropyInstLoss()  ### was none torch.Tensor([0.66, 0.33, 1]), torch.tensor(self.volweight)

#     def forward(self, y_pred, y_true):
        
#         y_true = y_true.unsqueeze(dim=0).expand(y_pred.shape[1],-1,-1,-1,-1, -1)
#         #return sum([0.5 ** i * ((self.dice(p, l)) + self.smcross_entropy(p, l)) \
#         #            for i, (p, l) in enumerate(zip(torch.unbind(y_pred, dim=1), torch.unbind(y_true, dim=0)))])
#         return sum([0.5 ** i * self.dice(p, l) for i, (p, l) in enumerate(zip(torch.unbind(y_pred, dim=1), torch.unbind(y_true, dim=0)))])




class DeepDiceCELogitInstLoss(nn.Module):
    def __init__(self):
        super().__init__()
        #self.volweight = torch.softmax(torch.tensor([0.12, 0.33, 0.55]), dim = 0)
        self.dice = DiceLoss(to_onehot_y=False, sigmoid=True, squared_pred=True, batch = True)  # reduction = "none", False
        self.logitcross_entropy = CrossEntropyInstWLogitLoss()  ### was none torch.Tensor([0.66, 0.33, 1]), torch.tensor(self.volweight)

    def forward(self, y_pred, y_true):
        
        y_true = y_true.unsqueeze(dim=0).expand(y_pred.shape[1],-1,-1,-1,-1, -1)
        return sum([0.5 ** i * ((self.dice(p, l)) + self.logitcross_entropy(p, l)) \
                    for i, (p, l) in enumerate(zip(torch.unbind(y_pred, dim=1), torch.unbind(y_true, dim=0)))])
    
loss_function = DeepDiceCELogitInstLoss()

In [45]:
xxp = torch.randint(0,4,size=(6,3, 128, 128, 128))
#pred = [torch.randn(6,3,8,8,6), torch.randn(6,3,8,8,6), torch.randn(6,3,8,8,6)]
pred = torch.stack([torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128), torch.randn(6, 3, 128, 128, 128)], dim=1)
loss_function(pred, xxp.float())
#loss_function(pred, xxp.float())

tensor(2.6448)

#### A function to create ***cache_dir*** to save transformed outputs

In [46]:
def removeAndcreate_cachedir(cache_dir):
    if not os.path.exists(cache_dir):
        os.makedirs(cache_dir)
    else:
        #print("Pass")
        shutil.rmtree(cache_dir)
        os.makedirs(cache_dir)
    return 1

### Pytorch training loop

Following functionalities are added

- Implemeting learning rate finder
- Defining Ranger21 optimizer (learning rate scheduler attached to it)
- Implementing mixed precision (AMP)
- Saving the model weights based on the performance on validation data
- Executing 5 fold cross validation (CV) training/validation pipeline, saving a few best model's weights at each fold
- Defining train_dataset/train_loader, and val_dataset/val_loader at each fold
- Defining a CNN based classification model (DenseNet, EfficientNet, etc.) at each fold to make sure that all accumulated gradients get vanished

The key variables which are used here
- ***val_cache_dir:*** The path where the transformed ouputs of validaion files will be cached/saved
- ***train_cache_dir:*** The path where the transformed ouputs of training files will be cached/saved
- ***max_epochs:*** Total number of epochs
- ***save_dir:*** The path to save the checkpoints/weights of the model
- ***file_prefix:*** The text file where loss/accuracy is recoded like

```
current fold: 0 current epoch: 1, acc_metric: 0.4579 accuracy: 0.5085, f1score: 0.5085 epoch 1 average training loss: 0.7250 average validation loss: 0.7128 
current fold: 0 current epoch: 2, acc_metric: 0.4876 accuracy: 0.5424, f1score: 0.5424 epoch 2 average training loss: 0.6961 average validation loss: 0.6940 
current fold: 0 current epoch: 3, acc_metric: 0.4870 accuracy: 0.4915, f1score: 0.4915 epoch 3 average training loss: 0.6862 average validation loss: 0.6965 
current fold: 0 current epoch: 4, acc_metric: 0.4882 accuracy: 0.5593, f1score: 0.5593 epoch 4 average training loss: 0.6715 average validation loss: 0.6885 
current fold: 0 current epoch: 5, acc_metric: 0.5927 accuracy: 0.5593, f1score: 0.5593 epoch 5 average training loss: 0.6555 average validation loss: 0.6826 
```

- ***val_interval*** Epoch interval to investivate the model's performance on validation data. If the current performance is better than in previous epochs, the model's weights will be saved
- ***key_metric_n_saved*** The number of model checkpoints we want to save. It it is set as 5, top 5 checkpoints/weights will be saved in 5 different ***.pth*** files  


In [47]:

#***Executed pipeline***\
#<img src="assets/ProposedIDHClass.png" align="left" width="1024" height="1800">

In [48]:
def get_segclass(x, dim = 1):
    xdvc = x.device
    x_chlist = torch.unbind(x, dim = dim)
    xclassNoList = []
    xvalueList = []

    for x_i in x_chlist:

        xv, xc = torch.unique(x_i, return_counts  = True)

        if xc.shape[0]==1:
            if xv==0:
                xclassNoList.append(-1)
                xvalueList.append(xv[0].item())
            elif xv==1:
                xclassNoList.append(xc[0].item())
                xvalueList.append(xv[0].item())
            else:
                print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')


        elif xc.shape[0]==2:
                if torch.any(torch.eq(xv, 1)):
                    xclassNoList.append(xc[1].item())
                    xvalueList.append(xv[1].item())
                else:
                    print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')
        else:
            print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')

    #pdb.set_trace()
    #if torch.any(torch.eq(torch.tensor(xvalueList), 1)):

    if xclassNoList[0]!=xclassNoList[1]: 
        xclass = torch.argmax(torch.tensor(xclassNoList).to(xdvc))
    else:
        xclass = torch.tensor(float('NaN')).to(xdvc)

    #else:
        '''If all uniques class values are 0, we are assigning nan values as a class'''
    #    xclass = torch.tensor(float('NaN')).to(xdvc)


    return xclass

In [49]:
def getbatch_segclass(x):
    xdvc = x.device
    x_batchlist = torch.unbind(x, dim = 0)
    xbatchclass = []
    for xc in x_batchlist:
        xcclass = get_segclass(xc, dim = 0)            
        xbatchclass.append(xcclass.item())


    xbatchclass = torch.tensor(xbatchclass)
    if torch.all(torch.isnan(xbatchclass))==True:

            return torch.tensor(float('NaN')).to(xdvc)

    else:

        num_xbatchnanvalues = torch.isnan(xbatchclass).sum().item()
        not_xbatchnanmask = torch.logical_not(torch.isnan(xbatchclass))
        xbatchclass = xbatchclass[not_xbatchnanmask]

        xclassVal_01, xclassCnt_01 =xbatchclass.unique(return_counts = True)

        if xclassCnt_01.shape[0]==1:
            return xclassVal_01[0].to(xdvc)


        if xclassCnt_01.shape[0]==2:
            if xclassCnt_01[0]!=xclassCnt_01[1]:
                ''' xclassCnt_01 will always be two values converting [7, 8] to 1; [8, 7] to 0'''
                return torch.argmax(xclassCnt_01).to(xdvc)  

            else:

                return torch.tensor(float('NaN')).to(xdvc)
                

In [50]:
def get_binarize_tensor(x, dim=1):
    x_chlist = torch.unbind(x, dim = dim)
    bin_x = torch.zeros_like(x_chlist[0])
    for x_i in x_chlist:
        bin_x = torch.logical_or(x_i, bin_x)
    return bin_x.unsqueeze(dim=dim).to(torch.float32)

### Starting training loop

In [51]:
def train(train_files, train_files_IDH_label, val_files, val_files_IDH_label, batch_size = 2, epochs = 10, find_lr=False, cfold = 0, transfer_modelPath=None):
    

    #     model = DenseNet201(spatial_dims=2, in_channels=3,
    #                        out_channels=num_classes, pretrained=True).to(device)

    # create spatial 3D
    #model = MultiDenseNet(spatial_dims=3, in_channelsList=(4, 1, 1, 1, 1), out_channels=2, block_config = (6, 12, 24, 16)).to(device)
    #model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=4, out_channels=1).to(device)
    #model = monai.networks.nets.DenseNet264(spatial_dims=3, in_channels=4, out_channels=1, init_features=64, growth_rate=32, block_config=(6, 12, 64, 48)).to(device)
    #patch_size=(64, 80, 64)
    
    num_classes = 2
    model = DynUOneEncodAttn(num_classes=num_classes, num_input_channels=2, base_filters=32).to(device)
    
    
    
    auc_metric = ROCAUCMetric()
    

    #train_files, train_files_IDH_label, val_files, val_files_IDH_label = train_files[:48], train_files_IDH_label[:48], val_files[:16], val_files_IDH_label[:16]

    """
    Block for using Monai's caching mechanishm for faster training
    """
    
    file_prefixfold = file_prefix  ##or file_prefix f"{file_prefix}_fold{cfold}" if saving cv file
    data_rpath = '/home/mmiv-ml/data'
    train_cache_dir = os.path.join(data_rpath,f'cachingDataset/{file_prefixfold}/train')    
    val_cache_dir = os.path.join(data_rpath,f'cachingDataset/{file_prefixfold}/val')
    
    is_done_train = removeAndcreate_cachedir(train_cache_dir)
    is_done_val = removeAndcreate_cachedir(val_cache_dir)
    

    n_train_cache_n_trans = len(train_transforms) #15
    n_val_cache_n_trans = len(val_transforms)
    
     # create a training data loader

    train_dataset = monai.data.CacheNTransDataset(data=train_files, transform=train_transforms,\
                                               cache_n_trans = n_train_cache_n_trans, cache_dir = train_cache_dir)
    
    
    
#    train_dataset = monai.data.Dataset(data=train_files, transform=train_transforms)
    
#     #train_folds['fold0_IDH_label']

    train_files_IDH_labels = np.array([id_lbl['IDH_label'].item() for id_lbl in train_files])
    uval, ucnt = np.unique(train_files_IDH_labels, return_counts=True)
    weight = 1./(ucnt/ucnt.min())
    #weight = 1./ucnt
    #weight = np.array([0.55, 0.45])
    sample_weights = np.array([weight[int(t)] for t in train_files_IDH_labels])
    sample_weights = torch.from_numpy(sample_weights)
    sampler = WeightedRandomSampler(sample_weights, num_samples= len(sample_weights), replacement=True)


    #train_dataset = Dataset(data=train_files, transform=train_transforms)
    #train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=2, pin_memory=True)
    #train_dataset = CacheDataset(data=train_files, transform=train_transforms, cache_rate = 1.0, num_workers=8)
    #train_loader = ThreadDataLoader(train_dataset, num_workers=0, batch_size=batch_size, shuffle=True)
    
    #shiffle = False, sampler = sampler, shuffle=True doesnot work with patch, num_workers=2
    train_loader = monai.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=list_data_collate, sampler = sampler) 
    

    
    # create a validation data loader
    
    #val_dataset = monai.data.Dataset(data=val_files, transform=val_transforms)
    #val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)
    #val_dataset = CacheDataset(data=val_files, transform=val_transforms, cache_rate = 1.0, num_workers=5)
    
    val_dataset = monai.data.CacheNTransDataset(data=val_files, transform=val_transforms,\
                                            cache_n_trans = n_val_cache_n_trans, cache_dir = val_cache_dir)
    #val_loader = ThreadDataLoader(val_dataset, num_workers=0, batch_size=1)
    val_loader = monai.data.DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=2, pin_memory=True) ##
    
    
    
#     for ibatch in train_loader:
#         ibatch_IDH = ibatch['IDH_label']
#         print(torch.eq(ibatch_IDH, 0).sum(), torch.eq(ibatch_IDH, 1).sum())
#         print('#'*50)
    
    
    """
    just initialising some basic steps
    """
    
    max_epochs = epochs
    find_lr=False
    
    ### Calling the loss function ***CrossEntropyPlusMSELoss**,and optimizer   
    #loss_function = nn.CrossEntropyLoss()
    #loss_function=nn.BCEWithLogitsLoss()
    
    #optimizer = torch.optim.Adam(model.parameters(), lr=1e-03, weight_decay=1e-07)
    optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-05, weight_decay = 1e-4)
    #optimizer = torch.optim.SGD(model.parameters(), lr=1e-03, momentum= 0.99, nesterov=True)
    #optimizer = torch.optim.Adam(model.parameters(), 1e-03, weight_decay=1e-04)
    scaler = torch.cuda.amp.GradScaler()
    
    
    max_lr_init = 1e-04
    """
     ###################### Block for LR finder from pytorch ignite ########################

    """

    if find_lr:

        def prepare_batch(batch, device=None, non_blocking=False):
            return _prepare_batch((batch['image'], batch['IDH_label'].long()), device, non_blocking)

        #trainer = create_supervised_trainer(model, optimizer, loss_function, device, non_blocking=False, prepare_batch=prepare_batch)
        def train_step(engine, batch):
            model.train()
            optimizer.zero_grad()
            x, y = batch['image'].to(device), batch['IDH_label'].to(device)  #non_blocking=True
            with torch.cuda.amp.autocast():
                y_pred = model(x)
                loss4lr = loss_function(y_pred, y)
                
            scaler.scale(loss4lr).backward()
            scaler.step(optimizer)
            scaler.update()
            return loss4lr.item()

        trainer = Engine(train_step)

        ProgressBar(persist=True).attach(trainer, output_transform=lambda x: {"batch loss": x})
        lr_finder = FastaiLRFinder()
        to_save={'model': model, 'optimizer': optimizer}
        num_iter = 100  #2*len(train_loader)
        run_epochs = int(np.ceil(num_iter/len(train_loader)))
        with lr_finder.attach(trainer, to_save, end_lr=10, num_iter=num_iter, diverge_th=1.5) as trainer_with_lr_finder:    ####diverge_th=1.5

            trainer_with_lr_finder.run(train_loader, max_epochs=run_epochs)  #max_epochs=run_epochs or 5

        ax = lr_finder.plot()
        plt.show()
        
        max_lr = lr_finder.lr_suggestion() if lr_finder.lr_suggestion()<5e-03 else max_lr_init
        #max_lr = lr_finder.lr_suggestion() ##max_lr/10 i guess not needed, ignite does itself
        print(f'Suggested learning rate by LR finder for this fold: {lr_finder.lr_suggestion()}')
        
    else:
        max_lr = max_lr_init
        
    #max_lr_slice = 1e01*max_lr if max_lr<5e-03 else 1e-02
    #max_lr_slice = 1e-01*max_lr if max_lr<1e-05 else max_lr
    
    ''' Transfer learning section '''
    if transfer_modelPath is not None:
        current_model_dict = model.state_dict()
        loaded_state_dict = torch.load(transfer_modelPath, map_location=device)
        new_state_dict={k:v if v.size()==current_model_dict[k].size() else current_model_dict[k] for k,v in zip(current_model_dict.keys(), loaded_state_dict.values())}
        model.load_state_dict(new_state_dict, strict=False)
        
    
    """
    ### defining learning rate scheduler
    
    """
    #steps_per_epoch=len(train_loader)
    #optimizer.param_groups[0]['lr'] = max_lr #*1e-01
    #scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=max_lr_slice, steps_per_epoch=len(train_loader), epochs=max_epochs)
    #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: (1 - epoch / max_epochs) ** 0.9)
    
    #max_lr = 1e-3   
    optimizer = Ranger21(model.parameters(), lr = max_lr, num_epochs = epochs, num_batches_per_epoch = len(train_loader))
    


    """
     ###################### Block for native pytorch training loop and  ########################
    """

    key_metric_n_saved = 2   ### Usually I keep it 5
    save_last = False 
    dispformat_specs = '.4f'

        
#     file_prefix = 'ConvEffNet_Brats21_5CV'
#     savedirname = 'ConvEffNet_Brats21'
#     save_dir = os.path.join('/raid/brats2021/pthBraTS2021Radiogenomics', savedirname)
#     if not os.path.exists(save_dir):
#         os.makedirs(save_dir)

    logsfile_path = f"{save_dir}/Logs_{file_prefix}.txt"


    epoch_num = max_epochs #  max_epochs
    val_interval = 1
    valstep = 0
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()


    numsiters = len(train_files) // train_loader.batch_size

    first_batch = monai.utils.misc.first(train_loader)
        
    
    #post_pred = AsDiscrete(argmax=True, to_onehot=num_classes)  ### num_classes=num_classes
    #post_label = AsDiscrete(to_onehot=num_classes) ###num_classes=num_classes
    #dice_metric = monai.metrics.DiceMetric(include_background=False, reduction='mean', get_not_nans=False)
    
    
    dice_metric = monai.metrics.DiceMetric(include_background=True, reduction='mean', get_not_nans=False)
    dice_metric_batch = monai.metrics.DiceMetric(include_background=True, reduction='mean_batch', get_not_nans=False)
    post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])  
    
    def one_hot_permute(x):
        return F.one_hot(x.squeeze(dim=0).long(), num_classes=num_classes).permute(3, 0, 1, 2)
    

    

    for epoch in range(epoch_num):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{epoch_num}")
        model.train()
        epoch_loss = 0.
        stepiter = 0
        for batch_data in train_loader:
            
            stepiter += 1
            inputs, labels, IDH_labels= (
                batch_data['image'].to(device),
                batch_data['label'].to(device),
                batch_data['IDH_label'].to(device),
            )
            
          
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():

                # compute output
                outputs  = model(inputs)
                loss = loss_function(outputs, labels) 

            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += loss.item()

            print(f"{stepiter}/{numsiters}, train_loss: {loss.item():.4f}")

            #scheduler.step() 
            
        epoch_loss /= stepiter
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        
        
        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():

                y_pred = torch.tensor([], dtype=torch.float32, device=device)
                y = torch.tensor([], dtype=torch.long, device=device)
                val_losses = torch.tensor([], dtype=torch.float32, device=device)


                for val_data in val_loader:

                    val_inputs, val_labels, val_IDH_labels = (
                        val_data['image'].to(device),
                        val_data['label'].to(device),
                        val_data['IDH_label'].to(device),
                    )
                
                    roi_size = patch_size #(32, 32, 32)
                    sw_batch_size = 8
                    val_overlap = 0.5
                    mode="gaussian"
                            
                    
                    with torch.cuda.amp.autocast():
                        
                        val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model, mode = mode, overlap = val_overlap, sw_device = device, device = device) 
                        #val_outputs = model(val_inputs)
                        val_ce_loss = loss_function(val_outputs.unsqueeze(dim=1), val_labels)

                    val_losses = torch.cat([val_losses, val_ce_loss.view(1)], dim = 0)
                    val_outputs = torch.stack([post_pred(i) for i in torch.unbind(val_outputs, dim = 0)], dim = 0)
                    
                    
                    #val_labels2hot = torch.stack([one_hot_permute(i) for i in torch.unbind(val_labels, dim = 0)], dim = 0)

                    
                    val_labels_bin = get_binarize_tensor(val_labels, dim=1)
                    val_outputs_bin = get_binarize_tensor(val_outputs, dim=1)
                    
                    dice_metric(y_pred=val_outputs_bin, y=val_labels_bin)
                    
                    
                    klcc = KeepLargestConnectedComponent(applied_labels = [0, 1])  ##is_onehot=True
                    #val_labels = klcc(val_labels.squeeze(dim=0)).unsqueeze(dim=0)
                    #val_outputs = klcc(val_outputs.squeeze(dim=0)).unsqueeze(dim=0)
                    val_outputs = torch.stack([klcc(i) for i in torch.unbind(val_outputs, dim = 0)], dim = 0)
                
                    val_label4mSeg_C = get_segclass(val_outputs, dim = 1)
                    #val_label4mSeg_C = getbatch_segclass(val_outputs)
                    #val_surv_labels = val_surv_labels.squeeze(dim=1)  ###Squeezing from B, 1 to B if needed
                    y_pred = torch.cat([y_pred, val_label4mSeg_C.view(1)], dim=0)
                
                    #pdb.set_trace()    
                    val_IDH_labels = torch.mode(val_IDH_labels)[0].view(1)           
                    y = torch.cat([y, val_IDH_labels], dim=0)

                mdice_value = dice_metric.aggregate()
                dice_metric.reset()
                
                
                
                y_pred, y = y_pred.cpu(), y.cpu()
                
                if torch.all(torch.isnan(y_pred))==True:
                    
                    auc_result, accscore, f1score = np.nan, np.nan, np.nan
                    #print('acc_metric#', np.nan, ', auc#', np.nan, ', f1#', np.nan, '\n')
                
                else:

                    num_nanvalues = torch.isnan(y_pred).sum().item()
                    not_nanmask = torch.logical_not(torch.isnan(y_pred))
                    y = y[not_nanmask]
                    y_pred = y_pred[not_nanmask]
                    
                    
                    acc_value = torch.eq(y_pred, y)
                    acc_metric = acc_value.sum().item() / len(acc_value)

                    '''auc metric'''
                    auc_metric(y_pred, y)
                    auc_result = auc_metric.aggregate()
                    auc_metric.reset()
                    
                    '''balanced accuracy and f1 score'''
                    accscore = balanced_accuracy_score(y, y_pred)
                    f1score = f1_score(y, y_pred, average='micro')
                    #print('acc_metric#', acc_metric, ', auc#', auc_result, ', f1#', f1score, '\n')
                

                del y, y_pred
                
            
                epoch_val_losses=torch.mean(val_losses).detach().cpu().item()
                #metric = auc_result
                mdice_value = mdice_value.item()
                #metric = mdice_value
                metric = (mdice_value+auc_result)/2
                metric= -1.0 if np.isnan(metric) else metric
                metric_values.append(metric) ######List of over number of epochs
                printstring = "Best PMetric"
                

                with open(logsfile_path, 'a') as file:
                    file.write(
                        f"current fold: {cfold} current epoch: {epoch + 1} dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}" 
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                        f" epoch {epoch + 1} average training loss: {epoch_loss:^{dispformat_specs}} average validation loss: {epoch_val_losses:^{dispformat_specs}} \n"

                    )

                if valstep < key_metric_n_saved:
                    
                    torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_Fold{cfold}_{metric:^{dispformat_specs}}_epoch{epoch + 1}.pth"))
                    print(
                        f"current fold: {cfold} current epoch: {epoch + 1} dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}" 
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                        f" epoch {epoch + 1} average training loss: {epoch_loss:^{dispformat_specs}} average validation loss: {epoch_val_losses:^{dispformat_specs}}"
                        
                    )

                else:

                    #sortmetric_values = sorted(metric_values[:-1], reverse=True)  ###Higher loss needs to be deleted, so sorting is reversed
                    sortmetric_values = sorted(metric_values[:-1], reverse=False)  

                    if metric>=sortmetric_values[-key_metric_n_saved]:
                        savegood_metric = metric
                        good_metric_epoch = epoch + 1

                        #if os.path.exists(f"{save_dir}/{file_prefix}_{sortmetric_values[-key_metric_n_saved]:.4f}.pth"):
                        #    os.remove(f"{save_dir}/{file_prefix}_{sortmetric_values[-key_metric_n_saved]:.4f}.pth")
                        #else:
                        #    print("The file does not exist")
                        
                        
                        glblist = glob.glob(f"{save_dir}/{file_prefix}_Fold{cfold}_{sortmetric_values[-key_metric_n_saved]:^{dispformat_specs}}_*")

                        if not glblist:
                            print("The file does not exist")
                        else:
                            os.remove(glblist[0])


                        torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_Fold{cfold}_{metric:^{dispformat_specs}}_epoch{epoch + 1}.pth"))
                        print("saved new best metric model")
                        print(
                            f"current fold: {cfold} current epoch: {epoch + 1} validation loss: {epoch_val_losses:^{dispformat_specs}}"
                            f" dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}"
                            f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"
                            f"\n saved {printstring}: {savegood_metric:^{dispformat_specs}} at epoch: {good_metric_epoch}"
                        )

                    else:

                        f"current fold: {cfold} current epoch: {epoch + 1} validation loss: {epoch_val_losses:^{dispformat_specs}}"
                        f" dice_score: {mdice_value:^{dispformat_specs}} acc_metric: {auc_result:^{dispformat_specs}}"
                        f" accuracy: {accscore:^{dispformat_specs}}, f1score: {f1score:^{dispformat_specs}}"

                        #pass

                valstep += 1
        ####Saving last epoch
        if epoch==epoch_num-1:
            if save_last:
                torch.save(model.state_dict(), os.path.join(save_dir, f"{file_prefix}_Fold{cfold}_{metric:^{dispformat_specs}}_last_epoch{epoch + 1}.pth"))
                
            #break
            
    # Free up GPU memory after training
    model = None
    train_loader, val_loader = None, None
    gc.collect()
    torch.cuda.empty_cache()

### Loop to execute n_splits=3 fold cross validation
if the model is trained and the checkpoints are saved already, just setting the start_training flag as false, to run remaining part of the programs of this notebook

In [52]:
start_training = False

In [53]:
n_splits = 3
if start_training:
    #### Running 10 folds
    for i in range(0, n_splits):
        
        #train_files_fld, train_files_fld_IDH_label, val_files_fld, val_files_fld_IDH_label  = copy.deepcopy(train_folds[f'fold{i}']), copy.deepcopy(train_folds[f'fold{i}_IDH_label']),\
        #copy.deepcopy(val_folds[f'fold{i}']), copy.deepcopy(val_folds[f'fold{i}_IDH_label'])
             
        train_files_fld, val_files_fld  = copy.deepcopy(BraTS20SubjectsIDHTrainDCT[f'fold{i}']), copy.deepcopy(BraTS20SubjectsIDHValDCT[f'fold{i}'])
        train_files_fld_IDH_label, val_files_fld_IDH_label = None, None
        batch_size=8
        ### Need to change batch size if minimux training batch size == 1
        print('fold', i, "Bacth Investigation, minimum batch size", len(train_files_fld)%batch_size)        
        #train(train_files_fld, train_files_fld_IDH_label, val_files_fld, val_files_fld_IDH_label, batch_size=batch_size, epochs = 200, cfold = i)
        train(train_files_fld, train_files_fld_IDH_label, val_files_fld, val_files_fld_IDH_label, batch_size=batch_size, epochs = 500, cfold = i, transfer_modelPath = None)
    
    start_training = False
else:
    
    pass

In [54]:
# from sklearn.datasets import make_multilabel_classification
# from sklearn.multioutput import MultiOutputClassifier
# from sklearn.linear_model import LogisticRegression
# X, y = make_multilabel_classification(random_state=0, n_classes=2)
# inner_clf = LogisticRegression(solver="liblinear", random_state=0)
# clf = MultiOutputClassifier(inner_clf).fit(X, y)
# y_score = np.transpose([y_pred[:, 1] for y_pred in clf.predict_proba(X)])
# roc_auc_score(y, y_score, average=None)

## Inference 

In [55]:
def inferWithTA(data_loader,listmodels, prediction_folder="./", topk=1, num_channels = 4,\
                orientation="LPS", withoptimizer = False, softmaxEnsemble=False, save_inference = False, tta = False):
    """
    run inference, the output folder will be "./output"
    """        
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    listmodels = listmodels[0:topk]
    for x in listmodels:
        print(f"available model file: {x}.")
        
    channel_nums =  monai.utils.misc.first(data_loader)['image'].shape[1] ##next(iter(val_loader["image"])).shape[1]
    channelNums = f"{channel_nums} channels"
    keys = ('image',)
    patch_size = (32, 32, 32)
    
    post_trans_sigbin = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
    post_trans_bin = Compose([EnsureType(), AsDiscrete(threshold=0.5)])
    post_trans_sig = Compose([EnsureType(), Activations(sigmoid=True)])
    
    auc_metric = ROCAUCMetric()
    dice_metric = monai.metrics.DiceMetric(include_background=True, reduction='mean', get_not_nans=False)
    dice_metric_batch = monai.metrics.DiceMetric(include_background=True, reduction='mean_batch', get_not_nans=False)
    
    
    HD_metric = HausdorffDistanceMetric(include_background=True, percentile = 95., reduction='mean', get_not_nans=False)
    HD_metric_batch = HausdorffDistanceMetric(include_background=True, percentile = 95., reduction='mean_batch', get_not_nans=False)
    
    
    
    post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])  
    
    def one_hot_permute(x):
        return F.one_hot(x.squeeze(dim=0).long(), num_classes=num_classes).permute(3, 0, 1, 2)
    
    def get_binarize_tensor(x, dim=1):
        x_chlist = torch.unbind(x, dim = dim)
        bin_x = torch.zeros_like(x_chlist[0])
        for x_i in x_chlist:
            bin_x = torch.logical_or(x_i, bin_x)
        return bin_x.unsqueeze(dim=dim).to(torch.float32)
        
    def get_segclass(x, dim = 1):
        xdvc = x.device
        x_chlist = torch.unbind(x, dim = dim)
        xclassNoList = []
        xvalueList = []

        for x_i in x_chlist:

            xv, xc = torch.unique(x_i, return_counts  = True)

            if xc.shape[0]==1:
                if xv==0:
                    xclassNoList.append(-1)
                    xvalueList.append(xv[0].item())
                elif xv==1:
                    xclassNoList.append(xc[0].item())
                    xvalueList.append(xv[0].item())
                else:
                    print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')


            elif xc.shape[0]==2:
                    if torch.any(torch.eq(xv, 1)):
                        xclassNoList.append(xc[1].item())
                        xvalueList.append(xv[1].item())
                    else:
                        print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')
            else:
                print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')

        #pdb.set_trace()
        #if torch.any(torch.eq(torch.tensor(xvalueList), 1)):

        if xclassNoList[0]!=xclassNoList[1]: 
            xclass = torch.argmax(torch.tensor(xclassNoList).to(xdvc))
        else:
            xclass = torch.tensor(float('NaN')).to(xdvc)

        #else:
            '''If all uniques class values are 0, we are assigning nan values as a class'''
        #    xclass = torch.tensor(float('NaN')).to(xdvc)


        return xclass
    
    
    
    keys = ("image",)
        
    with torch.no_grad():
        
        y_pred = torch.tensor([], dtype=torch.float32, device=device)
        y = torch.tensor([], dtype=torch.long, device=device)
    
        for infindx, infer_data in enumerate(tqdm(data_loader)):

            
            val_inputs, val_labels, val_IDH_labels = (
                infer_data['image'].to(device),
                infer_data['label'].to(device),
                infer_data['IDH_label'].to(device),
            )
                

            n_class = 2
            val_outputsAll = torch.zeros(val_inputs.shape[0], n_class, val_inputs.shape[2], val_inputs.shape[3], val_inputs.shape[4]).to(device)
            n_model = 0.
            
            #for mdlindx in (number+1 for number in range(topk)):
            for mdlindx in range(topk):
    
                print(f'Model {mdlindx}, {listmodels[mdlindx]} is running now')
                model = None        
                num_classes = 2
                model = DynUOneEncodAttn(num_classes=num_classes, num_input_channels=2, base_filters=32).to(device)
                
            
                if withoptimizer ==True:
                    
                    state_dictsAll = torch.load(listmodels[mdlindx], map_location=device)
                    model.load_state_dict(state_dictsAll["model_state_dict"])
                    model.eval()
                
                else:    
                    model.load_state_dict(torch.load(listmodels[mdlindx], map_location=device))
                    model.eval()
                
                
                n = 1.0
                roi_size = patch_size #(32, 32, 32)
                sw_batch_size = 8
                val_overlap = 0.5
                mode="gaussian"
                
                with torch.cuda.amp.autocast():

                    preds = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model, mode = mode, overlap = val_overlap, sw_device = device, device = device)
                    
                flip_val_inputs = torch.flip(val_inputs, dims=(2, 3, 4))
                
                with torch.cuda.amp.autocast():
                    
                    mfpred = sliding_window_inference(flip_val_inputs, roi_size, sw_batch_size, model, mode = mode, overlap = val_overlap, sw_device = device, device = device)
                 
                flip_pred = torch.flip(mfpred, dims=(2, 3, 4))
                preds  = preds + flip_pred
                n = n + 1.0
                
                if tta:
                    
                    for _ in range(4):
                        # test time augmentations
                        _img = RandGaussianNoised(keys[0], prob=1.0, std=0.01)(infer_data)[keys[0]]


                        with torch.cuda.amp.autocast():

                            #val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model, sw_device = device, device = device)
                            _img_pred = sliding_window_inference(_img.to(device), roi_size, sw_batch_size, model, mode = mode, overlap = val_overlap, sw_device = device, device = device)
                            preds = preds + _img_pred
                            n = n + 1.0


                        _img_flip = torch.flip(_img, dims=(2, 3, 4)) 
                        with torch.cuda.amp.autocast():
                            _mf_flip_pred = sliding_window_inference(_img_flip.to(device), roi_size, sw_batch_size, model, mode = mode, overlap = val_overlap, sw_device = device, device = device)

                        _img_flip_pred = torch.flip(_mf_flip_pred, dims=(2, 3, 4))
                        preds = preds + _img_flip_pred
                        n = n + 1.0
                 
                
                preds = preds / n
                
                if softmaxEnsemble:
                    preds = torch.stack([post_trans_sig(i) for i in torch.unbind(preds, dim = 0)], dim = 0)
                val_outputsAll = val_outputsAll + preds
                n_model = n_model+1.0
                
                # Free up GPU memory after training
                model = None
                del model
                #train_loader, val_loader = None, None        
                gc.collect()
                torch.cuda.empty_cache()
                           
            val_outputsAll = val_outputsAll / n_model
            
            val_outputs = post_trans_bin(val_outputsAll) if softmaxEnsemble else post_trans_sigbin(val_outputsAll)
            '''Sigmoid or logit'''
            val_outputsSig = val_outputsAll if softmaxEnsemble else post_trans_sig(val_outputsAll)
            #val_outputsSig = val_outputsAll
            
            

            #val_labels2hot = torch.stack([one_hot_permute(i) for i in torch.unbind(val_labels, dim = 0)], dim = 0)


            val_labels_bin = get_binarize_tensor(val_labels, dim=1)
            val_outputs_bin = get_binarize_tensor(val_outputs, dim=1)

            dice_metric(y_pred=val_outputs_bin, y=val_labels_bin)

            klcc = KeepLargestConnectedComponent(applied_labels = [0, 1], is_onehot = True)  ##is_onehot=True or None by default
            #val_labels = klcc(val_labels.squeeze(dim=0)).unsqueeze(dim=0)
            val_labels = torch.stack([klcc(i) for i in torch.unbind(val_labels, dim = 0)], dim = 0)

            val_label4mSeg_C = get_segclass(val_outputs)
            #val_surv_labels = val_surv_labels.squeeze(dim=1)  ###Squeezing from B, 1 to B if needed
            y_pred = torch.cat([y_pred, val_label4mSeg_C.view(1)], dim=0)
            y = torch.cat([y, val_IDH_labels], dim=0)
                
                        
            
        mdice_value = dice_metric.aggregate().item()
        dice_metric.reset()        
        
        
        y_pred, y = y_pred.cpu(), y.cpu()

        if torch.all(torch.isnan(y_pred))==True:

            auc_result, baccscore, f1score, rscore, pscore, justaccscore = np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
            specificity_score, sensitivity_score = np.nan, np.nan
            #print('acc_metric#', np.nan, ', auc#', np.nan, ', f1#', np.nan, '\n')
            num_nanvalues = len(y_pred)

        else:

            num_nanvalues = torch.isnan(y_pred).sum().item()
            not_nanmask = torch.logical_not(torch.isnan(y_pred))
            y = y[not_nanmask]
            y_pred = y_pred[not_nanmask]


            acc_value = torch.eq(y_pred, y)
            acc_metric = acc_value.sum().item() / len(acc_value)
            
            
            '''auc metric'''
            auc_metric(y_pred, y)
            auc_result = auc_metric.aggregate()
            auc_metric.reset()
            
            y_pred, y = y_pred.numpy(), y.numpy()

            '''balanced accuracy and f1 score'''
            baccscore = balanced_accuracy_score(y, y_pred)
            f1score = f1_score(y, y_pred, average='micro')
            rscore=recall_score(y, y_pred)
            pscore = precision_score(y, y_pred)
            justaccscore = accuracy_score(y, y_pred)
            #print('acc_metric#', acc_metric, ', auc#', auc_result, ', f1#', f1score, '\n')
            tn, fp, fn, tp = confusion_matrix(y, y_pred).ravel()
 
            #specificity_score = cm1[0,0]/(cm1[0,0]+cm1[0,1])
            #sensitivity_score = cm1[1,1]/(cm1[1,0]+cm1[1,1])
            
            specificity_score = tn/(tn + fp)
            sensitivity_score = tp/(tp + fn)


        del y, y_pred
        

        
        #Accuracy (%) Precision (%) Sensitivity(%) Specificity(%) F1-Score(%)
        
        aDCT = {"dice_score": mdice_value,  "Balanced Accuracy": baccscore, 'Precision': pscore, 'Recall':rscore, 'Sensitivity':sensitivity_score, 'Specificity':specificity_score,\
                'F1-Score':f1score, 'Accuracy':justaccscore, 'AUC Metric':auc_result, 'NanSubjectNos':num_nanvalues}
        
                    
#     dfET = pd.DataFrame({"BraTS21ID":Infer_idLst, "Model": ["DynUnet"]* len(ETdices), "Channels":[channelNums]*len(ETdices), "Tumor regions": ["ET"]*len(ETdices), "Dice score": ETdices, "HD95": ETHD95s})
#     dfTC = pd.DataFrame({"BraTS21ID":Infer_idLst, "Model": ["DynUnet"]* len(TCdices), "Channels":[channelNums]*len(TCdices), "Tumor regions": ["TC"]*len(TCdices), "Dice score": TCdices, "HD95": TCHD95s})
#     dfWT = pd.DataFrame({"BraTS21ID":Infer_idLst,"Model": ["DynUnet"]* len(WTdices), "Channels":[channelNums]*len(WTdices), "Tumor regions": ["WT"]*len(WTdices), "Dice score": WTdices, "HD95": WTHD95s})
#     dfWT_TC_ET = pd.DataFrame({"BraTS21ID":Infer_idLst,"Model": ["DynUnet"]* len(WTdices), "Channels":[channelNums]*len(WTdices), "Tumor regions": ["WT_TC_ET_Regions"]*len(WTdices), "Dice score": Alldices, "HD95": AllHD95s})
    
#     dfRegions = pd.concat([dfTC, dfWT, dfET, dfWT_TC_ET])
#     return dfRegions


    return aDCT

    
            

In [56]:
prediction_folder = f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand' 
# modelDCTList = {'fold0': [glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7870_epoch469.pt*')[0],
#                           glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pt*')[0]],\
#                'fold1':[glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold1_0.7317_epoch333.pt*')[0],
#                         glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold1_0.7479_epoch335.pt*')[0]],\
#                'fold2':[glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold2_0.8145_epoch188.pt*')[0],
#                         glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold2_0.8179_epoch163.pt*')[0]]}

modelDCTList = {'fold0': glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pt*'),\
               'fold1':glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold1_0.7317_epoch333.pt*'),\
               'fold2':glob.glob(f'{save_dir}/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold2_0.8145_epoch188.pt**')}

modelDCTList

{'fold0': ['/raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth'],
 'fold1': ['/raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold1_0.7317_epoch333.pth'],
 'fold2': ['/raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold2_0.8145_epoch188.pth']}

In [57]:
start_inference = True

In [58]:
n_splits = 3
aDCTResultList = list()
if start_inference:
    #### Running 10 folds
    for i in range(0, n_splits):
        

        
        infer_files_fld = copy.deepcopy(BraTS20SubjectsIDHTestDCT[f'fold{i}'])
        
        infer_dataset_fld = monai.data.Dataset(data=infer_files_fld, transform=val_transforms)
       
        infer_loader_fld = monai.data.DataLoader(infer_dataset_fld, batch_size=1, shuffle=False) #num_workers=2, pin_memory=True
        
        aDCTResult = inferWithTA(data_loader = infer_loader_fld, listmodels=modelDCTList[f'fold{i}'], prediction_folder=prediction_folder, topk=len(modelDCTList[f'fold{i}']), num_channels=4,\
                    orientation="LPS", withoptimizer = False, softmaxEnsemble= True, tta = True)
        aDCTResult['TestSplitName'] = f"split{i}"
        aDCTResultList.append(aDCTResult.copy())
        
    
    start_inference = False

available model file: /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth.


Modifying image pixdim from [1. 1. 1. 1.] to [  1.           1.           1.         239.00209204]
  0%|                                                                                                                                 | 0/122 [00:00<?, ?it/s]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  1%|▉                                                                                                                        | 1/122 [00:07<15:31,  7.69s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  2%|█▉                                                                                                                       | 2/122 [00:16<16:32,  8.27s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  2%|██▉                                                                                                                      | 3/122 [00:20<12:56,  6.52s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  3%|███▉                                                                                                                     | 4/122 [00:26<12:13,  6.21s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  4%|████▉                                                                                                                    | 5/122 [00:29<09:55,  5.09s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  5%|█████▉                                                                                                                   | 6/122 [00:33<09:15,  4.79s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  6%|██████▉                                                                                                                  | 7/122 [00:37<08:38,  4.51s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  7%|███████▉                                                                                                                 | 8/122 [00:43<09:12,  4.85s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  7%|████████▉                                                                                                                | 9/122 [00:47<08:26,  4.48s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  8%|█████████▊                                                                                                              | 10/122 [00:51<08:25,  4.51s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


  9%|██████████▊                                                                                                             | 11/122 [00:55<08:09,  4.41s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 10%|███████████▊                                                                                                            | 12/122 [01:01<08:43,  4.76s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 11%|████████████▊                                                                                                           | 13/122 [01:09<10:40,  5.87s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 11%|█████████████▊                                                                                                          | 14/122 [01:14<10:10,  5.65s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 12%|██████████████▊                                                                                                         | 15/122 [01:19<09:44,  5.46s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 13%|███████████████▋                                                                                                        | 16/122 [01:25<09:35,  5.43s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 14%|████████████████▋                                                                                                       | 17/122 [01:30<09:13,  5.27s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 15%|█████████████████▋                                                                                                      | 18/122 [01:36<09:53,  5.71s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 16%|██████████████████▋                                                                                                     | 19/122 [01:42<09:58,  5.81s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 16%|███████████████████▋                                                                                                    | 20/122 [01:47<09:14,  5.44s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 17%|████████████████████▋                                                                                                   | 21/122 [01:52<08:55,  5.30s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 18%|█████████████████████▋                                                                                                  | 22/122 [01:58<09:18,  5.59s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 19%|██████████████████████▌                                                                                                 | 23/122 [02:02<08:11,  4.96s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 20%|███████████████████████▌                                                                                                | 24/122 [02:06<07:39,  4.69s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 20%|████████████████████████▌                                                                                               | 25/122 [02:13<08:43,  5.40s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 21%|█████████████████████████▌                                                                                              | 26/122 [02:18<08:35,  5.37s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 22%|██████████████████████████▌                                                                                             | 27/122 [02:24<08:46,  5.54s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 23%|███████████████████████████▌                                                                                            | 28/122 [02:30<08:57,  5.71s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 24%|████████████████████████████▌                                                                                           | 29/122 [02:37<09:15,  5.98s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 25%|█████████████████████████████▌                                                                                          | 30/122 [02:42<08:33,  5.59s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 25%|██████████████████████████████▍                                                                                         | 31/122 [02:46<08:08,  5.37s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 26%|███████████████████████████████▍                                                                                        | 32/122 [02:50<07:21,  4.91s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 27%|████████████████████████████████▍                                                                                       | 33/122 [02:58<08:23,  5.66s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 28%|█████████████████████████████████▍                                                                                      | 34/122 [03:02<07:47,  5.31s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 29%|██████████████████████████████████▍                                                                                     | 35/122 [03:07<07:20,  5.07s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 30%|███████████████████████████████████▍                                                                                    | 36/122 [03:10<06:43,  4.69s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 30%|████████████████████████████████████▍                                                                                   | 37/122 [03:18<08:04,  5.70s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 31%|█████████████████████████████████████▍                                                                                  | 38/122 [03:25<08:28,  6.05s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 32%|██████████████████████████████████████▎                                                                                 | 39/122 [03:32<08:24,  6.08s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 33%|███████████████████████████████████████▎                                                                                | 40/122 [03:35<07:25,  5.43s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 34%|████████████████████████████████████████▎                                                                               | 41/122 [03:44<08:44,  6.48s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 34%|█████████████████████████████████████████▎                                                                              | 42/122 [03:50<08:12,  6.15s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 35%|██████████████████████████████████████████▎                                                                             | 43/122 [03:54<07:12,  5.47s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 36%|███████████████████████████████████████████▎                                                                            | 44/122 [03:59<07:05,  5.46s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 37%|████████████████████████████████████████████▎                                                                           | 45/122 [04:10<09:12,  7.18s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 38%|█████████████████████████████████████████████▏                                                                          | 46/122 [04:17<08:51,  6.99s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 39%|██████████████████████████████████████████████▏                                                                         | 47/122 [04:26<09:33,  7.64s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 39%|███████████████████████████████████████████████▏                                                                        | 48/122 [04:30<07:58,  6.47s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 40%|████████████████████████████████████████████████▏                                                                       | 49/122 [04:34<06:58,  5.73s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 41%|█████████████████████████████████████████████████▏                                                                      | 50/122 [04:39<06:45,  5.64s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 42%|██████████████████████████████████████████████████▏                                                                     | 51/122 [04:45<06:37,  5.60s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 43%|███████████████████████████████████████████████████▏                                                                    | 52/122 [04:48<05:44,  4.92s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 43%|████████████████████████████████████████████████████▏                                                                   | 53/122 [04:58<07:31,  6.55s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 44%|█████████████████████████████████████████████████████                                                                   | 54/122 [05:06<07:48,  6.89s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 45%|██████████████████████████████████████████████████████                                                                  | 55/122 [05:09<06:30,  5.83s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 46%|███████████████████████████████████████████████████████                                                                 | 56/122 [05:17<07:08,  6.49s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 47%|████████████████████████████████████████████████████████                                                                | 57/122 [05:23<06:44,  6.22s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 48%|█████████████████████████████████████████████████████████                                                               | 58/122 [05:26<05:40,  5.32s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 48%|██████████████████████████████████████████████████████████                                                              | 59/122 [05:30<05:06,  4.86s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 49%|███████████████████████████████████████████████████████████                                                             | 60/122 [05:34<04:41,  4.53s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 50%|████████████████████████████████████████████████████████████                                                            | 61/122 [05:37<04:21,  4.29s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 51%|████████████████████████████████████████████████████████████▉                                                           | 62/122 [05:42<04:14,  4.24s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 52%|█████████████████████████████████████████████████████████████▉                                                          | 63/122 [05:45<03:54,  3.97s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 52%|██████████████████████████████████████████████████████████████▉                                                         | 64/122 [05:52<04:36,  4.76s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 53%|███████████████████████████████████████████████████████████████▉                                                        | 65/122 [05:55<04:06,  4.33s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 54%|████████████████████████████████████████████████████████████████▉                                                       | 66/122 [06:01<04:25,  4.74s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 55%|█████████████████████████████████████████████████████████████████▉                                                      | 67/122 [06:09<05:25,  5.92s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 56%|██████████████████████████████████████████████████████████████████▉                                                     | 68/122 [06:13<04:43,  5.25s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 57%|███████████████████████████████████████████████████████████████████▊                                                    | 69/122 [06:20<05:03,  5.72s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 57%|████████████████████████████████████████████████████████████████████▊                                                   | 70/122 [06:25<04:49,  5.57s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 58%|█████████████████████████████████████████████████████████████████████▊                                                  | 71/122 [06:29<04:15,  5.01s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 59%|██████████████████████████████████████████████████████████████████████▊                                                 | 72/122 [06:34<04:22,  5.24s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 60%|███████████████████████████████████████████████████████████████████████▊                                                | 73/122 [06:39<04:00,  4.91s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 61%|████████████████████████████████████████████████████████████████████████▊                                               | 74/122 [06:43<03:53,  4.87s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 61%|█████████████████████████████████████████████████████████████████████████▊                                              | 75/122 [06:47<03:37,  4.62s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 62%|██████████████████████████████████████████████████████████████████████████▊                                             | 76/122 [06:51<03:17,  4.29s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 63%|███████████████████████████████████████████████████████████████████████████▋                                            | 77/122 [06:57<03:33,  4.75s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 64%|████████████████████████████████████████████████████████████████████████████▋                                           | 78/122 [07:03<03:45,  5.12s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 65%|█████████████████████████████████████████████████████████████████████████████▋                                          | 79/122 [07:07<03:33,  4.97s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 66%|██████████████████████████████████████████████████████████████████████████████▋                                         | 80/122 [07:11<03:11,  4.57s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 66%|███████████████████████████████████████████████████████████████████████████████▋                                        | 81/122 [07:17<03:21,  4.93s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 67%|████████████████████████████████████████████████████████████████████████████████▋                                       | 82/122 [07:24<03:41,  5.55s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 68%|█████████████████████████████████████████████████████████████████████████████████▋                                      | 83/122 [07:28<03:22,  5.19s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 69%|██████████████████████████████████████████████████████████████████████████████████▌                                     | 84/122 [07:33<03:19,  5.24s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 70%|███████████████████████████████████████████████████████████████████████████████████▌                                    | 85/122 [07:39<03:14,  5.26s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 70%|████████████████████████████████████████████████████████████████████████████████████▌                                   | 86/122 [07:42<02:45,  4.59s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 71%|█████████████████████████████████████████████████████████████████████████████████████▌                                  | 87/122 [07:47<02:46,  4.75s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 72%|██████████████████████████████████████████████████████████████████████████████████████▌                                 | 88/122 [07:51<02:34,  4.55s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 73%|███████████████████████████████████████████████████████████████████████████████████████▌                                | 89/122 [07:55<02:20,  4.25s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 74%|████████████████████████████████████████████████████████████████████████████████████████▌                               | 90/122 [07:59<02:21,  4.43s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 75%|█████████████████████████████████████████████████████████████████████████████████████████▌                              | 91/122 [08:06<02:37,  5.08s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 75%|██████████████████████████████████████████████████████████████████████████████████████████▍                             | 92/122 [08:10<02:25,  4.85s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 76%|███████████████████████████████████████████████████████████████████████████████████████████▍                            | 93/122 [08:15<02:17,  4.76s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 77%|████████████████████████████████████████████████████████████████████████████████████████████▍                           | 94/122 [08:19<02:10,  4.66s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 78%|█████████████████████████████████████████████████████████████████████████████████████████████▍                          | 95/122 [08:24<02:04,  4.60s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 79%|██████████████████████████████████████████████████████████████████████████████████████████████▍                         | 96/122 [08:28<01:56,  4.47s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 80%|███████████████████████████████████████████████████████████████████████████████████████████████▍                        | 97/122 [08:32<01:47,  4.30s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 80%|████████████████████████████████████████████████████████████████████████████████████████████████▍                       | 98/122 [08:36<01:39,  4.14s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 81%|█████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 99/122 [08:40<01:36,  4.20s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 82%|█████████████████████████████████████████████████████████████████████████████████████████████████▌                     | 100/122 [08:44<01:28,  4.04s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 83%|██████████████████████████████████████████████████████████████████████████████████████████████████▌                    | 101/122 [08:48<01:29,  4.27s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 84%|███████████████████████████████████████████████████████████████████████████████████████████████████▍                   | 102/122 [08:51<01:18,  3.92s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 84%|████████████████████████████████████████████████████████████████████████████████████████████████████▍                  | 103/122 [08:55<01:11,  3.77s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 85%|█████████████████████████████████████████████████████████████████████████████████████████████████████▍                 | 104/122 [08:59<01:11,  3.99s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 86%|██████████████████████████████████████████████████████████████████████████████████████████████████████▍                | 105/122 [09:04<01:11,  4.18s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 87%|███████████████████████████████████████████████████████████████████████████████████████████████████████▍               | 106/122 [09:08<01:06,  4.16s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎              | 107/122 [09:12<01:01,  4.10s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 108/122 [09:17<00:58,  4.20s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 89%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▎            | 109/122 [09:21<00:53,  4.13s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 90%|███████████████████████████████████████████████████████████████████████████████████████████████████████████▎           | 110/122 [09:24<00:48,  4.00s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 91%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▎          | 111/122 [09:31<00:53,  4.84s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▏         | 112/122 [09:36<00:47,  4.73s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▏        | 113/122 [09:39<00:39,  4.38s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 93%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████▏       | 114/122 [09:48<00:44,  5.60s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏      | 115/122 [09:52<00:36,  5.23s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 95%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 116/122 [09:56<00:29,  4.94s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 96%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████     | 117/122 [10:00<00:22,  4.53s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 97%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████    | 118/122 [10:05<00:19,  4.83s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 98%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████   | 119/122 [10:09<00:13,  4.64s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 98%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████  | 120/122 [10:12<00:08,  4.14s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


 99%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████ | 121/122 [10:16<00:03,  3.96s/it]

Model 0, /raid/brats2021/pthTCGA_1p19q_CoDeletion/DynUNetVariants_TCGA/AttnDynUNet_BratsTCGA_1p19q_3CV_2ChnlsMorePatch_OnlyWSampler_Infer1PatchSWIRngr21_2nRatioclass_HistStand_Fold0_0.7880_epoch430.pth is running now


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 122/122 [10:20<00:00,  5.08s/it]


NameError: name 'confusion_matrix' is not defined

In [None]:
print('Result on the testset (3 testsest from 3 splits)')
DFResult = pd.DataFrame.from_dict(aDCTResultList)
display(DFResult)
DFResult.describe()

# roc_auc_score(y, y_score)

#### Section for testing the user-defined functions and classes

In [None]:
def one_hot_permute(x):
    return F.one_hot(x.squeeze(dim=0).long(), num_classes=3).permute(2, 0, 1)

In [None]:
xc= torch.tensor([0, 1, 0, 0, 2, 3, 10])
if torch.any(torch.eq(xc, 11)):
    print('Do')

In [None]:
xc = torch.tensor([0, 100, 500, 10000, 5])
torch.argmax(xc)

In [None]:
torch.argsort(xc)[-1]

In [None]:
import torch
is_onehot = True
img = torch.ones((1, 64, 64, 64))
is_onehot2 = img.shape[0] > 1 if is_onehot is None else is_onehot
is_onehot2

In [None]:
dd = np.array([[1.]])
dd.shape

In [None]:
dx = np.stack([dd, dd], axis = 0)
dx.shape

In [None]:
np.concatenate([dx, dx], axis = 0).shape

In [None]:
P1=torch.tensor([[0, 1, 1, 1,  0, 0, 0, 0, 0]])
P2=torch.tensor([[0, 1, 1, 1,   0, 0, 1, 1, 0]])
P3=torch.tensor([[1, 0, 1, 1,   0, 0, 1, 1, 1]])
P4=torch.tensor([[1, 0, 1, 1,   0, 0, 1, 1, 0]])
P5=torch.tensor([[1, 0, 1, 1,   0, 0, 1, 1, 1]])

In [None]:
cc = torch.mode(torch.cat((P1, P2, P3, P4, P5), dim=0), dim=0, keepdim=True)[0]
cc

In [None]:
P1.shape

In [None]:
P00=torch.tensor([1])
torch.mode(P00)[0]

In [None]:
torch.mode(P1)[0].view(1)

In [None]:
atnsr = torch.tensor((0, 0, 1, 1))
xv, xc = torch.unique(atnsr, return_counts  = True)

In [None]:
xc, xv[0].item()

In [None]:
torch.any(torch.eq(xv, 1))
xv

In [None]:
xc.shape[0]

In [None]:
torch.argmax(torch.tensor([1000, 2000, 3000, float('NaN'), 5, 600]))

In [None]:
torch.mode(torch.tensor([[5, 3, 3, float('NaN'),  float('NaN'), float('NaN'), float('NaN'), 0]]))

In [None]:
torch.mode(torch.tensor([60, 60, 50, 50, 60]))[0]

In [None]:
torch.tensor(float('NaN'))

In [None]:
xdvc = device
def get_bin_tensor(xbatchclass):
   
    if torch.all(torch.isnan(xbatchclass))==True:
            
            return torch.tensor(float('NaN')).to(xdvc)
        
    else:

        num_xbatchnanvalues = torch.isnan(xbatchclass).sum().item()
        not_xbatchnanmask = torch.logical_not(torch.isnan(xbatchclass))
        xbatchclass = xbatchclass[not_xbatchnanmask]

        xclassVal_01, xclassCnt_01 =xbatchclass.unique(return_counts = True)

        if xclassCnt_01.shape[0]==1:
            return xclassVal_01[0].to(xdvc)


        if xclassCnt_01.shape[0]==2:
            if xclassCnt_01[0]!=xclassCnt_01[1]:
                ''' xclassCnt_01 will always be two values converting [7, 8] to 1; [8, 7] to 0'''
                return torch.argmax(xclassCnt_01).to(xdvc)  

            else:
                return torch.tensor(float('NaN')).to(xdvc)
        
        

In [None]:
get_bin_tensor(torch.tensor([float('NaN'), float('NaN'), float('NaN'), 0]))

In [None]:
torch.argmax(torch.tensor([0, 1]))

In [None]:
xclassVal_01, xclassCnt_01 =torch.tensor([0, 0, 0, 1]).unique(return_counts = True)
xclassVal_01

In [None]:
torch.argmax(xclassCnt_01)

In [None]:
torch.argmax(torch.tensor([7, 8]))

In [None]:
def get_segclass(x_chlist):
    
    xclassNoList = list()
    xvalueList = list()
    
    for x_i in x_chlist:
    
        xv, xc = torch.unique(x_i, return_counts  = True)

        if xc.shape[0]==1:
            if xv==0:
                xclassNoList.append(-1)
                xvalueList.append(xv[0].item())
            elif xv==1:
                xclassNoList.append(xc[0].item())
                xvalueList.append(xv[0].item())
            else:
                print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')
                

        elif xc.shape[0]==2:
                if torch.any(torch.eq(xv, 1)):
                    xclassNoList.append(xc[1].item())
                    xvalueList.append(xv[1].item())
                else:
                    print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')
        else:
            print('The function only supports binarized tensor (binarized unique values, 0(n=...) and 1(n=...) only)\n')

    #pdb.set_trace()
    #if torch.any(torch.eq(torch.tensor(xvalueList), 1)):
        
    if xclassNoList[0]!=xclassNoList[1]: 
        xclass = torch.argmax(torch.tensor(xclassNoList).to(xdvc))
    else:
        xclass = torch.tensor(float('NaN')).to(xdvc)
            
    #else:
        '''If all uniques class values are 0, we are assigning nan values as a class'''
    #    xclass = torch.tensor(float('NaN')).to(xdvc)
    
    
    return xclass

In [None]:
get_segclass([torch.tensor([[1,1,1,1,1],[1,1,1,1,0]]), torch.tensor([[1,1,1,1,1], [1,1,1,1,0]])])

In [None]:
torch.tensor([0, 0, 0, float('NaN'), 0]).unique(return_counts = True)