# Installs & Imports

In [1]:
!pip install -q albumentations
!pip install -q natsort
!pip install -q patchify
!pip install -q lightning
!pip install -q lightning[extra]
!pip install -q segmentation_models_pytorch
!pip install -q wandb
!pip install -q "monai[einops]==1.4"
!pip install -q nibabel
!pip install -q ttach
!pip install -q medpy

[0m

In [2]:
! pip install -q scikit-image

[0m

In [3]:
!pip install -q scikit-learn

[0m

In [4]:
!python -c "import monai" || pip install -q "monai-weekly[gdown, nibabel, tqdm, ignite]"
!python -c "import matplotlib" || pip install -q matplotlib
!pip install -q medpy
%matplotlib inline

[0m

In [1]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings


# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm
import pprint as pp

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import seed_everything

# Weights & Biases
import wandb

# Patchify
from patchify import patchify, unpatchify

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init

from natsort import natsorted, ns
from utiils import *
import monai

from boundaryloss.dataloader import dist_map_transform
from boundaryloss.utils import simplex
from boundaryloss.losses import *
from torch.utils.data import default_collate

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()

# Uncomment below line to ignore warnings
# warnings.filterwarnings("ignore")

MONAI version: 1.4.0
Numpy version: 1.26.3
Pytorch version: 2.2.1
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 46a5272196a6c2590ca2589029eed8e4d56ff008
MONAI __file__: /opt/conda/lib/python3.10/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.25.2
scipy version: 1.15.2
Pillow version: 10.2.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.17.1
tqdm version: 4.65.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.0
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing t

# Env Setup

In [2]:
# Function to process each segment
def convert_cases_to_remove_txt(text):
    lines = text.strip().split("\n")
    case_dict = {}
    current_case = None

    for line in lines:
        if line == "":
            continue
        if line.startswith('/'):
            if current_case:
                case_dict[current_case].append("Dataset-arrays-4/"+ "/".join(line.split('/')[-3:]))
        else:
            current_case = line.strip()
            case_dict[current_case] = []

    case_dict = {k: v for k, v in case_dict.items() if v}
    
    return case_dict

In [3]:
cases_to_remove_txt = """BNB1172(DF)
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_130.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_131.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_196.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_197.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_198.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_199.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_200.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_201.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_202.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BNB1172(DF)/images/BNB1172(DF)_203.npy

D1AP5(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/ D1AP5(VR)/images/D1AP5(VR)_91.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/ D1AP5(VR)/images/D1AP5(VR)_92.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/ D1AP5(VR)/images/D1AP5(VR)_97.npy

D1AP7(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP7(VR)/images/D1AP7(VR)_56.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP7(VR)/images/D1AP7(VR)_78.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP7(VR)/images/D1AP7(VR)_79.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP7(VR)/images/D1AP7(VR)_80.npy

D1AP12(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP12(VR)/images/D1AP12(VR)_21.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D1AP12(VR)/images/D1AP12(VR)_22.npy

D2MP1(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP1(VR)/images/D2MP1(VR)_47.npy
BRUTTINO

D2MP3(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP3(VR)/images/D2MP3(VR)_133.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP3(VR)/images/D2MP3(VR)_146.npy
C’è massa non segmentata, probabilmente benigna ma chiarire

D2MP4(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_60.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_61.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_84.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_86.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_128.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP4(VR)/images/D2MP4(VR)_129.npy

D2MP6(VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP6(VR)/images/D2MP6(VR)_24.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP6(VR)/images/D2MP6(VR)_31.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP6(VR)/images/D2MP6(VR)_35.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP6(VR)/images/D2MP6(VR)_52.npy

D3MP7 (VR)
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_38.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_39.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_40.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_41.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_42.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_64.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_65.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_66.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_67.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_68.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_69.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_70.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_71.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D3MP7 (VR)/images/D3MP7 (VR)_72.npy

DFC1168(DF)
/content/drive/MyDrive/Tesi/Dataset-arrays/DFC1168(DF)/images/DFC1168(DF)_173.npy

DTM0772(1,5)
/content/drive/MyDrive/Tesi/Dataset-arrays/DTM0772(1,5)/images/DTM0772(1,5)_21.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DTM0772(1,5)/images/DTM0772(1,5)_47.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DTM0772(1,5)/images/DTM0772(1,5)_50.npy

GMG0961(3)
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_71.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_81.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_82.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_83.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_87.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_101.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_102.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_104.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_109.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GMG0961(3)/images/GMG0961(3)_110.npy
LGM0159(1,5)
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_41.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_42.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_43.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_44.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_45.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_46.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_47.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_48.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_49.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_70.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LGM0159(1,5)/images/LGM0159(1,5)_71.npy
Bruttino

MD0773(DF)
/content/drive/MyDrive/Tesi/Dataset-arrays/MD0773(DF)/images/MD0773(DF)_115.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MD0773(DF)/images/MD0773(DF)_116.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MD0773(DF)/images/MD0773(DF)_153.npy

PF0473(1,5)
/content/drive/MyDrive/Tesi/Dataset-arrays/PF0473(1,5)/images/PF0473(1,5)_50.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PF0473(1,5)/images/PF0473(1,5)_51.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PF0473(1,5)/images/PF0473(1,5)_52.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PF0473(1,5)/images/PF0473(1,5)_59.npy
Brutto

PMG0761(3)
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_73.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_85.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_89.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_90.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_91.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_92.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_93.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_93.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_95.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_96.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_97.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_98.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_99.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_112.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_113.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_124.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_125.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_126.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_133.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_134.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_135.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_136.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_137.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_145.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_146.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_147.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_148.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_149.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_150.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_151.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_152.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_153.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PMG0761(3)/images/PMG0761(3)_154.npy
Bruttissimo

RD0175
/content/drive/MyDrive/Tesi/Dataset-arrays/RD0175/images/RD0175_82.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/RD0175/images/RD0175_83.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/RD0175/images/RD0175_123.npy

SM0972(DF)
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_104.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_126.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_127.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_128.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_129.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_135.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_142.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_143.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_144.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_145.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_146.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_147.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_148.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_149.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_150.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_151.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_152.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_153.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_154.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_155.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_156.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_157.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_158.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_159.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_160.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_161.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_162.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_163.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM0972(DF)/images/SM0972(DF)_164.npy
Usare in validation

UFR0987
/content/drive/MyDrive/Tesi/Dataset-arrays/UFR0987/images/UFR0987_71.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/UFR0987/images/UFR0987_72.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/UFR0987/images/UFR0987_73.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/UFR0987/images/UFR0987_89.npy

ZT0279(3)
/content/drive/MyDrive/Tesi/Dataset-arrays/ZT0279(3)/images/ZT0279(3)_137.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/ZT0279(3)/images/ZT0279(3)_138.npy

BC1179B
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_57.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_71.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_96.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_97.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_115.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_116.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/BC1179B-merged/images/BC1179B_117.npy

D2MP9b(VR)-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR) merged/images/D2MP9b(VR)_35.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR) merged/images/D2MP9b(VR)_36.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR) merged/images/D2MP9b(VR)_37.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR)-merged/images/D2MP9b(VR)_45.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR)-merged/images/D2MP9b(VR)_63.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/D2MP9b(VR)-merged/images/D2MP9b(VR)_69.npy

GA07(DF)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_62.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_83.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_104.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_105.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_106.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/GA07(DF)B-merged/images/GA07(DF)B_119.npy

DCC0340(1,5)-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_27.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_28.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_29.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_30.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_31.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_32.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_33.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_46.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_47.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_48.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_49.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_50.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_51.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_55.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_56.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_57.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/DCC0340(1,5)-merged/images/DCC0340(1,5)_58.npy

HV1263(1,5)-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_5.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_6.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_7.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_17.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_18.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_19.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_20.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_21.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_22.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_23.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_24.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_21.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_35.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/HV1263(1,5)-merged/images/HV1263(1,5)_36.npy

LA0248B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/LA0248B-merged/images/LA0248B_115.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LA0248B-merged/images/LA0248B_120.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LA0248B-merged/images/LA0248B_128.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/LA0248B-merged/images/LA0248B_146.npy

PE0468(1,5)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_17.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_18.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_19.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_28.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_29.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_30.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_38.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_48.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_49.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_50.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PE0468(1,5)B-merged/images/PE0468(1,5)B_51.npy

MV1276(1,5)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_41.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_42.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_43.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_44.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_45.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_46.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_58.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_59.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/MV1276(1,5)B-merged/images/MV1276(1,5)B_60.npy

VS0976(1,5)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_20.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_21.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_22.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_27.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_28.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_29.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_36.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_37.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_43.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_44.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_45.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_46.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_47.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_55.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_58.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VS0976(1,5)B-merged/images/VS0976(1,5)B_59.npy
BRUTTO

VDMB0751(DF)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_74.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_75.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_78.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_85.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_86.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_87.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_88.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_89.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_90.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_91.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_92.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_93.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_94.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_95.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_96.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_97.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_98.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_99.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_100.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_101.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_102.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_103.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_104.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_105.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_106.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_107.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_108.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_109.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_110.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_111.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_112.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_130.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_136.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_146.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_147.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_148.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_149.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_152.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/VDMB0751(DF)B-merged/images/VDMB0751(DF)B_157.npy
VALIDATION

SM1232B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/SM1232B-merged/images/SM1232B_84.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM1232B-merged/images/SM1232B_85.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM1232B-merged/images/SM1232B_86.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/SM1232B-merged/images/SM1232B_138.npy

PS0446(1.5)B-merged
/content/drive/MyDrive/Tesi/Dataset-arrays/PS0446(1.5)B-merged/images/PS0446(1.5)B_28.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PS0446(1.5)B-merged/images/PS0446(1.5)B_29.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PS0446(1.5)B-merged/images/PS0446(1.5)B_34.npy
/content/drive/MyDrive/Tesi/Dataset-arrays/PS0446(1.5)B-merged/images/PS0446(1.5)B_40.npy"""
to_remove_dict = convert_cases_to_remove_txt(cases_to_remove_txt)

In [4]:
checkpoints_dir="checkpoints"
wandb.login(key = "2bc18e4744fb0771a16fd009b7aa2c98c79efc49")


train_ratio = 0.8
validation_ratio = 0.2
test_ratio = 1 - train_ratio
SEED = 200
n_cpu = os.cpu_count()

def seed_worker(worker_id):
    worker_seed = 200
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def reseed():
    SEED = 200
    print(f'Using random seed {SEED}...')

    g = torch.Generator()
    g.manual_seed(SEED)

    seed_everything(SEED, workers=True)
    os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    os.environ["PYTHONHASHSEED"] = str(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic=True
    torch.use_deterministic_algorithms(True)

    return g


g = reseed()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mpablo-giaccaglia[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
Seed set to 200


Using random seed 200...


In [5]:
dataset_base_path = "Dataset-arrays-4-FINAL"
all_folders = os.listdir(dataset_base_path)
len(all_folders)

103

# Utility functions

In [6]:
import numpy as np
from skimage.measure import label as LABEL, regionprops
from scipy.spatial.distance import cdist
import ttach as tta
from ttach.base import Merger

from monai.transforms import KeepLargestConnectedComponent, RemoveSmallObjects


import cv2
import numpy as np

import torch
import torch.nn.functional as F

import numpy as np
from scipy.ndimage import label as labell, generate_binary_structure

def compute_iou_imagewise_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, exclude_empty_only_gt = False,return_std=False):

    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    if return_std:

        mean_iou, std_iou = compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=exclude_empty, exclude_empty_only_gt =exclude_empty_only_gt, return_std=return_std)
        return mean_iou.item(), std_iou.item()

    else:

        return compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=exclude_empty).item()

def compute_dice_imagewise_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, exclude_empty_only_gt = False, return_std=False):

    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    if return_std:
        mean_dice, std_dice =  compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=exclude_empty, exclude_empty_only_gt=exclude_empty_only_gt, return_std=return_std)
        return mean_dice.item(), std_dice.item()

    else:

        return compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=exclude_empty).item()


def compute_mean_iou_imagewise_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, return_std=False, reduce_mean=True):
    # Concatenate tensors for each metric

    try:
        tp = torch.cat([tp for tp in TPs])
        fp = torch.cat([fp for fp in FPs])
        fn = torch.cat([fn for fn in FNs])
        tn = torch.cat([tn for tn in TNs])
    except:
        tp = TPs
        fp = FPs
        fn = FNs
        tn = TNs

    if exclude_empty:
        # Calculate IOU per image excluding empty cases
        iou1_per_image_no_empty = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none', exclude_empty=True)
        iou0_per_image_no_empty = compute_iou_from_metrics(tn, fn, tp, fp, reduction='none', exclude_empty=True)
        
        # Combine and filter valid IOU scores
        combined_iou_scores = np.hstack((iou0_per_image_no_empty, iou1_per_image_no_empty))
        valid_pairs = ~np.isnan(combined_iou_scores).any(axis=1)
        
        # Compute mean and optionally standard deviation
        mean_iou_per_image_no_empty = np.nanmean(combined_iou_scores[valid_pairs], axis=1)

        if not reduce_mean:
            return mean_iou_per_image_no_empty
        if return_std:
            std_iou_per_image_no_empty = np.nanstd(mean_iou_per_image_no_empty)
            return np.mean(mean_iou_per_image_no_empty), std_iou_per_image_no_empty
        else:
            return np.mean(mean_iou_per_image_no_empty)

    else:
        # Calculate IOU per image including empty cases
        iou1_per_image = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none')
        iou0_per_image = compute_iou_from_metrics(tn, fn, tp, fp, reduction='none')
        
        # Compute mean and optionally standard deviation
        combined_iou_scores = np.array([iou0_per_image.cpu().numpy(), iou1_per_image.cpu().numpy()])
        mean_iou_per_image = np.nanmean(combined_iou_scores, axis=0)
        
        if not reduce_mean:
            return mean_iou_per_image
        if return_std:
            std_iou_per_image = np.nanstd(mean_iou_per_image)
            return np.mean(mean_iou_per_image), std_iou_per_image
        else:
            return np.mean(mean_iou_per_image)

def compute_mean_dice_imagewise_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, return_std=False, reduce_mean=True):
    try:
        tp = torch.cat([tp for tp in TPs])
        fp = torch.cat([fp for fp in FPs])
        fn = torch.cat([fn for fn in FNs])
        tn = torch.cat([tn for tn in TNs])
    except:
        tp = TPs
        fp = FPs
        fn = FNs
        tn = TNs

    if exclude_empty:
        dice1_per_image_no_empty = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none', exclude_empty=True)
        dice0_per_image_no_empty = compute_dice_from_metrics(tn, fn, tp, fp, reduction='none', exclude_empty=True)
        combined_dice_scores = np.hstack((dice0_per_image_no_empty, dice1_per_image_no_empty))
        valid_pairs = ~np.isnan(combined_dice_scores).any(axis=1)
        mean_dice_per_image_no_empty = np.nanmean(combined_dice_scores[valid_pairs], axis=1)
        if not reduce_mean:
            return mean_dice_per_image_no_empty
        if return_std:
            std_dice_per_image_no_empty = np.std(np.nanmean(combined_dice_scores[valid_pairs], axis=1))
            return np.mean(mean_dice_per_image_no_empty), std_dice_per_image_no_empty
        else:
            return np.mean(mean_dice_per_image_no_empty)
    else:
        dice1_per_image = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none')
        dice0_per_image = compute_dice_from_metrics(tn, fn, tp, fp, reduction='none')
        combined_dice_scores = np.array([dice0_per_image.cpu().numpy(), dice1_per_image.cpu().numpy()])
        mean_dice_per_image = np.nanmean(combined_dice_scores, axis=0)
        if not reduce_mean:
            return mean_dice_per_image
        if return_std:
            std_dice_per_image = np.std(np.nanmean(combined_dice_scores, axis=0))
            return np.mean(mean_dice_per_image), std_dice_per_image
        else:
            return np.mean(mean_dice_per_image)

def plot_slices_side_by_side(volume1, volume2):
    """
    Plot corresponding slices from two CxHxW volumes side by side.
    
    :param volume1: First volume with shape CxHxW.
    :param volume2: Second volume with shape CxHxW.
    """
    H, W, B = volume1.shape  # Assuming volume1 and volume2 have the same shape
    
    # Set up the figure size dynamically based on the number of slices
    plt.figure(figsize=(10, 2 * B))
    
    for b in range(B):
        # Plot slice from volume 1
        plt.subplot(B, 2, 2*b + 1)  # Rows, Columns, Index
        plt.imshow(volume1[:, :,b], cmap='gray')
        plt.title(f'Slice {b + 1} - Volume 1')
        plt.axis('off')  # Hide axes ticks
        
        # Plot corresponding slice from volume 2
        plt.subplot(B, 2, 2*b + 2)  # Rows, Columns, Index
        plt.imshow(volume2[:, :,b], cmap='gray')
        plt.title(f'Slice {b + 1} - Volume 2')
        plt.axis('off')  # Hide axes ticks
    
    plt.tight_layout()
    plt.show()

def calculate_local_agreement(prob1, prob2, kernel_size=3, agreement_threshold=0.1):
    """
    Calculate local agreement between two probability masks using average pooling to simulate
    the surrounding window effect.
    """
    # Calculate absolute difference and apply threshold
    diff = torch.abs(prob1 - prob2)
    local_diff = F.avg_pool2d(diff.unsqueeze(0), kernel_size, stride=1, padding=kernel_size//2).squeeze(0)
    local_agreement = local_diff < agreement_threshold
    return local_agreement

def fill_gaps_in_masses(binary_mask, gap_filling_kernel_size=5):
    """
    Fills gaps in segmented masses using morphological closing.
    
    Parameters:
    - binary_mask: numpy.ndarray, the binary segmentation mask with masses.
    - gap_filling_kernel_size: int, the size of the square kernel used for gap filling.
    
    Returns:
    - gap_filled_mask: numpy.ndarray, the mask after filling gaps.
    """

    print(np.unique(binary_mask))
    # Define the square kernel based on the specified size
    kernel_gap_filling = np.ones((gap_filling_kernel_size, gap_filling_kernel_size), np.uint8)
    
    # Perform the morphological closing operation
    gap_filled_mask = cv2.morphologyEx(binary_mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel_gap_filling)
    
    return gap_filled_mask


def perform_dilation(image, dilation_size=3):
    """
    Perform dilation on an image using OpenCV.
    
    :param image: Input binary image with objects to dilate.
    :param dilation_size: Determines the size of the dilation kernel. Default is 3.
    :return: Image after dilation.
    """

    image = (image > 0).astype(np.uint8)
    # Create a square structuring element for dilation
    kernel = np.ones((dilation_size, dilation_size), np.uint8)
    
    # Perform dilation
    dilated_image = cv2.dilate(image, kernel, iterations=1)
    
    return dilated_image


def remove_far_masses_based_on_largest_mass(batch_masks, distance_threshold):
    """
    Remove masses in each mask of the batch that are far from the largest mass.

    Parameters:
    - batch_masks: numpy array of shape BCHW.
    - distance_threshold: distance beyond which a mass is considered far.

    Returns:
    - processed_masks: numpy array of shape BCHW with far masses removed.
    """
    processed_masks = np.zeros_like(batch_masks)

    for i, mask in enumerate(batch_masks):
        # Ensure the mask is 2D
        if mask.ndim == 3:  # BCHW where C=1
            mask_2d = mask[0]
        elif mask.ndim == 2:  # HW
            mask_2d = mask
        else:
            raise ValueError("Mask dimension is not correct. Expected 2D or 3D with single channel.")

        # Label the connected components with 2D connectivity
        labeled_mask = LABEL(mask_2d.cpu(), connectivity=1)

        # Calculate properties of each component
        regions = regionprops(labeled_mask)

        if not regions:
            continue

        # Find the largest mass
        largest_mass = max(regions, key=lambda x: x.area)

        # Get the centroid of the largest mass
        largest_mass_centroid = largest_mass.centroid

        # Identify and keep components close to the centroid of the largest mass
        for region in regions:
            centroid = region.centroid
            distance = np.sqrt((centroid[0] - largest_mass_centroid[0]) ** 2 + (centroid[1] - largest_mass_centroid[1]) ** 2)
            if distance < distance_threshold:
                processed_masks[i, 0, region.coords[:,0], region.coords[:,1]] = 1

    return processed_masks


def save_to_json(data, filename):
    try:
        with open(filename, 'w') as file:
            json.dump(data, file, indent=4)
        print(f"Dictionary successfully saved to {filename}")
    except Exception as e:
        print(f"Error saving dictionary to file: {e}")


def load_json_file(file_path):
    """Load and return the content of a JSON file."""
    with open(file_path, 'r') as file:
        return json.load(file)

def get_filenames(suffix, base_path, patient_ids, remove_black_samples=False, get_random_samples_and_remove_black_samples=False, get_top_bottom_and_remove_black_samples=False,random_samples_indexes_list=None,
                 remove_picked_samples=False):
    filenames = []

    create_random_samples_index_list = False

    if random_samples_indexes_list is None:
        random_samples_indexes_list = []
        create_random_samples_index_list=True

    for idx, patient_id in enumerate(patient_ids):
        path = os.path.join(base_path, patient_id) + "/" + suffix + "/"
        files = [os.path.join(path, p) for p in natsorted(os.listdir(path), alg=ns.IGNORECASE)]

        if get_random_samples_and_remove_black_samples:
              files_sampled = filter_samples_sample_aware(files, patient_id)
              if remove_picked_samples:
                  files_sampled = filter_samples_to_exclude(files_sampled, patient_id)
              filenames += files_sampled
              size = int(len(files_sampled)*0.25)


              random_samples_indexes = None if create_random_samples_index_list else random_samples_indexes_list[idx]
              files_random, random_samples_indexes = get_samples_size(files=files, patient_id=patient_id, size=size, random_samples=True, random_samples_indexes=random_samples_indexes)

              if create_random_samples_index_list:
                  random_samples_indexes_list.append(random_samples_indexes)

              filenames += files_random

        elif remove_black_samples:
              files_sampled = filter_samples_sample_aware(files, patient_id)
              if remove_picked_samples:
                  files_sampled = filter_samples_to_exclude(files_sampled, patient_id)
                
              filenames += files_sampled

        elif get_top_bottom_and_remove_black_samples:
              files_sampled = filter_samples_sample_aware(files, patient_id)
              if remove_picked_samples:
                  files_sampled = filter_samples_to_exclude(files_sampled, patient_id)
              filenames += files_sampled

              size = int(len(files_sampled)*0.25)
              files_top_bottom = get_samples_size(files=files, patient_id=patient_id, size=size, random_samples=False)
              filenames += files_top_bottom

        else:
              filenames += files

    if get_random_samples_and_remove_black_samples:
      return filenames, random_samples_indexes_list
    else:
      return filenames, None

def get_samples_size(files, patient_id, size=None, random_samples=False, random_samples_indexes=None):

    top_slices_len = d[patient_id]['start']
    top_slices = files[:top_slices_len]
    sample_size_top_slices = size

    bottom_slices_len = len(files) - d[patient_id]['end']
    bottom_slices = files[d[patient_id]['end']:]
    sample_size_bottom_slices = size

    if random_samples:
        if random_samples_indexes:
          subset_top_slices_random_indexes = random_samples_indexes[0]
        else:
          if sample_size_top_slices > len(top_slices):
            sample_size_top_slices=len(top_slices)

          subset_top_slices_random_indexes = random.sample(range(len(top_slices)), sample_size_top_slices)

        subset_top_slices = [top_slices[i] for i in subset_top_slices_random_indexes]

        if random_samples_indexes:
              subset_bottom_slices_random_indexes = random_samples_indexes[1]
        else:
          if sample_size_bottom_slices > len(bottom_slices):
            sample_size_bottom_slices=len(bottom_slices)

          subset_bottom_slices_random_indexes = random.sample(range(len(bottom_slices)), sample_size_bottom_slices)
            
        subset_bottom_slices=  [bottom_slices[i] for i in subset_bottom_slices_random_indexes]

        files_to_return = subset_top_slices + subset_bottom_slices
        return files_to_return, [subset_top_slices_random_indexes, subset_bottom_slices_random_indexes]



    else:
        if sample_size_top_slices > len(top_slices):
            sample_size_top_slices=len(top_slices)
        if sample_size_bottom_slices > len(bottom_slices):
            sample_size_bottom_slices=len(bottom_slices)

        subset_top_slices = top_slices[-sample_size_top_slices:]
        subset_bottom_slices = bottom_slices[:sample_size_bottom_slices]
        files_to_return = subset_top_slices + subset_bottom_slices
        return files_to_return



def filter_samples_sample_aware(files, patient_id):
    start, end = d[patient_id]['start'], d[patient_id]['end']
    return files[start+1:end]

def filter_samples_to_exclude(files, patient_id):
    filtered_list = []
    if patient_id not in to_remove_dict:
        return files
    files_to_exclude = to_remove_dict[patient_id]

    filtered_list = []
    for file in files:
        file_clean = file.replace("mask_","")
        file_clean = file_clean.replace("masks","images")
        
        if file_clean not in files_to_exclude:
            filtered_list.append(file)

    return filtered_list


def reconstruct_label(label, original_shape, thorax_crop_coords, bottom_crop_coords, crop_coords, resize_dims, trim_breast_coords):
    # First, resize the label back to the size before the final crop
    label_resized = F.interpolate(label.unsqueeze(0), size=resize_dims, mode='nearest-exact').squeeze(0)

    # Now, we need to reverse the crop operations in the correct order
    # Start with the most recent crop and work backward to the original state

    # Reverse the final crop to get to the state before bottom crop
    y1, y2, x1, x2 = crop_coords
    crop_reversed_label = torch.zeros((original_shape[0], y2-y1, x2-x1), dtype=label.dtype)
    crop_reversed_label = label_resized

    # Reverse the bottom crop to get to the state before breast trim
    x1, y1, x2, y2 = bottom_crop_coords
    bottom_crop_reversed_label = torch.zeros((original_shape[0], original_shape[1], crop_reversed_label.shape[2]), dtype=label.dtype)
    bottom_crop_reversed_label[:, y1:y2, :] = crop_reversed_label

    # Reverse the breast trim to get to the state before thorax crop
    start, end = trim_breast_coords
    breast_trim_reversed_label = torch.zeros((original_shape[0], original_shape[1], original_shape[2]), dtype=label.dtype)
    breast_trim_reversed_label[:, :, start:end] = bottom_crop_reversed_label

    # Reverse the thorax crop to get to the original state
    x1, y1, x2, y2 = thorax_crop_coords
    thorax_crop_reversed_label = torch.zeros(original_shape, dtype=label.dtype)
    thorax_crop_reversed_label[:, y1:y2, :] = breast_trim_reversed_label

    return thorax_crop_reversed_label


def filter_fn(image, max_ratio):
    c, h, w = image.shape

    if h >= w:
        if w == 0 or max_ratio < h/w:
          return False
    elif w >= h:
        if h==0 or max_ratio < w/h:
            return False
    return True

def train_custom_collate(batch):
    # Filter out None samples
    augmentations = Compose([monai.transforms.RandHistogramShiftd(keys=['image'], prob=0.2, num_control_points=4), 
                                      monai.transforms.RandRotated(keys=['image', 'label'],mode='nearest-exact', range_x=[0.1, 0.1], prob=0.3),
                                      monai.transforms.RandZoomd(keys=['image', 'label'],mode='nearest-exact', min_zoom = 1.3, max_zoom = 1.5, prob=0.3),
                                      #monai.transforms.RandCoarseDropoutd(keys=['image', 'label'], prob=0.3, holes=20, spatial_size=20, fill_value =0)
                                     ]
        
        )

    
    
    # Filter out None samples
    batch = [augmentations({
                'image' : copy.deepcopy(item['image']),
                'label' : copy.deepcopy(item['label']),
                'boundary': copy.deepcopy(item["boundary"])
            }) for sublist in batch for item in sublist if item['keep_sample']]
    
    if len(batch)>0:
        batch = default_collate(batch)
        return batch
    return None

def train_custom_collate_no_patches(batch):

    augmentations = Compose([monai.transforms.RandHistogramShiftd(keys=['image'], prob=0.2, num_control_points=4), 
                                      monai.transforms.RandRotated(keys=['image', 'label'],mode='nearest-exact', range_x=[0.1, 0.1], prob=0.3),
                                      monai.transforms.RandZoomd(keys=['image', 'label'],mode='nearest-exact', min_zoom = 1.3, max_zoom = 1.5, prob=0.3),
                                      #monai.transforms.RandCoarseDropoutd(keys=['image', 'label'], prob=0.3, holes=20, spatial_size=20, fill_value =0)
                                     ]
        
        )

    # Filter out None samples
    batch = [augmentations({
                'image' : copy.deepcopy(item['image']),
                'label' : copy.deepcopy(item['label']),
                'boundary': copy.deepcopy(item["boundary"])
            }) for item in batch if item['keep_sample']]

    
    if len(batch)>0:
        batch = default_collate(batch)
        return batch
    return None


def custom_collate(batch):

    
    batch = [item for sublist in batch for item in sublist if item['keep_sample']]

    
    if len(batch)>0:
        batch = default_collate(batch)
        return batch
    return None

def  custom_collate_no_patches(batch):

    # Filter out None samples
    batch = [item for item in batch if item['keep_sample']]
    
    if len(batch)>0:
        batch = default_collate(batch)
        return batch
    return None
    
def reverse_transformations(d, processed_label, mode='patches'):
    # Extract the processed label and transformation coordinates
    y1_crop, y2_crop, x1_crop, x2_crop = d['crop_coords']
    x1_bottom, y1_bottom, x2_bottom, y2_bottom = d['bottom_crop_coords']

    if mode=='patches':
        start_breast, end_breast = d['trim_breast_coords']
    x1_thorax, y1_thorax, x2_thorax, y2_thorax = d['thorax_crop_coords']
    intermediate_spatial_dim = d['dim_before_resize_final']

    # Step 0: Resize to the original spatial dimensions before the final crop
    # Assuming 'resize' was a downscaling operation and the original_spatial_dim is the target size
    label_resized = F.interpolate(processed_label.unsqueeze(0).unsqueeze(0).float(),
                                  size=intermediate_spatial_dim.tolist(),  # Excluding the batch size dimension
                                  mode='nearest-exact').squeeze(0).squeeze(0)  # Removing the added batch and channel dimensions

    
    
    pad_post_crop_coords = d['pad_post_crop_coords'].tolist()
    before1, before2, before3 = pad_post_crop_coords[0],pad_post_crop_coords[1],pad_post_crop_coords[2]

    original_height = label_resized.shape[1] - before2[1]
    original_width = label_resized.shape[2] - before3[1]

    # Slice the image to remove the padding
    reversed_pad = label_resized[:, :original_height, :original_width]

    # Step 1: Reverse final crop
    crop_height, crop_width = d['dim_before_crop'][1:]
    padded_label = torch.zeros((1, crop_height, crop_width), dtype=label_resized.dtype)
    padded_label[:, y1_crop:y2_crop, x1_crop:x2_crop] = reversed_pad

    # Step 2: Reverse bottom crop
    bottom_height = d['dim_before_bottom_crop'][1]
    bottom_width = d['dim_before_bottom_crop'][2]
    bottom_padded_label = torch.zeros((1, bottom_height, x2_bottom), dtype=padded_label.dtype)
    bottom_padded_label[:, :y2_bottom, :x2_bottom] = padded_label

    # Conditional steps based on whether breast trim was applied
    if mode == 'patches':
        # Step 3: Reverse breast trim
        trim_width = d['dim_before_breast_crop'][2]
        trim_padded_label = torch.zeros((bottom_height, trim_width), dtype=bottom_padded_label.dtype)
        trim_padded_label[:, start_breast:end_breast] = bottom_padded_label
    else:
        # In other modes, use the bottom_padded_label directly for thorax crop reversal
        trim_padded_label = bottom_padded_label
        trim_width = bottom_width  # This assumes no trimming, hence the width is unchanged

    
    # Step 4: Reverse thorax crop
    thorax_height = d['dim_before_thorax_crop'][1]
    thorax_padded_label = torch.zeros((1, thorax_height, trim_width), dtype=trim_padded_label.dtype)
    thorax_padded_label[:, y1_thorax:, :] = trim_padded_label


    original_spatial_dim = d['dim_before_resize_preliminary']

    # Step 0: Resize to the original spatial dimensions before the final crop
    # Assuming 'resize' was a downscaling operation and the original_spatial_dim is the target size
    reconstructed_mask = F.interpolate(thorax_padded_label.unsqueeze(0).unsqueeze(0).float(),
                                  size=original_spatial_dim.tolist(),  # Excluding the batch size dimension
                                  mode='nearest-exact').squeeze(0).squeeze(0)  # Removing the added batch and channel dimensions


    # Update the original label in the dictionary
    return reconstructed_mask

def get_mean_std_dataloader(dataloader, masked=False):
  # Variables to store sum and sum of squares
  sum_of_images = 0.0
  sum_of_squares = 0.0
  num_pixels = 0

  # Iterate over the DataLoader
  for batch in tqdm(dataloader):
      image = batch["image"]

      if masked:
        mask = image > 0.0
        image = image[mask]

      sum_of_images += image.sum()
      sum_of_squares += (image ** 2).sum()
      num_pixels += image.numel()

  # Calculate the mean and standard deviation
  mean = sum_of_images / num_pixels
  std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

  print(f'Mean: {mean}, Standard Deviation: {std_dev}')
  return mean, std_dev

# Transform functions

In [7]:
class EnhanceLesionsSelective(MapTransform):
    """
    A MONAI MapTransform to enhance lesions selectively in post-contrast images using a soft mask derived from the subtracted image.
    """
    def __init__(self, keys, threshold=0.9):
        super().__init__(keys)
        self.threshold = threshold

    def create_soft_mask(self, subtracted_norm, threshold=0.9):
        """
        Create a soft mask where values are scaled between 0 and max_value,
        with intensities above 'threshold' in the subtracted image being closer to max_value.
        """
        subtracted_norm = np.where(subtracted_norm > threshold, threshold, subtracted_norm/threshold)
        return subtracted_norm

    def __call__(self, data):
        d = dict(data)
        
        for key in self.keys:
            
            subtracted = d[key]  # Assuming d[key] is a tuple (post_contrast, subtracted)

            fourth_image = d['processed_image'] # 4th sequence
            subtracted = np.array(subtracted)
            
            # Normalize the subtracted image to [0, 1]
            subtracted_norm = cv2.normalize(subtracted[0], None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            
            # Create a soft mask from the normalized subtracted image
            soft_mask = self.create_soft_mask(subtracted_norm, threshold=self.threshold)
            
            # Normalize the post-contrast image to [0, 1] and apply the soft mask
            fouth_image_norm = cv2.normalize(fourth_image[0], None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            enhanced_image = soft_mask*fouth_image_norm
            
            # Combine enhanced image with the normalized subtracted image for final enhancement
            enhanced_image_final = subtracted_norm + enhanced_image

            enhanced_image_final=cv2.normalize(enhanced_image_final, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
            
            # Assuming the result should maintain the original shape as 1 x H x W
            enhanced_image_final = np.expand_dims(enhanced_image_final, 0)

            d[key] = monai.data.MetaTensor(enhanced_image_final)
        
        return d



class RemoveThorax(MapTransform):

    def __init__(self, threshold=250, value=0, margin=0,**kwargs):
        super(RemoveThorax, self).__init__(**kwargs)
        self.threshold = threshold
        self.value = value
        self.margin = margin

    def remove_upper_portion_get_coords(self, image):
        # IMAGE IS C, H, W
        # Step 1: Find the vertical middle line of the image
        middle_x = image.shape[2] // 2

        non_zero_y = 0
        for y in reversed(range(image.shape[1])):
            if image[:,y, middle_x] > 0:
                non_zero_y = y
                break

        return non_zero_y  # Also return the y-coordinate for reference

    def __call__(self, data):
        d = dict(data)
        image_to_threshold = d['processed_image']
        dim_before_thorax_crop = torch.tensor(image_to_threshold.shape)
        image_to_threshold = np.where(image_to_threshold < self.threshold, self.value, image_to_threshold)
        y_coord = self.remove_upper_portion_get_coords(image_to_threshold)-self.margin
        thorax_crop_coords = torch.tensor([0, y_coord, image_to_threshold.shape[2], image_to_threshold.shape[1]], dtype=torch.int16)  # (x1, y1, x2, y2)
        image = d['processed_image']
        image = image[:,y_coord:, :]

        if data['has_mask']:
            mask = d['processed_label']
            mask = mask[:,y_coord:, :]
            d['processed_label'] = mask

        d['processed_image'] = image
        d['thorax_crop_coords'] = torch.cat((d['thorax_crop_coords'], thorax_crop_coords), dim=0)
        d['dim_before_thorax_crop'] = torch.cat((d['dim_before_thorax_crop'], dim_before_thorax_crop), dim=0)
        return d


class RemoveBottom(MapTransform):

    def __init__(self, threshold=250, value=0, margin=0, **kwargs):
        super(RemoveBottom, self).__init__(**kwargs)
        self.threshold = threshold
        self.value = value
        self.margin = margin

    def remove_lower_portion_get_coords(self, image):

        # Step 2: Starting from the bottom, find the first non-zero pixel
        non_zero_y = image.shape[1]-1
        for y in reversed(range(image.shape[1])):  # Start from the bottom
            if np.sum(image[:,y,:]) > 0:
                non_zero_y = y
                break

        return non_zero_y  # Also return the y-coordinate for reference

    def __call__(self, data):
        d = dict(data)
        image_to_threshold = d['processed_image']

        """print("ciao prima")
        plt.imshow(image_to_threshold[0], cmap='gray')
        plt.show()"""

        image_to_threshold = np.where(image_to_threshold < self.threshold, self.value, image_to_threshold)
        y_coord = self.remove_lower_portion_get_coords(image_to_threshold)+self.margin

        bottom_crop_coords = torch.tensor([0, 0, image_to_threshold.shape[2], y_coord], dtype=torch.int16)  # (x1, y1, x2, y2)
        image = d['processed_image']
        dim_before_bottom_crop = torch.tensor(image.shape)
        image = image[:,:y_coord, :]

        if d['has_mask']:
            mask = d['processed_label']
            mask = mask[:,:y_coord, :]
            d['processed_label'] = mask

        d['processed_image'] = image
        d['bottom_crop_coords'] = torch.cat((d['bottom_crop_coords'], bottom_crop_coords), dim=0)
        d['dim_before_bottom_crop'] = torch.cat((d['dim_before_bottom_crop'], dim_before_bottom_crop), dim=0)

        return d
class FilterBySize(MapTransform):

    def __init__(self, max_ratio, **kwargs):
        super(FilterBySize, self).__init__(**kwargs)
        self.max_ratio = max_ratio
        self.delete = monai.transforms.DeleteItemsd(keys = ['image', 'label'])

    def __call__(self, data):
        d = dict(data)
        c, h, w = d['image'].shape

        if h >= w:

          if w == 0 or self.max_ratio < h/w:
            return self.delete(d)
        elif w >= h:
          if h==0 or self.max_ratio < w/h:
            return self.delete(d)

        return d


class MedianSmooth(MapTransform):

    def __init__(self, radius, **kwargs):
        super(MedianSmooth, self).__init__(**kwargs)
        self.median_smooth = monai.transforms.MedianSmooth(radius=radius)


    def __call__(self, data):
        d = dict(data)
        d['processed_image'] = self.median_smooth(d['processed_image'])
        return d


class TrimSides(MapTransform):

    def __init__(self, keys, threshold, tolerance, **kwargs):
        super(TrimSides, self).__init__(keys, **kwargs)
        self.threshold = threshold
        self.tolerance = tolerance

    def trim_sides(self, image_data, threshold=0, tolerance=0):
        # Calculate the sum of pixel values across the channel axis for each column
        col_sum = np.sum(image_data, axis=0).sum(axis=0)


        # Find indices where the sum exceeds the threshold
        x_start = np.argmax(col_sum > threshold)
        x_end = len(col_sum) - np.argmax(col_sum[::-1] > threshold) - 1

        # Apply tolerance
        x_start = max(0, x_start - tolerance)
        x_end = min(len(col_sum) - 1, x_end + tolerance)

        return x_start, x_end

    def __call__(self, data):
        d = dict(data)
        image = d['processed_image']


        x_start, x_end = self.trim_sides(image_data=image, threshold=self.threshold, tolerance=self.tolerance)

        # Crop the image and mask
        cropped_image = image[:, :, x_start:x_end+1]



        trim_coords = torch.tensor([x_start, x_end+1], dtype=torch.int16)

        # Update the dictionary
        d['processed_image'] = cropped_image

        if d['has_mask']:
            mask = d['processed_label']
            cropped_mask = mask[:, :, x_start:x_end+1]
            d['processed_label'] = cropped_mask


        d['trim_coords'] = torch.cat((d['trim_coords'], trim_coords), dim=0)

        return d

class RelativeThresholding(MapTransform):
    def __init__(self, keys, relative_threshold):
        super().__init__(keys)
        self.relative_threshold = relative_threshold

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            # assuming the image's pixel values are in the range [0, 1]
            # if your image has a different range, adjust the max_intensity and threshold_value calculation accordingly
            max_intensity = np.max(d[key])
            threshold_value = max_intensity * self.relative_threshold

            # apply the thresholding
            d[key] = torch.tensor(np.where(d[key] >= threshold_value, 1, 0))
        return d


class RelativeThresholdingSingleChannel(MapTransform):
    def __init__(self, keys, relative_threshold, channel_index):
        super().__init__(keys)
        self.relative_threshold = relative_threshold
        self.channel_index = channel_index

    def __call__(self, data):
        d = dict(data)
        for key in self.keys:
            original_image = d[key]

            # Extract the specific channel
            single_channel = original_image[self.channel_index]

            # assuming the image's pixel values are in the range [0, 1]
            # if your image has a different range, adjust the max_intensity and threshold_value calculation accordingly
            max_intensity = torch.max(single_channel)
            threshold_value = max_intensity * self.relative_threshold

            # apply the thresholding to the single channel
            transformed_channel = torch.where(single_channel >= threshold_value, 1, 0)

            # Replace the channel in the original image
            transformed_image = torch.clone(original_image)  # make a copy of the original image
            transformed_image[self.channel_index] = transformed_channel
            d[key] = transformed_image

        return d

class ThresholdBlack(MapTransform):

    def __init__(self, threshold, value, **kwargs):
        super(ThresholdBlack, self).__init__(**kwargs)
        self.threshold = threshold
        self.value = value

    def __call__(self, data):
        d = dict(data)

        d['processed_image'] = monai.data.MetaTensor(np.where(d['processed_image'] < self.threshold, self.value, d['processed_image']))

        if d['has_mask']:
            d['processed_label'] = monai.data.MetaTensor(np.where(d['processed_label'] < self.threshold, self.value, d['processed_label']))


        return d


def get_crop_coordinates(image):
    image = image[200:,:]
    # Convert the image to uint8 type for compatibility with OpenCV functions
    image_uint8 = (image * 255).astype(np.uint8)

    # Apply a binary threshold to segment the breasts from the background
    _, thresholded = cv2.threshold(image_uint8, 1, 255, cv2.THRESH_BINARY)

    # Find contours in the binary image
    contours, _ = cv2.findContours(thresholded, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Sort the contours by area (largest first)
    contours = sorted(contours, key=cv2.contourArea, reverse=True)

    # If there are not enough contours found, return the original image
    if len(contours) < 2:
        return image

    # Get the bounding boxes of the two largest contours (likely the breasts)
    x1, y1, w1, h1 = cv2.boundingRect(contours[0])
    x2, y2, w2, h2 = cv2.boundingRect(contours[1])

    # Compute the combined bounding box
    x = min(x1, x2)
    y = min(y1, y2)
    w = max(x1 + w1, x2 + w2) - x
    h = max(y1 + h1, y2 + h2) - y

    y_start, y_end, x_start, x_end = y, y+h, x, x+w

    return y_start, y_end, x_start, x_end

def remove_black_borders_get_coordinates(image):
    """
    Remove black borders on the left and right of the image.
    """
    # Sum the pixel values across rows to get a profile of the image's columns
    column_sum = image.sum(axis=0)

    # Set a threshold (small value) to identify black regions
    threshold = 500

    # Identify where the profile transitions from nearly zero to non-zero
    non_black_columns = np.where(column_sum > threshold)[0]

    y_start, y_end = non_black_columns[0], non_black_columns[-1]+1

    return y_start, y_end

def remove_black_borders_2D_get_coordinates(image):
    """
    Remove black borders on all sides of the image.
    """
    # Sum the pixel values across columns and rows to get profiles of the image's columns and rows
    column_sum = image.sum(axis=0)
    row_sum = image.sum(axis=1)

    # Set a threshold (small value) to identify black regions
    col_threshold = 500
    row_threshold = 500

    # Identify where the profile transitions from nearly zero to non-zero for both columns and rows
    non_black_columns = np.where(column_sum > col_threshold)[0]
    non_black_rows = np.where(row_sum > row_threshold)[0]

    # Crop the image using the identified transition points
    y_start, y_end = non_black_columns[0], non_black_columns[-1]+1
    x_start, x_end = non_black_rows[0], non_black_rows[-1]+1
    return y_start, y_end, x_start, x_end

class RemoveBlack(MapTransform):

    def __init__(self, **kwargs):
        super(RemoveBlack, self).__init__(**kwargs)

    def __call__(self, data):

        d = dict(data)
        y_start, y_end, x_start, x_end = get_crop_coordinates(d['image'][0])
        if y_end-y_start < 100:
            return d
        d['image'] = d['image'][:,200:]
        d['label'] = d['label'][:,200:]

        d['image'] = d['image'][:, y_start: y_end+50]
        d['label'] = d['label'][:, y_start: y_end+50]

        y_start, y_end, x_start, x_end = remove_black_borders_2D_get_coordinates(d['image'][0])

        d['image'] = d['image'][:, x_start:x_end, y_start: y_end]
        d['label'] = d['label'][:, x_start:x_end, y_start: y_end]

        return d

class PrepareSample(MapTransform):

    def __init__(self, target_size, subtracted_images, patches, **kwargs):
        super(PrepareSample, self).__init__(**kwargs)
        self.resize = monai.transforms.Resized(keys=['image', 'label'],spatial_size=target_size, mode='nearest-exact')
        self.patches = patches
        self.resize_original = monai.transforms.Resized(keys=['original_image', 'original_label'],spatial_size=target_size,  mode='nearest-exact')
        self.subtracted_images = subtracted_images
        self.loadimage = monai.transforms.LoadImage(ensure_channel_first=True, reader=monai.data.PILReader(converter=lambda image: image.convert("L")))
        
                
    def prepare_with_patches(self,data):
        trim_breast_coords = data['trim_breast_coords'].tolist()
        #trim_coords = data['trim_coords'].tolist()
        thorax_crop_coords= data['thorax_crop_coords'].tolist()
        bottom_crop_coords= data['bottom_crop_coords'].tolist()

        crop_coords = data['crop_coords'].tolist()
        pad_post_crop_coords = data['pad_post_crop_coords'].tolist()

        if self.subtracted_images:
            image_path = data['image_meta_dict']['subtracted_filename_or_obj']
            if image_path.endswith(".npy"):
                image = np.load(image_path)
                image = np.expand_dims(image, 0)
                image = monai.data.MetaTensor(image)

            else:
                label_path = data['label_meta_dict']['subtracted_filename_or_obj']
                image = self.loadimage(image_path)
                image = monai.transforms.Rotate90()(image)
                image = monai.data.MetaTensor(image)
                
                label = self.loadimage(label_path)
                label = monai.transforms.Rotate90()(label)
                label = monai.data.MetaTensor(label)
                data['label'] = label
                

            data['image'] = image

       
        data = self.resize(data)

        data = self.resize_original(data)

        image = data['image']

        original_image = data['original_image']

        target_size = data['preliminary_target_size'].tolist()

        x1, y1, x2, y2 = thorax_crop_coords
        
        image = image[:,y1:, :]
        original_image = original_image[:,y1:, :]


        start, end = trim_breast_coords

        image = image[:,:, start:end]
        original_image = original_image[:,:, start:end]


        x1, y1, x2, y2 = bottom_crop_coords
        image = image[:,:y2, :]
        original_image = original_image[:,:y2, :]


        """start, end = trim_coords

        image = image[:, :, start:end]
        label = label[:, :,start: end]
        """
        y_min, y_max, x_min, x_max = crop_coords

        image = image[:, y_min:y_max, x_min:x_max]
        original_image = original_image[:, y_min:y_max, x_min:x_max]

        before1, before2, before3 = pad_post_crop_coords[0],pad_post_crop_coords[1],pad_post_crop_coords[2]

        image = np.pad(image, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')
        original_image = np.pad(original_image, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')

        data['image']=image
        data['processed_image'] = original_image


        if data['has_mask']:
            label = data['label']
            original_label = data['original_label']
            
            x1, y1, x2, y2 = thorax_crop_coords
            
            #label =  monai.transforms.Resize(spatial_size=target_size)(label)
            label = label[:,y1:, :]
            original_label = original_label[:,y1:, :]

            start, end = trim_breast_coords
            label = label[:,:, start:end]
            original_label = original_label[:,:, start:end]

            x1, y1, x2, y2 = bottom_crop_coords
            label = label[:,:y2, :]
            original_label = original_label[:,:y2, :]

            y_min, y_max, x_min, x_max = crop_coords
            label = label[:, y_min:y_max, x_min:x_max]
            original_label = original_label[:, y_min:y_max, x_min:x_max]

            before1, before2, before3 = pad_post_crop_coords[0],pad_post_crop_coords[1],pad_post_crop_coords[2]

            label = np.pad(label, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')
            original_label = np.pad(original_label, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')

            data['label']=label
            data['processed_label'] = original_label

        return data

    def prepare_without_patches(self,data):

        trim_breast_coords = data['trim_breast_coords'].tolist()
        #trim_coords = data['trim_coords'].tolist()
        thorax_crop_coords= data['thorax_crop_coords'].tolist()
        bottom_crop_coords= data['bottom_crop_coords'].tolist()

        crop_coords = data['crop_coords'].tolist()
        pad_post_crop_coords = data['pad_post_crop_coords'].tolist()

        if self.subtracted_images:
            image_path = data['image_meta_dict']['subtracted_filename_or_obj']
            if image_path.endswith(".npy"):
                image = np.load(image_path)
                image = np.expand_dims(image, 0)
                image = monai.data.MetaTensor(image)
                data['image'] = image

                 

            else: #FOR BRADM
                label_path = data['label_meta_dict']['subtracted_filename_or_obj']
                image = self.loadimage(image_path)
                image = monai.transforms.Rotate90()(image)
                image = monai.data.MetaTensor(image)
                
                label = self.loadimage(label_path)
                label = monai.transforms.Rotate90()(label)
                label = monai.data.MetaTensor(label)
                data['label'] = label
                data['image'] = image

        data = self.resize(data)
        data = self.resize_original(data)

        image = data['image']
        original_image = data['original_image']

        target_size = data['preliminary_target_size'].tolist()
        
        x1, y1, x2, y2 = thorax_crop_coords
        
        image = image[:,y1:, :]
        original_image = original_image[:,y1:, :]

        #start, end = trim_breast_coords

        #image = image[:,:, start:end]
        #original_image = original_image[:,:, start:end]


        x1, y1, x2, y2 = bottom_crop_coords
        image = image[:,:y2, :]
        original_image = original_image[:,:y2, :]

        y_min, y_max, x_min, x_max = crop_coords

        image = image[:, y_min:y_max, x_min:x_max]
        original_image = original_image[:, y_min:y_max, x_min:x_max]

        before1, before2, before3 = pad_post_crop_coords[0],pad_post_crop_coords[1],pad_post_crop_coords[2]

        image = np.pad(image, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')
        original_image = np.pad(original_image, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')

        data['image']=image
        data['processed_image'] = original_image


        
        if data['has_mask']:
            label = data['label']
            original_label = data['original_label']
            
            x1, y1, x2, y2 = thorax_crop_coords
            
            #label =  monai.transforms.Resize(spatial_size=target_size)(label)
            label = label[:,y1:, :]
            original_label = original_label[:,y1:, :]

            #start, end = trim_breast_coords
            #label = label[:,:, start:end]
            #original_label = original_label[:,:, start:end]

            x1, y1, x2, y2 = bottom_crop_coords
            label = label[:,:y2, :]
            original_label = original_label[:,:y2, :]

            y_min, y_max, x_min, x_max = crop_coords
            label = label[:, y_min:y_max, x_min:x_max]
            original_label = original_label[:, y_min:y_max, x_min:x_max]

            before1, before2, before3 = pad_post_crop_coords[0],pad_post_crop_coords[1],pad_post_crop_coords[2]

            label = np.pad(label, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')
            original_label = np.pad(original_label, ((before1[0], before1[1]), (before2[0], before2[1]), (before3[0], before3[1])), 'constant')

            data['label']=label
            data['processed_label'] = original_label

        return data
        
    def __call__(self, data):

        if self.patches:
            return self.prepare_with_patches(data)
        else:
            return self.prepare_without_patches(data)

    

class ForegroundMaskdSingleChannel(MapTransform):
    def __init__(self, keys, channel_index, num_bins=10):
        super().__init__(keys)
        self.foregroundMask= monai.transforms.ForegroundMask(invert=True)
        self.channel_index = channel_index

    def __call__(self, data):
        d = dict(data)

        for key in self.keys:

            original_image = d[key]

            # Extract the specific channel
            single_channel = original_image[self.channel_index]

            single_channel = np.expand_dims(single_channel, 0)

            # Apply the transform to this channel
            transformed_channel = self.foregroundMask(single_channel)

            # Replace the channel in the original image
            transformed_image = np.array(original_image)  # make a copy of the original image
            transformed_image[self.channel_index] = transformed_channel
            d[key] = torch.tensor(transformed_image)

        return d

class NormalizedSingleChannel(MapTransform):
    def __init__(self, keys, channel_index):
        super().__init__(keys)
        self.normalize = monai.transforms.NormalizeIntensity(subtrahend  = 0.1046, divisor=335.7632)
        self.channel_index = channel_index

    def __call__(self, data):
        d = dict(data)

        for key in self.keys:

            original_image = d[key]

            # Extract the specific channel
            single_channel = original_image[self.channel_index]

            single_channel = np.expand_dims(single_channel, 0)

            # Apply the transform to this channel
            transformed_channel = self.normalize(single_channel)[0]


            # Replace the channel in the original image
            transformed_image = np.array(original_image)  # make a copy of the original image

            transformed_image[self.channel_index] = transformed_channel
            d[key] = torch.tensor(transformed_image)

        return d


class Convert3D(MapTransform):

    def __init__(self, **kwargs):
        super( Convert3D, self).__init__(**kwargs)

    def __call__(self, data):

        d = data

        image = d['image']
        label = d['label']

        image = monai.transforms.utils_pytorch_numpy_unification.repeat(image,(3, 1, 1), axis=0)
        d['image']=image
        return d


class Convert3DEnhanced(MapTransform):

    def __init__(self, keys,**kwargs):
        super( Convert3DEnhanced, self).__init__(keys, **kwargs)
        self.relativeThresholding = RelativeThresholdingSingleChannel(keys=['image'], relative_threshold=0.4, channel_index = 1)
        self.foregroundMaskdSingleChannel = ForegroundMaskdSingleChannel(keys = ['image'], channel_index = 2)

    def __call__(self, data):

        d = data

        image = d['image']
        label = d['label']

        # Convert MetaTensor to torch.Tensor
        image_tensor = convert_to_numpy(image)

        # Perform the repeat operation
        image_tensor = np.repeat(image_tensor, 3, 0)

        # Convert back to MetaTensor if necessary
        image = torch.tensor(image_tensor)


        d['image'] = image
        d = self.relativeThresholding(d)
        d = self.foregroundMaskdSingleChannel(d)
        return d

class BoundingBoxSplit(MapTransform):
    def __init__(self, keys=("image", "label"), allow_missing_keys=False, bbox_size=(256, 256)):
        super().__init__(keys, allow_missing_keys)
        self.bbox_size = bbox_size

    def pad_image(self, image):
        """
        Pad the image to the desired bounding box size if it's smaller.

        Parameters:
        - image (numpy.ndarray): The image to be padded.

        Returns:
        - numpy.ndarray: The padded image.
        """
        channels, height, width = image.shape
        pad_height = max(0, self.bbox_size[0] - height)
        pad_width = max(0, self.bbox_size[1] - width)

        # Padding format should be [(0, 0), (pad_height, 0), (pad_width, 0)] to maintain the channel dimension
        padded_image = np.pad(image, [(0, 0), (0, pad_height), (0, pad_width)], mode='constant', constant_values=0)
        return padded_image

    def _positive_bounding_box(self, mask):
        """
        Computes the bounding box for a region of interest in a binary mask.

        Parameters:
        - mask (numpy.ndarray): A binary mask.

        Returns:
        - tuple: (y_min, y_max, x_min, x_max) coordinates of the bounding box.
        """
        # Find the row and column indices where the mask is 1.
        mask = mask[0]
        rows, cols = np.where(mask == 1)

        # If no ROI is found, return None.
        if len(rows) == 0 or len(cols) == 0:
            return None

        y_min, y_max = np.min(rows), np.max(rows)
        x_min, x_max = np.min(cols), np.max(cols)

        return y_min, y_max, x_min, x_max

    def _negative_bounding_box(self, mask, num_boxes=1):
        """
        Extracts two random bounding boxes of negative regions from a binary mask.

        Parameters:
        - mask (numpy.ndarray): A binary mask of shape (1, H, W).

        Returns:
        - list: Two tuples with (y_min, y_max, x_min, x_max) coordinates of the bounding boxes of the negative regions.
        """
        height, width = self.bbox_size[0], self.bbox_size[1]
        mask = mask[0]  # Remove the singleton dimension: (1, H, W) -> (H, W)

        H, W = mask.shape

        step_y = height // 2
        step_x = width // 2

        bboxes = []
        trials = 0
        max_trials = 50  # To avoid infinite loops, though this value can be adjusted

        while len(bboxes) < num_boxes and trials < max_trials:
            # Randomly sample a starting point
            y = np.random.randint(0, H - height + 1, 1)[0]
            x = np.random.randint(0, W - width + 1, 1)[0]

            # Align the sampled point to the nearest half-sized step grid
            y = (y // step_y) * step_y
            x = (x // step_x) * step_x

            window = mask[y:y+height, x:x+width]
            if np.sum(window) == 0 and (y, y+height-1, x, x+width-1) not in bboxes:
                bboxes.append((x, x+width-1, y, y+height-1))
            trials += 1

        return bboxes


    def _get_bboxes(self, mask):
        if mask.sum() == 0:
            return self._negative_bounding_box(mask, num_boxes=1)
        else:
            bbox_negative = self._negative_bounding_box(mask, num_boxes=1)
            bbox_positive = self._positive_bounding_box(mask)
            if not bbox_positive:
                return bbox_negative

            y_min, y_max, x_min, x_max = bbox_positive
            width, height = self.bbox_size

            # Calculate the sizes of the positive bounding box
            pos_width = x_max - x_min + 1
            pos_height = y_max - y_min + 1

            # Ensure the new bounding box includes the positive bounding box
            x_min_new = max(x_min - (width - pos_width) // 2, 0)
            y_min_new = max(y_min - (height - pos_height) // 2, 0)

            x_max_new = x_min_new + width - 1
            y_max_new = y_min_new + height - 1

            # Adjust the bounding box if it extends beyond the mask's boundaries
            if y_max_new >= mask.shape[1]:
                y_max_new = mask.shape[1] - 1
                y_min_new = max(y_max_new - height + 1, 0)  # Ensure it doesn't go negative
            if x_max_new >= mask.shape[2]:
                x_max_new = mask.shape[2] - 1
                x_min_new = max(x_max_new - width + 1, 0)  # Ensure it doesn't go negative

            # Ensure the positive region is included in the new bounding box
            x_min_new = min(x_min_new, x_min)
            y_min_new = min(y_min_new, y_min)
            x_max_new = max(x_max_new, x_max)
            y_max_new = max(y_max_new, y_max)

            bbox_positive = [(x_min_new, x_max_new, y_min_new, y_max_new)]
            return bbox_negative + bbox_positive

    def __call__(self, data):
        d = dict(data)


        d['image'] = self.pad_image(d['image'])
        d['label'] = self.pad_image(d['label'])

        data = []

        label = d['label']
        bboxes = self._get_bboxes(label)

        for bbox in bboxes:
            xmin, xmax, ymin, ymax = bbox
            new_d= d.copy()
            # Crop using bounding box
            new_d['image'] = torch.tensor(d["image"][:, ymin:ymax+1, xmin:xmax+1])
            new_d['label'] = torch.tensor(label[:, ymin:ymax+1, xmin:xmax+1])


            # Adjust meta-data for cropped image and label
            new_d["image_meta_dict"] = dict(d["image_meta_dict"])
            new_d["image_meta_dict"]["original_affine"] = d["image_meta_dict"]["affine"]
            new_d["image_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]["original_affine"])

            affine_adjust = np.array([[1, 0, 0, xmin], [0, 1, 0, ymin], [0, 0, 1, 0], [0, 0, 0, 1]])
            new_d["image_meta_dict"]["affine"] = d["image_meta_dict"]["affine"] @ affine_adjust
            new_d["image_meta_dict"]["affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]['affine'])



            new_d["label_meta_dict"] = dict(d["label_meta_dict"])
            new_d["label_meta_dict"]["original_affine"] = d["label_meta_dict"]["affine"]
            new_d["label_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["label_meta_dict"]["original_affine"])


            new_d["label_meta_dict"]["affine"] = d["label_meta_dict"]["affine"] @affine_adjust
            new_d["label_meta_dict"]["affine"] = monai.data.MetaTensor(new_d["label_meta_dict"]['affine'])

            data.append(new_d)
        if len(data) == 0:
          return d
        return data



class AdaptiveCropBreasts2(MapTransform):

    def __init__(self, keys=['processed_image','processed_label'], strict_boundary_perc=0.001):
        """
        Initializes the adaptive crop transform.

        :param keys: The data keys to apply the transform to.
        """
        super().__init__(keys)
        self.strict_boundary_perc = strict_boundary_perc

    def find_strict_breast_region(self, half_image_sum, peak_index, total_width):
        # Set a stricter percentage of the peak value to consider as the breast boundary
        peak_value = half_image_sum[peak_index]
        boundary_threshold = peak_value * self.strict_boundary_perc

        # Find the left boundary of the breast region
        left_boundary = peak_index
        while left_boundary > 0 and half_image_sum[left_boundary] > boundary_threshold:
            left_boundary -= 1

        # Find the right boundary of the breast region
        right_boundary = peak_index
        while right_boundary < total_width and half_image_sum[right_boundary] > boundary_threshold:
            right_boundary += 1

        return left_boundary, right_boundary

    def __call__(self, data):
        d = dict(data)
        data_list = []

        x1, y1, x2, y2 = d['thorax_crop_coords']

        image = copy.deepcopy(d['image'])
        mask = copy.deepcopy(d['label'])

        processed_image = copy.deepcopy(d['processed_image'])
        processed_mask = copy.deepcopy(d['processed_label'])

        dim_before_breast_crop = torch.tensor(image.shape)
        image = image[:,y1:, :]
        mask = mask[:,y1:, :]

        image_for_check = image[:, 20:, :]

        # Calculate the vertical sum for the left and right halves
        mid_point = image_for_check.shape[2] // 2

        left_half_sum = image_for_check.sum(axis=(0, 1))[:mid_point]
        right_half_sum = image_for_check.sum(axis=(0, 1))[mid_point:]

        # Find the peak in each half
        left_peak_index = np.argmax(left_half_sum)
        right_peak_index = np.argmax(right_half_sum) + mid_point

        # Find the breast regions
        left_breast_boundaries = self.find_strict_breast_region(left_half_sum, left_peak_index, mid_point)
        right_breast_boundaries = self.find_strict_breast_region(right_half_sum, right_peak_index - mid_point, image.shape[2] - mid_point)

        # Extract the breast regions
        left_breast_region_image_strict = processed_image[:, :, left_breast_boundaries[0]:left_breast_boundaries[1]]
        left_breast_region_mask_strict = processed_mask[:, :, left_breast_boundaries[0]:left_breast_boundaries[1]]

        right_breast_region_mask_strict = processed_mask[:, :, right_breast_boundaries[0] + mid_point:right_breast_boundaries[1] + mid_point]
        right_breast_region_image_strict = processed_image[:, :, right_breast_boundaries[0] + mid_point:right_breast_boundaries[1] + mid_point]

        left_breast_trim_coords = torch.tensor([left_breast_boundaries[0], left_breast_boundaries[1]], dtype=torch.int16)
        right_breast_trim_coords = torch.tensor([right_breast_boundaries[0] + mid_point, right_breast_boundaries[1] + mid_point], dtype=torch.int16)

        regions = [
            (left_breast_region_image_strict, left_breast_region_mask_strict, left_breast_trim_coords),
            (right_breast_region_image_strict, right_breast_region_mask_strict, right_breast_trim_coords)
        ]

        # Loop through the two largest regions to crop the image and mask
        for i, region in enumerate(regions):
            new_d = d.copy()

            new_d['processed_image'] = region[0]
            new_d['processed_label'] = region[1]

            # Adjust meta-data for cropped image and label if meta-data is available
            if "image_meta_dict" in d and "label_meta_dict" in d:
                left_boundary = left_breast_boundaries[0] if i == 0 else right_breast_boundaries[0] + mid_point
                affine_adjust = np.array([[1, 0, 0, -left_boundary], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]])

                new_d["image_meta_dict"] = dict(d["image_meta_dict"])
                new_d["image_meta_dict"]["original_affine"] = d["image_meta_dict"]["affine"]
                new_d["image_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]["original_affine"])

                new_d["image_meta_dict"]["affine"] = d["image_meta_dict"]["affine"] @ affine_adjust
                new_d["image_meta_dict"]["affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]["affine"])

                new_d["label_meta_dict"] = dict(d["label_meta_dict"])
                new_d["label_meta_dict"]["original_affine"] = d["label_meta_dict"]["affine"]
                new_d["label_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["label_meta_dict"]["original_affine"])

                new_d["label_meta_dict"]["affine"] = d["label_meta_dict"]["affine"] @ affine_adjust
                new_d["label_meta_dict"]["affine"]= monai.data.MetaTensor(new_d["label_meta_dict"]["affine"])


                new_d['trim_breast_coords'] = torch.cat((new_d['trim_breast_coords'], region[2]), dim=0)
                new_d['dim_before_breast_crop'] = torch.cat((d['dim_before_breast_crop'], dim_before_breast_crop), dim=0)
            data_list.append(new_d)

        return data_list



class AdaptiveCropBreasts(MapTransform):

    def __init__(self,  keys=["image", "label"], margin_size=100, min_size=200, threshold=1):
        """
        Initializes the adaptive crop transform.

        :param margin_size: The maximum margin size to add
        :param min_size: The minimum size of regions to keep
        :param threshold: Threshold value to create a binary image
        """

        super().__init__(keys)
        self.margin_size = margin_size
        self.min_size = min_size
        self.threshold = threshold
        self.label_fn = label_fn

    def __call__(self, data):
        """
        Applies the adaptive crop to the image and mask in the input data.

        :param data: Dictionary containing image and mask
        :return: Dictionary with cropped image and mask
        """
        d = dict(data)
        data_list = []
        image = d['image']
        mask = d['label']

        image_mean, image_std = image.mean(), image.std()

        # Create a binary image based on the threshold
        threshold = int(image_mean + image_std)
        binary_image = image[0] > threshold

        # Label the connected components in the binary image
        labeled_image = self.label_fn(binary_image)

        # Remove small objects
        cleaned_image = morphology.remove_small_objects(labeled_image, min_size=self.min_size)

        # Calculate region properties
        regions = regionprops(cleaned_image)

        # Sort regions based on area size, keeping the two largest
        largest_regions = sorted(regions, key=lambda x: -x.area)[1:3]

        # Initialize lists to store the cropped images and masks
        cropped_images = []
        cropped_masks = []

        # Loop through the two largest regions to crop the image and mask
        for region in largest_regions:
            new_d= d.copy()
            bbox = region.bbox
            top, left, bottom, right = bbox

            # Add margin
            top = max(top - self.margin_size, 0)
            left = max(left - self.margin_size, 0)
            bottom = min(bottom + self.margin_size, image.shape[1])
            right = min(right + self.margin_size, image.shape[2])

            # Crop the image and mask
            cropped_image = image[: ,top:bottom, left:right]
            cropped_mask = mask[:,top:bottom, left:right]

            new_d['image'] = torch.tensor(cropped_image)
            new_d['label'] = torch.tensor(cropped_mask)

            # Adjust meta-data for cropped image and label
            new_d["image_meta_dict"] = dict(d["image_meta_dict"])

            new_d["image_meta_dict"]["original_affine"] = d["image_meta_dict"]["affine"]
            new_d["image_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]["original_affine"])

            affine_adjust = np.array([[1, 0, 0, left], [0, 1, 0, top], [0, 0, 1, 0], [0, 0, 0, 1]])
            new_d["image_meta_dict"]["affine"] = d["image_meta_dict"]["affine"] @ affine_adjust
            new_d["image_meta_dict"]["affine"] = monai.data.MetaTensor(new_d["image_meta_dict"]["affine"])

            new_d["label_meta_dict"] = dict(d["label_meta_dict"])
            new_d["label_meta_dict"]["original_affine"] = d["label_meta_dict"]["affine"]
            new_d["label_meta_dict"]["original_affine"] = monai.data.MetaTensor(new_d["label_meta_dict"]["original_affine"])


            new_d["label_meta_dict"]["affine"] = d["label_meta_dict"]["affine"] @affine_adjust
            new_d["label_meta_dict"]["affine"] = monai.data.MetaTensor(new_d["label_meta_dict"]["affine"])


            data_list.append(new_d)

        if len(data_list) < 2:
            return [data]

        return data_list


from monai.transforms import Resize
from monai.transforms import SpatialCrop


import torch
import numpy as np
import matplotlib.pyplot as plt

import torch
import numpy as np
import matplotlib.pyplot as plt

class CropToSquare(MapTransform):

    def __init__(self, keys=["image", "label"], shrink_factor=10, black_threshold=400):
        super().__init__(keys)
        self.shrink_factor = shrink_factor
        self.black_threshold = black_threshold  # Pixel intensity threshold for 'almost black'

    def __call__(self, data):
        d = dict(data)
        image = d['processed_image']
        label = d['processed_label']

        """print("prima 22")
        plt.imshow(image[0], cmap='gray')
        plt.show()"""

        d['dim_before_crop'] = torch.cat((d['dim_before_crop'], torch.tensor(image.shape)), dim=0)

        # Crop the image along the longest dimension to remove almost black regions
        max_intensity = np.max(image[0], axis=0)  # Max intensity for each column

        valid_columns = np.argwhere(max_intensity > self.black_threshold).flatten()

        if len(valid_columns) == 0:
            x_min = 0
            x_max = image.shape[2]
        else:
            x_min, x_max = valid_columns[0], valid_columns[-1]

            if x_min-30 >=0:
                x_min = x_min-30
            if x_max+30 < image.shape[2]:
                x_max = x_max+30

        y_min = self.shrink_factor
        y_max = image.shape[1]

        # Crop both image and label
        image = image[:, y_min:y_max, x_min:x_max]
        label = label[:, y_min:y_max, x_min:x_max]

        # Determine the new longest and shortest dimensions
        new_longest_dim = max(image.shape[1], image.shape[2])
        new_shortest_dim = min(image.shape[1], image.shape[2])

        # Pad to make a square
        pad_bottom = new_longest_dim - image.shape[1] if image.shape[1] < new_longest_dim else 0
        pad_right = new_longest_dim - image.shape[2] if image.shape[2] < new_longest_dim else 0

        d['processed_image'] = np.pad(image, ((0, 0), (0, pad_bottom), (0, pad_right)), 'constant')
        d['processed_label'] = np.pad(label, ((0, 0), (0, pad_bottom), (0, pad_right)), 'constant')

        # Save the crop coordinates and dimensions
        crop_coords = torch.tensor([y_min, y_max, x_min, x_max], dtype=torch.int16)
        pad_post_crop_coords = torch.tensor([[0, 0], [0, pad_bottom], [0, pad_right]], dtype=torch.int16)
        d['crop_coords'] = torch.cat((d['crop_coords'], crop_coords), dim=0)
        d['pad_post_crop_coords'] = torch.cat((d['pad_post_crop_coords'], pad_post_crop_coords), dim=0)
        
        return d



        
class Resize(MapTransform):

    def __init__(self,  step, keys=["image", "label"], spatial_size=256):

        super().__init__(keys)
        self.spatial_size = spatial_size
        self.resize = monai.transforms.Resize(spatial_size=spatial_size , mode='nearest-exact')
        self.step = step

    def __call__(self, data):
        d = dict(data)

        image_key = 'processed_image' if self.step =="preliminary" else 'image'
        label_key = 'processed_label' if self.step =="preliminary" else 'label'

        image_shape = d[image_key].shape
        dim_before_resize = torch.tensor(image_shape)

        dim_before_resize_dict_key = "dim_before_resize_preliminary" if self.step =="preliminary" else "dim_before_resize_final"
        spatial_size_info_dict_key = "spatial_size_info_preliminary" if self.step =="preliminary" else "spatial_size_info_final"

        original_spatial_dim = torch.tensor([image_shape[1], image_shape[2]], dtype=torch.int16)

        #print(self.step)
        #print(image_shape)

        #print(d[image_key])
        #print(self.resize)
        d[image_key] = self.resize(d[image_key])
        if self.step!='preliminary':
            d['processed_image'] = self.resize(d['processed_image'])
        if d['has_mask']:
            d[label_key] = self.resize(d[label_key])
            if self.step!='preliminary':
                d['processed_label'] = self.resize(d['processed_label'])

        d[spatial_size_info_dict_key] = torch.cat((d[spatial_size_info_dict_key], original_spatial_dim), dim=0)
        d[dim_before_resize_dict_key ] = torch.cat((d[dim_before_resize_dict_key ], dim_before_resize), dim=0)

        if self.step=='preliminary':
            d['preliminary_target_size'] = torch.cat((d['preliminary_target_size'], torch.tensor(self.spatial_size)), dim=0)
        return d

class FilterByDim(MapTransform):
    def __init__(self, keys):
        super().__init__(keys)

        # Define a simple filter function for demonstration
    def filter_by_dim(self, data):
        # Implement your filtering condition here
        # For example, filter out images with a certain property:

        # OLD VALUE WAS 180
        keep_sample = data['processed_label'].shape[1] > 100
        return keep_sample

    def __call__(self, data):
        # Apply the filter function to determine if the sample should be kept
        keep_sample = torch.tensor([self.filter_by_dim(data)])
        data['keep_sample'] = torch.cat((data['keep_sample'], keep_sample), dim=0)
        return data

class FilterByMean(MapTransform):
    def __init__(self, keys, mean_threshold, start_pos):
        super().__init__(keys)
        self.mean_threshold = mean_threshold
        self.start_pos = start_pos

    def filter_by_mean(self, data):
        #print(data['processed_image'].std())
        keep_sample = data['processed_image'][:,self.start_pos:,:].std() > self.mean_threshold
        return keep_sample

    def __call__(self, data):
        # Apply the filter function to determine if the sample should be kept
        keep_sample = torch.tensor([self.filter_by_mean(data)])
        if data['keep_sample'] and not keep_sample:
            data['keep_sample'] = keep_sample

        return data


"""              median_smooth_radius=2,
                 hist_norm_num_bins=20,
                 rm_thorax_threshold=140,
                 rm_thorax_margin=80,
                 rm_bottom_threshold=140,
                 rm_bottom_margin=20,
                 threshold_black_threshold=500,
                 threshold_black_value=0,
                 trim_sides_threshold=50000,
                 trim_sides_tolerance=40,
                 bbox_size=(384,384),
                 pad_spatial_size=[512,512],
                 target_size_preliminary=(512,512),
                 target_size_final = (256,256),
                 mode='train',
                 has_mask=True,
                 **kwargs):





                     def __init__(self,
                 median_smooth_radius=2,
                 hist_norm_num_bins=20,
                 rm_thorax_threshold=10, #80
                 rm_thorax_margin=40,
                 rm_bottom_threshold=140,
                 rm_bottom_margin=50,
                 threshold_black_threshold=800,
                 threshold_black_value=0,
                 trim_sides_threshold=20000,
                 trim_sides_tolerance=20,
                 bbox_size=(384,384),
                 pad_spatial_size=[512,512],
                 target_size_preliminary=(512,512),
                 target_size_final = (256,256),
                 mode='train',
                 has_mask=True,
                 **kwargs):"""


import copy
class Preprocess(MapTransform):

    def __init__(self,
                 subtrahend=48.85898971557617, 
                 divisor=123.9007568359375,
                 median_smooth_radius=2,
                 hist_norm_num_bins=20,
                 rm_thorax_threshold=150, #80
                 rm_thorax_margin=60,
                 rm_bottom_threshold=120,
                 rm_bottom_margin=50,
                 threshold_black_threshold=1300,
                 threshold_black_value=0,
                 trim_sides_threshold=20000,
                 trim_sides_tolerance=20,
                 bbox_size=(384,384),
                 pad_spatial_size=[512,512],
                 target_size_preliminary=(512,512),
                 target_size_final = (256,256),
                 mode='train',
                 subtracted_images_path_prefixes = None,
                 has_mask=True, dataset = "private",
                 get_patches = False,
                 get_boundaryloss=False,
                 **kwargs):

        super(Preprocess, self).__init__(**kwargs)

        self.subtracted_images_path_prefixes = subtracted_images_path_prefixes
        self.subtrahend = subtrahend
        self.divisor = divisor
        self.get_patches = get_patches

        if dataset == "BRADM":
            self.subtracted_images = True
        else:
            if subtracted_images_path_prefixes:
                self.subtracted_images = True
            else:
                self.subtracted_images = False

            if get_patches:
                mean_threshold = 50
                start_pos = 75
            else:
                mean_threshold = 35
                start_pos = 40
            
        self.median_smooth = MedianSmooth(keys=['processed_image'], radius=median_smooth_radius)
        self.histogram_normalized = HistogramNormalized(keys=['processed_image'], num_bins = hist_norm_num_bins)
        self.remove_thorax = RemoveThorax(keys=['processed_image', 'processed_label'], threshold= rm_thorax_threshold, margin=rm_thorax_margin)
        self.remove_bottom = RemoveBottom(keys=['processed_image', 'processed_label'], threshold = rm_bottom_threshold,margin=rm_bottom_margin)
        self.threshold_black = ThresholdBlack(keys=['processed_image'], threshold = threshold_black_threshold, value=threshold_black_value)
        self.trim_sides = TrimSides(keys=['processed_image','processed_label'], threshold=trim_sides_threshold, tolerance=trim_sides_tolerance)
        self.normalize = monai.transforms.NormalizeIntensityd(keys=['image'], subtrahend  = self.subtrahend, divisor=self.divisor)
        self.bbox_split = BoundingBoxSplit(keys=['image','label'], bbox_size=bbox_size)
        self.pad = monai.transforms.SpatialPadd(keys=['image','label'], spatial_size=pad_spatial_size)
        self.convert3d =  monai.transforms.RepeatChanneld(keys=['image'], repeats=3)
        self.convert3denhanced = Convert3DEnhanced(keys=['image', 'label'])
        self.normalizedSingleChannel = NormalizedSingleChannel(keys=['image'], channel_index = 0)
        self.prepare_image = PrepareSample(keys=None, target_size=target_size_preliminary, subtracted_images=self.subtracted_images, patches=self.get_patches)
        self.adaptiveCropBreasts = AdaptiveCropBreasts2(keys=['processed_image','processed_label'])
        self.crop = CropToSquare(keys=['processed_image','processed_label'], shrink_factor=15)
        self.resizePreliminary = Resize(step='preliminary', keys=['image', 'label'], spatial_size=target_size_preliminary)
        self.resizePreliminaryCleanImg = monai.transforms.Resized(keys=['image', 'label'],spatial_size=target_size_preliminary, mode='nearest-exact')
        self.resizeFinal = Resize(step='final', keys=['image', 'label'], spatial_size=target_size_final)
        self.foreground_mask = monai.transforms.ForegroundMaskd(keys=['image'], invert=True)
        self.filterbyDim = FilterByDim(keys=['processed_image', 'processed_label'])
        self.filterbyMean = FilterByMean(keys=['image', 'label'], mean_threshold=mean_threshold, start_pos = start_pos)
        self.has_mask = has_mask
        self.randVFlip = RandFlipd(prob=0.3, spatial_axis=1, keys=['image', 'label'])
        self.randGridDistortion = RandGridDistortiond(keys=['image', 'label'], num_cells=10, prob=0.3, distort_limit=(-0.1, 0.1))
        self.randHistShift = RandHistogramShiftd(keys=['image', 'label'], num_control_points=15, prob=0.3)
        self.randomRotate = RandRotated(range_x=0.15, keys = ['image', 'label'], prob=0.3, keep_size=True)
        self.enhance =EnhanceLesionsSelective(keys=['image'])
        self.get_boundaryloss = get_boundaryloss
        

        self.train_functions_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.adaptiveCropBreasts,
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
            self.normalize,
            
        ]

        self.train_functions_no_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
            self.normalize,
        ]


        self.test_functions_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.adaptiveCropBreasts,
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
            self.normalize,
        ]

        self.test_functions_no_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
            self.normalize,
        ]

        self.statistics_functions_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.adaptiveCropBreasts,  # todo separate original image restoring!
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
        ]

        self.statistics_functions_no_patches=[
            self.resizePreliminary,
            self.median_smooth,
            self.histogram_normalized,
            self.resizePreliminaryCleanImg,
            self.remove_thorax,
            self.remove_bottom,
            self.filterbyDim,
            self.crop,
            self.prepare_image,
            self.resizeFinal,
            self.filterbyMean,
        ]

        if mode=='train':
          if self.get_patches:
              self.transforms = Compose(self.train_functions_patches)
          else:
              self.transforms = Compose(self.train_functions_no_patches)
            
        elif mode=="test":
            if self.get_patches:
                  self.transforms = Compose(self.test_functions_patches)
            else:
                self.transforms = Compose(self.test_functions_no_patches)
        elif mode=='statistics':
            if self.get_patches:
                self.transforms = Compose(self.statistics_functions_patches)
            else:
                self.transforms = Compose(self.statistics_functions_no_patches)
                
                
        self.disttransform = dist_map_transform([1,1], 2)
                

    def __call__(self, data):

        data['thorax_crop_coords']=torch.tensor([], dtype=torch.int16)
        data['dim_before_thorax_crop']=torch.tensor([], dtype=torch.int16)

        data['trim_breast_coords']=torch.tensor([],dtype=torch.int16)
        data['dim_before_breast_crop']=torch.tensor([], dtype=torch.int16)

        data['bottom_crop_coords']=torch.tensor([],dtype=torch.int16)
        data['dim_before_bottom_crop']=torch.tensor([], dtype=torch.int16)

        data['pad_post_crop_coords'] = torch.tensor([], dtype=torch.int16)

        data['crop_coords']=torch.tensor([],dtype=torch.int16)
        data['dim_before_crop']=torch.tensor([], dtype=torch.int16)
        data['preliminary_target_size']=torch.tensor([], dtype=torch.int16)

        #data['trim_coords']=torch.tensor([], dtype=torch.int16)

        data['spatial_size_info_preliminary']=torch.tensor([],dtype=torch.int16)
        data['spatial_size_info_final']=torch.tensor([],dtype=torch.int16)
        data['dim_before_resize_final']=torch.tensor([], dtype=torch.int16)
        data['dim_before_resize_preliminary']=torch.tensor([], dtype=torch.int16)

        data['processed_image']=copy.deepcopy(data['image'])
        data['original_image'] = copy.deepcopy(data['image'])

        data['keep_sample'] = torch.tensor([], dtype=torch.bool)
        data['has_mask'] = monai.data.MetaTensor(self.has_mask)

        if self.has_mask:
             data['processed_label']=copy.deepcopy(data['label'])
             data['original_label'] = copy.deepcopy(data['label'])

        else:
            data['processed_label']= np.zeros_like(data['image'])
            data['original_label'] =  np.zeros_like(data['image'])
            data['label'] =  np.zeros_like(data['image'])

            data['processed_label']=monai.data.MetaTensor(data['image'])
            data['original_label'] =  monai.data.MetaTensor(data['image'])
            data['label'] =  monai.data.MetaTensor(data['image'])

            
            
            


        if self.subtracted_images_path_prefixes:
            #path = path.replace("Dataset-arrays-images-only-4", "Dataset-arrays-images-only")
            pfx1, pfx2 = self.subtracted_images_path_prefixes[0], self.subtracted_images_path_prefixes[1]
            data['image_meta_dict']['subtracted_filename_or_obj'] = data['image_meta_dict']['filename_or_obj'].replace(pfx1, pfx2)
            if self.has_mask:
                data['label_meta_dict']['subtracted_filename_or_obj'] = data['label_meta_dict']['filename_or_obj'].replace(pfx1, pfx2)

        
        data = self.transforms(data)

        if self.get_patches:
            c, h, w = data[0]['image'].shape
    
            data[0]['image_meta_dict']['spatial_shape'] = np.array([h,w])
            data[0]['label_meta_dict']['spatial_shape'] = np.array([h,w])
    
            data[1]['image_meta_dict']['spatial_shape'] = np.array([h,w])
            data[1]['label_meta_dict']['spatial_shape'] = np.array([h,w])
    
            data[0]['image_meta_dict']['original_channel_dim'] = monai.data.MetaTensor(data[0]['image_meta_dict']['original_channel_dim'])
            data[0]['label_meta_dict']['original_channel_dim'] =  monai.data.MetaTensor(data[0]['label_meta_dict']['original_channel_dim'])
    
            data[1]['image_meta_dict']['original_channel_dim'] =  monai.data.MetaTensor(data[1]['image_meta_dict']['original_channel_dim'])
            data[1]['label_meta_dict']['original_channel_dim'] =  monai.data.MetaTensor(data[1]['label_meta_dict']['original_channel_dim'])
    
            """del data[0]['image_meta_dict']
            del data[0]['label_meta_dict']
    
            del data[1]['image_meta_dict']
            del data[1]['label_meta_dict']"""
            
            #del data[0]['processed_image']
            #del data[1]['processed_image']

                
            del data[0]['original_image']
            del data[1]['original_image']
            
    
            if self.has_mask:
                #del data[0]['processed_label']
                #del data[1]['processed_label']

                if self.get_boundaryloss:

                    boundary = self.disttransform(data[0]['label'][0])
                    data[0]['boundary']=boundary 
        
                    boundary = self.disttransform(data[1]['label'][0])
                    data[1]['boundary']=boundary 
                del data[0]['original_label']
                del data[1]['original_label']

                data[0]['has_mass'] = monai.data.MetaTensor(np.sum(data[0]['label']) != 0)
                data[1]['has_mass'] = monai.data.MetaTensor(np.sum(data[1]['label']) != 0)

    

        else:
            c, h, w = data['image'].shape
    
            data['image_meta_dict']['spatial_shape'] = np.array([h,w])
            data['label_meta_dict']['spatial_shape'] = np.array([h,w])
    
            data['image_meta_dict']['original_channel_dim'] = monai.data.MetaTensor(data['image_meta_dict']['original_channel_dim'])
            data['label_meta_dict']['original_channel_dim'] =  monai.data.MetaTensor(data['label_meta_dict']['original_channel_dim'])
    
            #del data['processed_image']
    
            if self.has_mask:
                #del data['processed_label']

                if self.get_boundaryloss:
                    boundary = self.disttransform(data['label'][0])
                    data['boundary']=boundary 
                data['has_mass'] = monai.data.MetaTensor(np.sum(data['label']) != 0)
                del data['original_label']
    
            del data['original_image']

        return data

# LOSSES & METRICS

In [8]:
# Helper function to identify the axis for aggregation
def identify_axis(shape):
    if len(shape) == 5:
        return [2, 3, 4]
    elif len(shape) == 4:
        return [2, 3]
    else:
        raise ValueError('Shape of tensor is neither 2D or 3D.')



def compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro', exclude_empty=False, exclude_empty_only_gt=False, return_std=False):
    dice_denominator = 2 * tp + fp + fn
    dice_numerator = 2 * tp

    if reduction == 'micro':
        dice_score = dice_numerator.sum() / dice_denominator.sum()
        if return_std:
            # Standard deviation does not apply in micro mode because it is a single value
            return dice_score, torch.tensor(0.0)
        return dice_score

    elif reduction == 'micro-imagewise':
        dice_per_sample = dice_numerator / dice_denominator
        if exclude_empty:
            dice_per_sample = torch.where(dice_denominator == 0, torch.tensor(float('nan')), dice_per_sample)
            if exclude_empty_only_gt:
                # Exclude samples where there are no positives in the ground truth
                exclude_mask = (tp + fn) == 0  # Ground truth empty cases
                dice_per_sample = torch.where(exclude_mask, torch.tensor(float('nan')), dice_per_sample)

            mean_dice = torch.nanmean(dice_per_sample)
            if return_std:
                std_dice = np.nanstd(dice_per_sample)
                return mean_dice, std_dice
            return mean_dice
        else:
            dice_per_sample = torch.where(dice_denominator == 0, 1, dice_per_sample)
            mean_dice = torch.nanmean(dice_per_sample)
            if return_std:
                std_dice = np.nanstd(dice_per_sample)
                return mean_dice, std_dice
            return mean_dice

    elif reduction == 'none':
        dice_score = dice_numerator / dice_denominator
        if exclude_empty:
            dice_score = torch.where(dice_denominator == 0, torch.tensor(float('nan')), dice_score)
            if exclude_empty_only_gt:
                # Exclude samples where there are no positives in the ground truth
                exclude_mask = (tp + fn) == 0  # Ground truth empty cases
                dice_score = torch.where(exclude_mask, torch.tensor(float('nan')), dice_score)
        else:
            dice_score = torch.where(dice_denominator == 0, 1, dice_score)
        return dice_score  # Standard deviation is not applicable for 'none' reduction

    else:
        raise ValueError("Reduction method must be either 'micro', 'micro-imagewise', or 'none'.")

    # Ensure Dice scores are within the [0, 1] range
    dice_score = torch.clamp(dice_score, min=0, max=1)

    return dice_score 

def compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro', exclude_empty=False, exclude_empty_only_gt=False, return_std=False):
    denominator = tp + fp + fn
    with torch.no_grad():  # Avoid tracking these operations in the autograd graph
        if reduction == 'micro':
            # Sum the counts across all samples and compute IoU
            iou = tp.sum() / denominator.sum()
            if return_std:
                # Standard deviation does not apply in micro mode because it is a single value
                return iou, torch.tensor(0.0)
            return iou

        elif reduction == 'micro-imagewise':
            # Avoid division by zero; set IoU to NaN for samples with denominator == 0
            valid = denominator != 0
            iou_per_sample = torch.zeros_like(tp, dtype=torch.float)
            iou_per_sample[valid] = tp[valid] / denominator[valid]
            if exclude_empty:
                if exclude_empty_only_gt:
                    # Exclude samples with no positives in the ground truth
                    exclude_mask = (tp + fn) == 0
                    iou_per_sample = torch.where(exclude_mask, torch.tensor(float('nan')), iou_per_sample)

                # Compute mean and optionally standard deviation only for valid samples
                mean_iou = torch.nanmean(iou_per_sample[valid])
                if return_std:
                    std_iou = np.nanstd(iou_per_sample[valid])
                    return mean_iou, std_iou
                return mean_iou
            else:
                iou_per_sample[~valid] = torch.tensor(1.0)  # Set invalid samples to 1
                mean_iou = torch.mean(iou_per_sample)
                if return_std:
                    std_iou = torch.std(iou_per_sample)
                    return mean_iou, std_iou
                return mean_iou

        elif reduction == 'none':
            # Compute IoU for each sample, handling division by zero
            iou = torch.zeros_like(tp, dtype=torch.float)
            valid = denominator != 0
            iou[valid] = tp[valid] / denominator[valid]
            if exclude_empty:
                iou[~valid] = torch.tensor(float('nan'))  # Mark invalid samples as NaN
                if exclude_empty_only_gt:
                    # Exclude samples with no positives in the ground truth
                    exclude_mask = (tp + fn) == 0
                    iou = torch.where(exclude_mask, torch.tensor(float('nan')), iou)
            else:
                iou[~valid] = torch.tensor(1.0)  # Optionally: Set a default value for invalid samples if not excluding them
            return iou  # Standard deviation is not applicable for 'none' reduction

        else:
            raise ValueError("Reduction method must be either 'micro', 'micro-imagewise', or 'none'.")

    return iou  # Return IoU, handle std outside this condition if needed

def compute_iou(y_true, y_pred, class_id, reduction='micro', exclude_empty=False):
    """
    Compute Intersection over Union for a specific class

    Args:
    y_true (torch.Tensor): batch of ground truth, 4D tensor (first dimension is batch size)
    y_pred (torch.Tensor): batch of prediction, 4D tensor (first dimension is batch size)
    class_id (int): the class to compute IoU for
    reduction (str): the method of reduction across the batch, can be 'micro' or 'micro image-wise'

    Returns:
    torch.Tensor: IoU score
    """

    def compute_iou_single(y_true_single, y_pred_single, class_id_single, exclude_empty=False):
        y_true_class = torch.where(y_true_single == class_id_single, 1, 0)
        y_pred_class = torch.where(y_pred_single == class_id_single, 1, 0)

        intersection = torch.logical_and(y_true_class, y_pred_class)
        union = torch.logical_or(y_true_class, y_pred_class)

        union_sum = torch.sum(union)
        if union_sum == 0:
            if exclude_empty:
                iou_score = float('nan')
            else:
                iou_score = 1.0
        else:
            iou_score = torch.sum(intersection).float() / union_sum.float()

        return iou_score

    assert reduction in ['micro', 'micro_image_wise'], "Reduction method should be either 'micro' or 'micro_image_wise'"

    if reduction == 'micro':
        y_true = y_true.view(-1)
        y_pred = y_pred.view(-1)
        return torch.tensor(compute_iou_single(y_true, y_pred, class_id,exclude_empty)).float()

    elif reduction == 'micro_image_wise':
        iou_scores = torch.tensor([compute_iou_single(y, p, class_id, exclude_empty) for y, p in zip(y_true, y_pred)], dtype=torch.float32)
        return torch.nanmean(iou_scores)  # Using nanmean to ignore NaN values



def compute_dice_score(y_true, y_pred, class_id=1, reduction='micro', exclude_empty=False):
    """
    Compute Dice Score for a specific class with reduction options, for input tensors in HxWxB format using PyTorch.
    
    Args:
    y_true (torch.Tensor): Ground truth, a 3D tensor (height, width, batch size).
    y_pred (torch.Tensor): Predictions, a 3D tensor (height, width, batch size).
    class_id (int): The class ID for which to compute the Dice Score.
    reduction (str): Method of reduction across the batch, either 'micro' or 'micro_image_wise'.
    
    Returns:
    torch.Tensor: The Dice Score.
    """

    def compute_dice_score_single(y_true_single, y_pred_single, class_id_single, exclude_empty=False):
        y_true_class = (y_true_single == class_id_single).float()
        y_pred_class = (y_pred_single == class_id_single).float()

        intersection = torch.sum(y_true_class * y_pred_class)
        union = torch.sum(y_true_class) + torch.sum(y_pred_class)

        if union == 0:
            if exclude_empty:
                dice_score = torch.tensor(float('nan'))  # Assuming need to handle NaN explicitly
            else:
                dice_score = torch.tensor(1.0)  # Assuming perfect IoU score when both prediction and GT are empty
        else:
            dice_score = (2. * intersection) / (union)
        
        return dice_score

    if reduction == 'micro':
        y_true_flat = y_true.view(-1)
        y_pred_flat = y_pred.view(-1)
        dice_score = torch.tensor(compute_dice_score_single(y_true_flat, y_pred_flat, class_id, exclude_empty)).float()
        return dice_score

    elif reduction == 'micro_image_wise':
        dice_scores = torch.tensor([compute_dice_score_single(y, p, class_id, exclude_empty) for y, p in zip(y_true, y_pred)], dtype=torch.float32)
        return torch.nanmean(dice_scores)  # Using nanmean to ignore NaN values in case of empty classes

    else:
        raise ValueError("Reduction method should be either 'micro' or 'micro_image_wise'")

def compute_mean_precision(tp, fp, fn, tn):
    """
    Compute the mean precision for binary classification across two classes.

    Args:
        tp (torch.Tensor): True Positives, tensor of shape (B, 1).
        fp (torch.Tensor): False Positives, tensor of shape (B, 1).
        tn (torch.Tensor): True Negatives, tensor of shape (B, 1).
        fn (torch.Tensor): False Negatives, tensor of shape (B, 1).

    Returns:
        torch.Tensor: The mean precision across classes.
    """

    # Precision for class 1
    precision_class_1 = torch.div(tp, tp + fp)
    precision_class_1[torch.isnan(precision_class_1)] = 1  # Handle division by zero

    # Precision for class 0 (inverting perspective)
    precision_class_0 = torch.div(tn, tn + fn)
    precision_class_0[torch.isnan(precision_class_0)] = 1  # Handle division by zero

    # Mean precision across both classes
    mean_precision = (precision_class_1 + precision_class_0) / 2

    # Average across the batch
    mean_precision = torch.mean(mean_precision)

    return mean_precision

def compute_mean_recall(tp, fp, fn, tn):
    """
    Compute the mean recall for binary classification across two classes.

    Args:
        tp (torch.Tensor): True Positives, tensor of shape (B, 1).
        fp (torch.Tensor): False Positives, tensor of shape (B, 1).
        tn (torch.Tensor): True Negatives, tensor of shape (B, 1).
        fn (torch.Tensor): False Negatives, tensor of shape (B, 1).

    Returns:
        torch.Tensor: The mean recall across classes.
    """
    recall_class_1 = torch.div(tp, tp + fn)
    recall_class_1[torch.isnan(recall_class_1)] = 1  # Handle division by zero

    recall_class_0 = torch.div(tn, tn + fp)
    recall_class_0[torch.isnan(recall_class_0)] = 1  # Handle division by zero

    mean_recall = (recall_class_1 + recall_class_0) / 2
    mean_recall = torch.mean(mean_recall)

    return mean_recall


def compute_dice_score_from_cm(tp, fp, fn, tn, reduction='micro', exclude_empty=False):
    # Convert to float for division
    tp = tp.float()
    fp = fp.float()
    fn = fn.float()
    
    if reduction == 'micro':
        # Sum across all classes and samples for micro averaging
        tp_sum = tp.sum()
        fp_sum = fp.sum()
        fn_sum = fn.sum()
        
        # Compute Dice score, handling division by zero
        denominator = 2 * tp_sum + fp_sum + fn_sum
        dice_score = 2 * tp_sum / denominator if denominator != 0 else torch.tensor(1.0)
        
    elif reduction == 'micro-imagewise':
        # Compute Dice Score per sample, then average across samples
        denominator = 2 * tp + fp + fn
        valid = denominator != 0
        dice_scores = torch.zeros_like(tp)
        dice_scores[valid] = 2 * tp[valid] / denominator[valid]

        if exclude_empty:
            dice_scores[~valid] = torch.tensor(float('nan'))  # Assuming perfect score for invalid cases
            dice_score = dice_scores.nanmean(dim=0)  # Average across samples
        else:
            dice_scores[~valid] = torch.tensor(1.0)  # Assuming perfect score for invalid cases
            dice_score = dice_scores.mean(dim=0)  # Average across samples
        
    else:
        raise ValueError("Reduction method must be either 'micro' or 'micro-imagewise'")
    
    return dice_score

def ra_iou(tp, fp, fn, beta=0.5):
    
    ra_iou = tp / (tp + beta * fn + fp + 1e-6)  # Adding epsilon to avoid division by zero
    
    return ra_iou.mean()

# Helper function to identify the axis for aggregation
def identify_axis(shape):
    if len(shape) == 5:
        return [2, 3, 4]
    elif len(shape) == 4:
        return [2, 3]
    else:
        raise ValueError('Shape of tensor is neither 2D or 3D.')

# Asymmetric Focal Loss for single-channel output
class AsymmetricFocalLoss(nn.Module):
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, logits, y_true):
        y_pred = torch.sigmoid(logits)  # Applying sigmoid to convert logits to probabilities
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * torch.log(y_pred) - (1 - y_true) * torch.log(1 - y_pred)

        # Calculate the loss for positive and negative classes
        pos_loss = torch.pow(1 - y_pred, self.gamma) * cross_entropy
        neg_loss = torch.pow(y_pred, self.gamma) * cross_entropy

        # Weighted sum of the losses
        loss = torch.mean((self.delta * pos_loss) + ((1 - self.delta) * neg_loss))
        return loss

# Asymmetric Focal Tversky Loss for single-channel output
class AsymmetricFocalTverskyLoss(nn.Module):
    def __init__(self, delta=0.7, gamma=0.75, epsilon=1e-07):
        super(AsymmetricFocalTverskyLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, logits, y_true):
        y_pred = torch.sigmoid(logits)  # Applying sigmoid to convert logits to probabilities
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        axis = identify_axis(y_true.size())

        tp = torch.sum(y_true * y_pred, axis=axis)
        fn = torch.sum(y_true * (1 - y_pred), axis=axis)
        fp = torch.sum((1 - y_true) * y_pred, axis=axis)

        tversky_index = (tp + self.epsilon) / (tp + self.delta * fn + (1 - self.delta) * fp + self.epsilon)
        loss = (1 - tversky_index) * torch.pow(1 - tversky_index, -self.gamma)

        return torch.mean(loss)

# Asymmetric Unified Focal Loss for single-channel output
class AsymmetricUnifiedFocalLoss(nn.Module):
    def __init__(self, weight=0.5, delta=0.6, gamma=0.2):
        super(AsymmetricUnifiedFocalLoss, self).__init__()
        self.weight = weight
        self.delta = delta
        self.gamma = gamma

    def forward(self, logits, y_true):
        asymmetric_ftl = AsymmetricFocalTverskyLoss(delta=self.delta, gamma=self.gamma)(logits, y_true)
        asymmetric_fl = AsymmetricFocalLoss(delta=self.delta, gamma=self.gamma)(logits, y_true)

        if self.weight is not None:
            return (self.weight * asymmetric_ftl) + ((1 - self.weight) * asymmetric_fl)
        else:
            return asymmetric_ftl + asymmetric_fl

class SurfaceLossBinary(nn.Module):
    def __init__(self, idc):
        super(SurfaceLossBinary, self).__init__()
        # Self.idc is used to filter out some classes of the target mask. Use fancy indexing
        self.idc = idc
        print(f"Initialized {self.__class__.__name__} with {idc}")

    def forward(self, probs, dist_maps):

        pc = probs[:, 0, ...].type(torch.float32)
        dc = dist_maps[:, 1, ...].type(torch.float32)

        multipled = einsum("bwh,bwh->bwh", pc, dc)

        loss = multipled.mean()

        return loss

class CABFL(nn.Module):
    def __init__(self, idc, weight_aufl=0.5, delta=0.6, gamma=0.2):
        super(CABFL, self).__init__()
        self.boundaryLoss = SurfaceLossBinary(idc=idc)
        self.aufl =  AsymmetricUnifiedFocalLoss(delta=delta, gamma=gamma, weight=weight_aufl)
        self.alpha = 0.01
        self.current_epoch = 0

    def norm_distmap(self,distmap):
        _m: float = torch.abs(distmap).max()
        return distmap / _m

    def forward(self, logits, probs, dist_maps, gts, current_epoch):
        if current_epoch != self.current_epoch:
            self.current_epoch = current_epoch
            self.alpha = min(self.alpha+0.01, 0.99)

        bl = self.boundaryLoss(probs, self.norm_distmap(dist_maps))
        aufl = self.aufl(logits, gts)
        
        return (1-self.alpha)*aufl + self.alpha*bl

def compute_dice_score_npy(y_true, y_pred, class_id=1, reduction='micro',exclude_empty=False):
    """
    Compute Dice Score for a specific class with reduction options, for input arrays in HxWxB format using NumPy.
    
    Args:
    y_true (np.array): Ground truth, a 3D array (height, width, batch size).
    y_pred (np.array): Predictions, a 3D array (height, width, batch size).
    class_id (int): The class ID for which to compute the Dice Score.
    reduction (str): Method of reduction across the batch, either 'micro' or 'micro-imagewise'.
    
    Returns:
    float: The Dice Score.
    """

    def compute_dice_score_single(y_true_single, y_pred_single, class_id_single, exclude_empty=False):
        y_true_class = (y_true_single == class_id_single).astype(np.float32)
        y_pred_class = (y_pred_single == class_id_single).astype(np.float32)

        intersection = np.sum(y_true_class * y_pred_class)
        union = np.sum(y_true_class) + np.sum(y_pred_class)

        if union == 0:
            # Both prediction and ground truth are empty for this class
            if exclude_empty:
                dice_score =  float('nan')
            else:
                dice_score= torch.tensor(1.0)  # Assuming perfect IoU score when both prediction and GT are empty
        else:
            dice_score = (2. * intersection) / (union)  # Adding epsilon to avoid division by zero
        
        return dice_score

    if reduction == 'micro':
        # Reshape to combine height, width, and batch into a single dimension
        y_true_flat = y_true.reshape(-1)
        y_pred_flat = y_pred.reshape(-1)
        dice_score = compute_dice_score_single(y_true_flat, y_pred_flat, class_id, exclude_empty)
        return dice_score

    elif reduction == 'micro-imagewise':
        # Compute Dice Score for each image separately and then average
        dice_scores = np.array([compute_dice_score_single(y_true[:, :, i], y_pred[:, :, i], class_id,exclude_empty) for i in range(y_true.shape[2])])
        return np.nanmean(dice_scores)  # Using nanmean to ignore NaN values in case of empty classes

    elif reduction == 'none':
        # Compute and return Dice Score for each image separately without averaging
        return np.array([compute_dice_score_single(y_true[:, :, i], y_pred[:, :, i], class_id, exclude_empty) for i in range(y_true.shape[2])])


    else:
        raise ValueError("Reduction method should be either 'micro' or 'micro-imagewise'")

def compute_iou_npy(y_true, y_pred, class_id=1, reduction='micro', exclude_empty=False):
    """
    Compute Intersection over Union for a specific class using NumPy.

    Args:
    y_true (np.ndarray): batch of ground truth, 4D tensor (first dimension is batch size)
    y_pred (np.ndarray): batch of prediction, 4D tensor (first dimension is batch size)
    class_id (int): the class to compute IoU for
    reduction (str): the method of reduction across the batch, can be 'micro' or 'micro-imagewise'

    Returns:
    np.ndarray: IoU score
    """

    def compute_iou_single(y_true_single, y_pred_single, class_id_single, exclude_empty=False):
        y_true_class = np.where(y_true_single == class_id_single, 1, 0)
        y_pred_class = np.where(y_pred_single == class_id_single, 1, 0)

        intersection = np.logical_and(y_true_class, y_pred_class)
        union = np.logical_or(y_true_class, y_pred_class)

        union_sum = np.sum(union)
        if union_sum == 0:
            # Both prediction and ground truth are empty
            if exclude_empty:
                iou_score = float('nan')
            else:
                iou_score = 1.0
        else:
            iou_score = np.sum(intersection).astype(float) / union_sum.astype(float)

        return iou_score

    assert reduction in ['micro', 'micro-imagewise'], "Reduction method should be either 'micro' or 'micro_image_wise'"

    if reduction == 'micro':
        y_true = y_true.reshape(-1)
        y_pred = y_pred.reshape(-1)
        return compute_iou_single(y_true, y_pred, class_id, exclude_empty)

    elif reduction == 'micro-imagewise':
        iou_scores = np.array([compute_iou_single(y, p, class_id, exclude_empty) for y, p in zip(y_true, y_pred)], dtype=np.float32)
        return np.nanmean(iou_scores)  # Using nanmean to ignore NaN values



def select_slices_based_on_gt(gt_volume, pred_volume):
    # Select slices with positive pixels in ground truth
    positive_slices = (gt_volume > 0)
    return gt_volume[positive_slices], pred_volume[positive_slices]

def compute_classwise_volumetric_iou(gt_volume, pred_volume, num_classes, exclude_empty=False):
    # Initialize IoU scores for each class
    iou_scores = np.zeros(num_classes)

    for class_id in range(num_classes):
        # Consider only voxels belonging to the current class
        gt_class_bool = gt_volume == class_id
        pred_class_bool = pred_volume == class_id

        intersection = np.logical_and(gt_class_bool, pred_class_bool)
        union = np.logical_or(gt_class_bool, pred_class_bool)

        # Calculate IoU, setting it to 1 if the union is zero
        if np.sum(union) == 0:
            if exclude_empty:
                iou_scores[class_id] = float('nan')
            else:
                iou_scores[class_id] = 1.0
        else:
            iou_scores[class_id] = np.sum(intersection) / np.sum(union)

    return iou_scores


def ra_iou(tp, fp, fn, beta=0.5):
    
    ra_iou = tp / (tp + beta * fn + fp + 1e-6)  # Adding epsilon to avoid division by zero
    
    return ra_iou.mean()

In [9]:
class PairedDataLoader(DataLoader):
    def __init__(self, dataset1, dataset2, batch_size, shuffle, worker_init_fn, generator, drop_last,augment=False):
        paired_dataset = PairedDataset(dataset1, dataset2,augment=augment)
        super().__init__(paired_dataset, batch_size=batch_size, shuffle=shuffle, worker_init_fn=worker_init_fn, generator=generator, drop_last=drop_last)

class PairedDataset(Dataset):
    def __init__(self, dataset1, dataset2, augment):
        self.dataset1 = dataset1
        self.dataset2 = dataset2
        self.augmentations = Compose([monai.transforms.RandHistogramShiftd(keys=['image'], prob=0.2, num_control_points=4), 
                                      monai.transforms.RandRotated(keys=['image', 'label'],mode='nearest-exact', range_x=[0.1, 0.1], prob=0.3),
                                      monai.transforms.RandZoomd(keys=['image', 'label'],mode='nearest-exact', min_zoom = 1.3, max_zoom = 1.5, prob=0.3),
                                      #monai.transforms.RandCoarseDropoutd(keys=['image', 'label'], prob=0.3, holes=20, spatial_size=20, fill_value =0)
                                     ]
            
        
        )

        self.augment=augment
    
        
    def __len__(self):
        return min(len(self.dataset1), len(self.dataset2))

    def filter_data(self,data):
        if not data['keep_sample']:
            data['image']=monai.data.MetaTensor(torch.zeros_like(data['image']))
        return data
    
    def __getitem__(self, idx):

        data1 = self.dataset1[idx]
        data2 = self.dataset2[idx]

        data2_flat = [self.filter_data(item) for item in data2]
        data2_1 = data2_flat[0]
        data2_2 = data2_flat[1]

        if self.augment:
            self.augmentations.set_random_state(seed=idx)

            data2_1 = {
                'image' : copy.deepcopy(data2_1['image']),
                'label' : copy.deepcopy(data2_1['label']),
                'boundary' : copy.deepcopy(data2_1['boundary'])
            }
            data2_2 = {
                'image' : copy.deepcopy(data2_2['image']),
                'label' : copy.deepcopy(data2_2['label']),
                'boundary' : copy.deepcopy(data2_2['boundary'])
            }
            data1 = {
                'image' : copy.deepcopy(data1['image']),
                'label' : copy.deepcopy(data1['label']),
                'boundary' : copy.deepcopy(data1['boundary'])
            }                                    
                                        
            data2_1 = self.augmentations(data2_1)
            
            self.augmentations.set_random_state(seed=idx)
            data2_2 = self.augmentations(data2_2)
            
            self.augmentations.set_random_state(seed=idx)
            data1= self.augmentations(data1)

        return data1, data2_1, data2_2

# Model

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, F_int, use_attention=True):
        super(DecoderBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=2, stride=2)
        # Adjust the attention gate to handle the combined skip connections
        if use_attention:
            self.attention = AttentionGate(F_g=mid_channels, F_l=mid_channels, F_int=F_int)
        self.conv = ConvBlock(mid_channels + mid_channels, out_channels)  # Adjust for concatenated skip connection size
        self.use_attention=use_attention

    def forward(self, x, combined_skip):

        x = self.up(x)
        # Apply attention to the combined skip connections
        if self.use_attention:
            combined_skip = self.attention(g=x, x=combined_skip)
        
        x = torch.cat([x, combined_skip], dim=1)
        x = self.conv(x)
        return x
    

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_pooled = self.avg_pool(x).view(b, c)
        max_pooled = self.max_pool(x).view(b, c)
        avg_out = self.fc2(self.relu(self.fc1(avg_pooled)))
        max_out = self.fc2(self.relu(self.fc1(max_pooled)))
        out = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * out
        
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        x = torch.cat([max_out, avg_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class FeatureFusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(FeatureFusionBlock, self).__init__()
        self.channel_attention_local = ChannelAttention(in_channels)
        self.spatial_attention_local = SpatialAttention()

        self.channel_attention_global = ChannelAttention(in_channels)
        self.spatial_attention_global = SpatialAttention()
        
        self.fusion_conv = nn.Conv2d(in_channels*3, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, global_feat, local_feat1, local_feat2):
        # Apply channel attention to each feature map
        global_ca = global_feat * self.channel_attention_global(global_feat)
        local_ca1 = local_feat1 * self.channel_attention_local(local_feat1)
        local_ca2 = local_feat2 * self.channel_attention_local(local_feat2)

        # Apply spatial attention to each feature map
        global_sa = global_ca * self.spatial_attention_global(global_ca)
        local_sa1 = local_ca1 * self.spatial_attention_local(local_ca1)
        local_sa2 = local_ca2 * self.spatial_attention_local(local_ca2)

        # Concatenate the feature maps
        fused_features = torch.cat((global_sa, local_sa1, local_sa2), dim=1)
        
        # Fuse them using a convolutional layer
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.relu(fused_features)
        
        return fused_features

class SimpleFeatureFusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleFeatureFusionBlock, self).__init__()
        # Since we're concatenating three feature maps, the input to the fusion_conv will be 3 times in_channels
        self.fusion_conv = nn.Conv2d(in_channels * 3, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, global_feat, local_feat1, local_feat2):
        # Concatenate the feature maps along the channel dimension
        fused_features = torch.cat((global_feat, local_feat1, local_feat2), dim=1)
        
        # Apply a convolutional layer to reduce dimensions
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.relu(fused_features)
        
        return fused_features



class MultiInputUNet(nn.Module):
    def __init__(self, n_channels, n_classes, use_simple_fusion=False, use_decoder_attention=True):
        super(MultiInputUNet, self).__init__()

        if use_simple_fusion:
            self.fusion_skip1 = SimpleFeatureFusionBlock(64, 64)
            self.fusion_skip2 = SimpleFeatureFusionBlock(128, 128)
            self.fusion_skip3 = SimpleFeatureFusionBlock(256, 256)
            self.fusion_skip4 = SimpleFeatureFusionBlock(512, 512)
        else:
          # Initialize fusion blocks with channel reduction
            self.fusion_skip1 = FeatureFusionBlock(64, 64)
            self.fusion_skip2 = FeatureFusionBlock(128, 128)
            self.fusion_skip3 = FeatureFusionBlock(256, 256)
            self.fusion_skip4 = FeatureFusionBlock(512,512)
        
        # Encoders for each input stream
        self.encoder1 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])
        self.encoder2 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])
        self.encoder3 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])

        if use_simple_fusion:
            self.deep_feature_fusion = SimpleFeatureFusionBlock(512, 512)
        else:
             self.deep_feature_fusion = FeatureFusionBlock(512,512)
        
        
        # Decoder Blocks
        self.decoder1 = DecoderBlock(512,512, 256,256, use_attention=use_decoder_attention)  # Input channels adjusted for merged features
        self.decoder2 = DecoderBlock(256, 256, 128, 128,use_attention=use_decoder_attention)  # Input channels adjusted for merged features
        self.decoder3 = DecoderBlock(128, 128, 64, 64,use_attention=use_decoder_attention)
        self.decoder4 = DecoderBlock(64, 64, 32, 32,use_attention=use_decoder_attention)

        self.final_conv = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x1, x2, x3):
        # Process each input through its respective encoders
        skips1, p1 = self.process_through_encoders(x1, self.encoder1)
        skips2, p2 = self.process_through_encoders(x2, self.encoder2)
        skips3, p3 = self.process_through_encoders(x3, self.encoder3)

        fused_skips1 = self.fusion_skip1(skips1[0], skips2[0], skips3[0])
        fused_skips2 = self.fusion_skip2(skips1[1], skips2[1], skips3[1])
        fused_skips3 = self.fusion_skip3(skips1[2], skips2[2], skips3[2])
        fused_skips4 = self.fusion_skip4(skips1[3], skips2[3], skips3[3])

        fused_features = self.deep_feature_fusion(p1, p2, p3)
        
        # Decode the combined features
        # Note: Attention mechanism applies to concatenated skip connections from corresponding layers of each input
        d1 = self.decoder1(fused_features, fused_skips4)
        d2 = self.decoder2(d1, fused_skips3)
        d3 = self.decoder3(d2, fused_skips2)
        d4 = self.decoder4(d3, fused_skips1)

        return self.final_conv(d4)

    def process_through_encoders(self, x, encoders):
        skips = []
        p = x
        for encoder in encoders:
            x, p = encoder(p)
            skips.append(x)
        return skips, p  # Reverse skips for correct order in decoding

## SegNet

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

class SegNet(nn.Module):
    def __init__(self,input_nbr,label_nbr):
        super(SegNet, self).__init__()

        batchNorm_momentum = 0.1

        self.conv11 = nn.Conv2d(input_nbr, 64, kernel_size=3, padding=1)
        self.bn11 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12 = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv21 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn21 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv22 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22 = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv31 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn31 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv33 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33 = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv41 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn41 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv43 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv51 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv53 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53 = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv53d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn53d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv52d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn52d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv51d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn51d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)

        self.conv43d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn43d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv42d = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn42d = nn.BatchNorm2d(512, momentum= batchNorm_momentum)
        self.conv41d = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.bn41d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)

        self.conv33d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn33d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv32d = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.bn32d = nn.BatchNorm2d(256, momentum= batchNorm_momentum)
        self.conv31d = nn.Conv2d(256,  128, kernel_size=3, padding=1)
        self.bn31d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)

        self.conv22d = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn22d = nn.BatchNorm2d(128, momentum= batchNorm_momentum)
        self.conv21d = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.bn21d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)

        self.conv12d = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn12d = nn.BatchNorm2d(64, momentum= batchNorm_momentum)
        self.conv11d = nn.Conv2d(64, label_nbr, kernel_size=3, padding=1)


    def forward(self, x):

        # Stage 1
        x11 = F.relu(self.bn11(self.conv11(x)))
        x12 = F.relu(self.bn12(self.conv12(x11)))
        x1p, id1 = F.max_pool2d(x12,kernel_size=2, stride=2,return_indices=True)

        # Stage 2
        x21 = F.relu(self.bn21(self.conv21(x1p)))
        x22 = F.relu(self.bn22(self.conv22(x21)))
        x2p, id2 = F.max_pool2d(x22,kernel_size=2, stride=2,return_indices=True)

        # Stage 3
        x31 = F.relu(self.bn31(self.conv31(x2p)))
        x32 = F.relu(self.bn32(self.conv32(x31)))
        x33 = F.relu(self.bn33(self.conv33(x32)))
        x3p, id3 = F.max_pool2d(x33,kernel_size=2, stride=2,return_indices=True)

        # Stage 4
        x41 = F.relu(self.bn41(self.conv41(x3p)))
        x42 = F.relu(self.bn42(self.conv42(x41)))
        x43 = F.relu(self.bn43(self.conv43(x42)))
        x4p, id4 = F.max_pool2d(x43,kernel_size=2, stride=2,return_indices=True)

        # Stage 5
        x51 = F.relu(self.bn51(self.conv51(x4p)))
        x52 = F.relu(self.bn52(self.conv52(x51)))
        x53 = F.relu(self.bn53(self.conv53(x52)))
        x5p, id5 = F.max_pool2d(x53,kernel_size=2, stride=2,return_indices=True)


        # Stage 5d
        torch.use_deterministic_algorithms(False)
        x5d = F.max_unpool2d(x5p, id5, kernel_size=2, stride=2)
        torch.use_deterministic_algorithms(True)
        x53d = F.relu(self.bn53d(self.conv53d(x5d)))
        x52d = F.relu(self.bn52d(self.conv52d(x53d)))
        x51d = F.relu(self.bn51d(self.conv51d(x52d)))

        # Stage 4d
        torch.use_deterministic_algorithms(False)
        x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2)
        torch.use_deterministic_algorithms(True)
        x43d = F.relu(self.bn43d(self.conv43d(x4d)))
        x42d = F.relu(self.bn42d(self.conv42d(x43d)))
        x41d = F.relu(self.bn41d(self.conv41d(x42d)))

        # Stage 3d
        torch.use_deterministic_algorithms(False)
        x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2)
        torch.use_deterministic_algorithms(True)
        x33d = F.relu(self.bn33d(self.conv33d(x3d)))
        x32d = F.relu(self.bn32d(self.conv32d(x33d)))
        x31d = F.relu(self.bn31d(self.conv31d(x32d)))

        # Stage 2d
        torch.use_deterministic_algorithms(False)
        x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2)
        torch.use_deterministic_algorithms(True)
        x22d = F.relu(self.bn22d(self.conv22d(x2d)))
        x21d = F.relu(self.bn21d(self.conv21d(x22d)))

        # Stage 1d
        torch.use_deterministic_algorithms(False)
        x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2)
        torch.use_deterministic_algorithms(True)
        x12d = F.relu(self.bn12d(self.conv12d(x1d)))
        x11d = self.conv11d(x12d)

        return x11d

    def load_from_segnet(self, model_path):
        s_dict = self.state_dict()# create a copy of the state dict
        th = torch.load(model_path).state_dict() # load the weigths
        # for name in th:
            # s_dict[corresp_name[name]] = th[name]
        self.load_state_dict(th)

## FCFFNET

In [12]:
import operator

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)
    
class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)
    
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class NeighConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.neigh_conv = nn.Conv2d(in_channels, out_channels, kernel_size=2, padding=1)
    
    def forward(self, x):
        return self.neigh_conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)
    

def compute_class_weight(labels):
    """
    Compute class weights for binary classification with labels of shape (B, 1, C, H, W).
    """
    labels = labels.cpu()
    labels = labels.squeeze(1).reshape(-1)  # Flatten labels to (B * C * H * W)
    unique_labels = np.unique(labels.numpy())  # Get unique class labels
    
    class_freq = {}
    for label in unique_labels:
        class_freq[label] = np.sum(labels.numpy() == label)  # Count occurrences of each class
    
    max_value = max(class_freq.values())  # Get the largest class frequency
    
    class_weights = {}
    for label in unique_labels:
        class_weights[label] = max_value / class_freq[label]  # Compute class weights
    
    # Convert to a PyTorch tensor
    return torch.tensor([class_weights[label] for label in sorted(class_weights.keys())]).float()



def CrossEntropy2d(input, target, weight=None, reduction='mean'):
    """
    Binary cross entropy for 2D inputs with shape (B, H, W).
    `weight` is applied per class.
    """
    # Ensure input and target have the correct dimensions
    target = target.float()  # Convert target to float for BCEWithLogitsLoss
    
    # Handle class weights
    if weight is not None:
        # Expand weight to match the shape of the input/target
        weight = weight[target.long()]  # Select weights based on the class labels
    
    return F.binary_cross_entropy_with_logits(input, target, weight=weight, reduction=reduction)


class FcnnFnet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(FcnnFnet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.fc1 = OutConv(256, n_classes)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.fc2 = OutConv(128, n_classes)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.fc3 = OutConv(64, n_classes)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
        self.neigh = NeighConv(n_classes, n_classes)

    def forward(self, x):

        activations = []
        out_fc = []
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        out1 = self.fc1(x)
        out_fc.append(out1)
        activations.append(x)
        x = self.up2(x, x3)
        out2 = self.fc2(x)
        out_fc.append(out2)
        activations.append(x)
        x = self.up3(x, x2)
        out3 = self.fc3(x)
        out_fc.append(out3)
        activations.append(x)
        x = self.up4(x, x1)
        logits = self.outc(x)
        neigh = self.neigh(logits)[:,:,:-1, :-1]
        return logits, out_fc, neigh, activations

## Skinny

In [13]:
def get_filters_count(level: int, base_filters: int) -> int:
    """Calculate the number of filters at each level."""
    return base_filters * (2 ** (level - 1))

###############################################################################
# Inception Module
###############################################################################
class InceptionModule(nn.Module):
    """Inception module with 4 parallel branches."""
    def __init__(self, in_channels: int, out_channels: int,
                 activation=nn.LeakyReLU(0.3, inplace=True)):
        super().__init__()
        self.activation = activation
        branch_ch = out_channels // 4

        # Branch 1: 1×1
        self.branch1_1x1 = nn.Conv2d(in_channels, branch_ch, kernel_size=1)

        # Branch 2: 1×1 -> 3×3
        self.branch2_1x1 = nn.Conv2d(in_channels, branch_ch, kernel_size=1)
        self.branch2_3x3 = nn.Conv2d(branch_ch, branch_ch, kernel_size=3, padding=1)

        # Branch 3: 1×1 -> 5×5
        self.branch3_1x1 = nn.Conv2d(in_channels, branch_ch, kernel_size=1)
        self.branch3_5x5 = nn.Conv2d(branch_ch, branch_ch, kernel_size=5, padding=2)

        # Branch 4: MaxPool -> 1×1
        self.pool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.branch4_1x1 = nn.Conv2d(in_channels, branch_ch, kernel_size=1)

    def forward(self, x):
        b1 = self.activation(self.branch1_1x1(x))
        b2 = self.activation(self.branch2_3x3(self.branch2_1x1(x)))
        b3 = self.activation(self.branch3_5x5(self.branch3_1x1(x)))
        b4 = self.activation(self.branch4_1x1(self.pool(x)))
        return torch.cat([b1, b2, b3, b4], dim=1)



class SkinnyNet(nn.Module):
    """
    'Skinny' U-Net
    """
    def __init__(self, image_channels=3, levels=6, base_filters=19):
        super().__init__()
        self.levels = levels
        self.base_filters = base_filters
        self.activation = nn.LeakyReLU(0.3, inplace=True)

        # ---------------------------------------------------------------------
        # Contracting (Down) Path
        # ---------------------------------------------------------------------
        self.down_convs = nn.ModuleList()
        self.down_inceptions = nn.ModuleList()
        self.pools = nn.ModuleList()

        # Record the actual # of output channels at each level
        self.down_channels = []

        in_ch = image_channels
        for lvl in range(1, levels + 1):
            out_ch = get_filters_count(lvl, base_filters)

            # Conv -> BN -> Activation
            conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
            bn = nn.BatchNorm2d(out_ch)
            self.down_convs.append(nn.Sequential(conv, bn, self.activation))

            # Inception
            inc = InceptionModule(in_channels=out_ch, out_channels=out_ch,
                                  activation=self.activation)
            self.down_inceptions.append(inc)

            # Actual output channels after Inception
            actual_out_ch = (out_ch // 4) * 4
            self.down_channels.append(actual_out_ch)

            # MaxPool except at the last (bottom) level
            if lvl < levels:
                self.pools.append(nn.MaxPool2d(kernel_size=2, stride=2))
            else:
                self.pools.append(None)

            # Update input channels for the next level
            in_ch = actual_out_ch

        # ---------------------------------------------------------------------
        # Expanding (Up) Path
        # ---------------------------------------------------------------------
        self.up_convs = nn.ModuleList()
        self.up_inceptions = nn.ModuleList()

        for lvl in range(levels, 1, -1):
            # Skip connection channels from contracting path
            skip_channels = self.down_channels[lvl - 2]
            # Bottom (upsampled) feature channels from the contracting path
            bottom_channels = self.down_channels[lvl - 1]
            # Total input channels = skip + bottom
            total_channels = bottom_channels + skip_channels

            # Create convolutional layers for expanding path
            up_conv = nn.Conv2d(total_channels, skip_channels, kernel_size=3, padding=1)
            bn_up = nn.BatchNorm2d(skip_channels)
            self.up_convs.append(nn.Sequential(up_conv, bn_up, self.activation))

            # Create InceptionModule for expanding path
            up_inception = InceptionModule(in_channels=skip_channels, out_channels=skip_channels,
                                           activation=self.activation)
            self.up_inceptions.append(up_inception)

        # ---------------------------------------------------------------------
        # Final: 1×1 or 3×3 conv to 1 channel
        # ---------------------------------------------------------------------
        self.final_conv = nn.Conv2d(self.down_channels[0], 1, kernel_size=3, padding=1)

    def forward(self, x):
        """
        1) Contracting path:
           for i in [0..levels-1]:
             x -> conv -> inception -> store skip
             if i < levels-1: pool
        2) Expanding path:
           for i in [levels-1..1]:
             upsample
             cat with skip
             conv + inception
        3) Final conv -> sigmoid
        """
        # -------------------------------------
        # Contracting Path
        # -------------------------------------
        downs = []
        curr = x
        for i in range(self.levels):
            curr = self.down_convs[i](curr)       # Conv
            curr = self.down_inceptions[i](curr) # Inception
            downs.append(curr)
            if self.pools[i] is not None:
                curr = self.pools[i](curr)

        # -------------------------------------
        # Expanding Path
        # -------------------------------------
        for i, (up_conv, up_inception) in enumerate(zip(self.up_convs, self.up_inceptions)):
            curr = F.interpolate(curr, scale_factor=2, mode='nearest')
            skip = downs[-(i + 2)]  # Skip connection from contracting path
            curr = torch.cat([curr, skip], dim=1)  # Concatenate along channels
            curr = up_conv(curr)  # Apply Conv -> BN -> Activation
            curr = up_inception(curr)  # Apply Inception

        # -------------------------------------
        # Final Layer
        # -------------------------------------
        curr = self.final_conv(curr)
        return curr


## BMS

In [14]:
from torch.autograd import Variable

class BreastModel2(L.LightningModule):

    def __init__(self,
                 arch,
                 encoder_name,
                 in_channels,
                 out_classes,
                 batch_size,
                 len_train_loader,
                 threshold = 0.4,
                 t_loss = AsymmetricUnifiedFocalLoss(delta=0.4, gamma=0.1),
                 boundaryloss=False,
                 img_size = None,
                 **kwargs
                 ):
        super().__init__()


        self.mode = "base"

        if arch == 'base_unet':
            self.model = UNet(n_channels=in_channels, n_classes = out_classes)

        elif arch == "segnet":
            self.model = SegNet(input_nbr=in_channels, label_nbr=out_classes)
            
        elif arch == "swin_unetr":
            self.model = SwinUNETR(img_size = img_size, in_channels=in_channels, out_channels = out_classes, spatial_dims=2, use_v2=True, downsample="mergingv2")

        elif arch == "unetplusplus":
            self.model = BasicUNetPlusPlus(in_channels=in_channels, out_channels = out_classes, spatial_dims=2)

        elif arch == "skinny":
            self.model = SkinnyNet(image_channels=in_channels, levels=6, base_filters=19)

        elif arch == "fcn_ffnet":
            self.model = FcnnFnet(n_channels=in_channels, n_classes=out_classes, bilinear=True)
            self.mode = "fcn_ffnet"

            
        else:
            aux_params = dict(
                pooling = 'avg',  # one of 'avg', 'max'
                dropout = 0.5,  # dropout ratio, default is None
                activation = None,  # activation function, default is None
                classes = out_classes,  # define number of output labels
            )
        
            self.model = smp.create_model(
                   arch, encoder_name=encoder_name,
                    aux_params = aux_params, 
                   in_channels = in_channels, encoder_weights=None,**kwargs)

        self.t_loss = t_loss
        self.train_outputs = []
        self.val_outputs = []
        self.test_outputs = []
        self.save_hyperparameters()
        self.threshold = threshold
        self.batch_size = batch_size
        self.len_train_loader = len_train_loader
        self.augment_inference=False
        self.val=80
        self.boundaryloss = boundaryloss

    def ttaug(self, model, image):
        transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Rotate90(angles=[0, 180]),
            #tta.Scale(scales=[1, 2, 4]),
            #tta.Multiply(factors=[0.8]),        
        ]
        )
    
        model.eval()
        merger = Merger(type="tsharpen", n=len(transforms))
        for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
        
            # augment image
            augmented_image = transformer.augment_image(image)
            augmented_image = augmented_image.to("cuda")
    
            model_output = model(augmented_image)[0]
            
            # reverse augmentation for mask and label
            deaug_mask = transformer.deaugment_mask(model_output)
            #deaug_mask = deaug_mask.sigmoid()
            
            merger.append(deaug_mask)
    
        masks = merger.result
        #masks = masks.sigmoid()

        return masks


    def base_step(self,stage, image, mask, batch):
        if stage=='test' and self.augment_inference:
            prob_mask = self.ttaug(self.model, image)
            logits_mask = self.forward(image)
            if isinstance(logits_mask, List):
                logits_mask = logits_mask[0]

        else:
            logits_mask = self.forward(image)
            if isinstance(logits_mask, List):
                logits_mask = logits_mask[0]
            prob_mask = logits_mask.sigmoid()
            
        if self.boundaryloss:
            dist_map = batch["boundary"].to("cuda")
            t_loss = self.t_loss(logits_mask, prob_mask, dist_map, mask, self.current_epoch)

        else:
            t_loss = self.t_loss(logits_mask, mask)
            
        loss = t_loss

        return loss, prob_mask

    def fcn_ffnet_step(self,stage, image, mask, batch):
        # Ensure mask has the correct shape
        mask = mask.data.cpu().numpy()  # Convert to numpy
        mask = np.squeeze(mask, axis=1)  # Remove the singleton channel dimension (B, H, W)
        
        # Resize and prepare the masks at different scales
        mask3 = np.stack([cv2.resize(m, dsize=(128, 128), interpolation=cv2.INTER_NEAREST) for m in mask], axis=0)
        mask2 = np.stack([cv2.resize(m, dsize=(64, 64), interpolation=cv2.INTER_NEAREST) for m in mask], axis=0)
        mask1 = np.stack([cv2.resize(m, dsize=(32, 32), interpolation=cv2.INTER_NEAREST) for m in mask], axis=0)
        
        # Convert resized masks back to tensors and adjust dimensions
        mask3 = torch.from_numpy(mask3).unsqueeze(1).type(torch.LongTensor).cuda()  # (B, 1, 128, 128)
        mask2 = torch.from_numpy(mask2).unsqueeze(1).type(torch.LongTensor).cuda()  # (B, 1, 64, 64)
        mask1 = torch.from_numpy(mask1).unsqueeze(1).type(torch.LongTensor).cuda()  # (B, 1, 32, 32)
        
        # Prepare the original mask
        mask = torch.from_numpy(mask).unsqueeze(1).type(torch.LongTensor).cuda()  # (B, 1, H, W)
        
        # Add weights for MLP loss
        image = Variable(image.cuda())
        mask = Variable(mask)
        
        # Forward pass
        output, out_fc, out_neigh = self.model(image)[:3]

        # Compute losses
        loss = self.t_loss(output, mask, weight=None)

        loss_fc1 = CrossEntropy2d(out_fc[0], Variable(mask1), weight=compute_class_weight(mask1).cuda())
        loss_fc2 = CrossEntropy2d(out_fc[1], Variable(mask2), weight=compute_class_weight(mask2).cuda())
        loss_fc3 = CrossEntropy2d(out_fc[2], Variable(mask3), weight=compute_class_weight(mask3).cuda())
        pairwise_loss = self.t_loss(out_neigh, mask, weight=None)
        
        # Combine losses
        loss = (loss + loss_fc1 + loss_fc2 + loss_fc3) / 4 + pairwise_loss
        
        # Output probabilities
        prob_mask = output.sigmoid()
        
        return loss, prob_mask


            
    def step(self, batch, batch_idx, stage, mode='base'):
        
        image = batch["image"].to("cuda")

        assert image.ndim == 4
        h, w = image.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch["label"].to("cuda")
        assert mask.ndim == 4

        epsilon = 1e-6  # Define a small epsilon value for margin of error

        assert mask.max() <= 1 + epsilon and mask.min() >= -epsilon

        if mode == "fcn_ffnet":
            loss, prob_mask= self.fcn_ffnet_step(stage, image, mask, batch)
        else:
            loss, prob_mask = self.base_step(stage, image, mask, batch)

        pred_mask = (prob_mask > self.threshold).float()

        if stage=='test' and self.augment_inference:
            pred_mask = (prob_mask > self.threshold).int()
            pred_mask = remove_far_masses_based_on_largest_mass(pred_mask, distance_threshold=10)
            pred_mask = RemoveSmallObjects(connectivity=1, min_size=140)(pred_mask)
            pred_mask = torch.Tensor(pred_mask).to("cuda")

        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode = "binary")

        iou_per_image_mass = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise")
        iou_per_image_background = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise")

        iou_per_image_mass_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise", exclude_empty=True)
        iou_per_image_background_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise",exclude_empty=True)

        iou_per_dataset_mass = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro")
        iou_per_dataset_background = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro")

        iou_per_dataset_mass_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro",exclude_empty=True)
        iou_per_dataset_background_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro",exclude_empty=True)

        ###

        dice_per_image_mass = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise")
        dice_per_image_background = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise")

        dice_per_image_mass_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise", exclude_empty=True)
        dice_per_image_background_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise",exclude_empty=True)

        dice_per_dataset_mass = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro")
        dice_per_dataset_background = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro")

        dice_per_dataset_mass_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro",exclude_empty=True)
        dice_per_dataset_background_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro",exclude_empty=True)


        ###

        acc_per_image_mass = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 1,
                                                           averaging = "micro_image_wise")
        acc_per_image_background = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 0,
                                                                 averaging = "micro_image_wise")

        acc_per_dataset_mass = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 1,
                                                              averaging = "micro")
        acc_per_dataset_background = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 0,
                                                                   averaging = "micro")

        loss = loss.to('cpu')

        output = {
            "loss":                       loss,
            "tp":                         tp,
            "fp":                         fp,
            "fn":                         fn,
            "tn":                         tn,
            "iou_per_image_mass":         iou_per_image_mass,
            "iou_per_image_background":   iou_per_image_background,
            "iou_per_dataset_mass":      iou_per_dataset_mass,
            "iou_per_dataset_background": iou_per_dataset_background,
            "iou_per_image_mass_no_empty":         iou_per_image_mass_no_empty,
            "iou_per_image_background_no_empty":   iou_per_image_background_no_empty,
            "iou_per_dataset_mass_no_empty":      iou_per_dataset_mass_no_empty,
            "iou_per_dataset_background_no_empty": iou_per_dataset_background_no_empty, ###
            "dice_per_image_mass":         dice_per_image_mass,
            "dice_per_image_background":   dice_per_image_background,
            "dice_per_dataset_mass":      dice_per_dataset_mass,
            "dice_per_dataset_background": dice_per_dataset_background,
            "dice_per_image_mass_no_empty":         dice_per_image_mass_no_empty,
            "dice_per_image_background_no_empty":   dice_per_image_background_no_empty,
            "dice_per_dataset_mass_no_empty":      dice_per_dataset_mass_no_empty,
            "dice_per_dataset_background_no_empty": dice_per_dataset_background_no_empty,
            "acc_per_image_mass":         acc_per_image_mass,
            "acc_per_image_background":   acc_per_image_background,
            "acc_per_dataset_mass":      acc_per_dataset_mass,
            "acc_per_dataset_background": acc_per_dataset_background
        }

        if stage=='train':
            self.train_outputs.append(output)
        if stage=='valid':
            self.val_outputs.append(output)
        if stage=='test':
            self.test_outputs.append(output)
        self.log(f'{stage}_loss', loss, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "valid",self.mode)

    def test_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "test", self.mode)

    def training_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "train",self.mode)

    def forward(self, image):
        mask=self.model(image)
        return mask

    def single_predict_sliding_window(self, image, roi_size=(128,128), sw_batch_size=128, overlap=0):
        return sliding_window_inference(image, roi_size, sw_batch_size, self.model, overlap=0,
                                              mode="gaussian")


    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        if not outputs:
            return

        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])
        
        iou_per_image_mass = torch.nanmean(torch.Tensor([x["iou_per_image_mass"] for x in outputs]))
        iou_per_image_background = torch.nanmean(torch.Tensor([x["iou_per_image_background"] for x in outputs]))
        iou_per_dataset_image = torch.nanmean(torch.Tensor([x["iou_per_dataset_mass"] for x in outputs]))
        iou_per_dataset_background = torch.nanmean(torch.Tensor([x["iou_per_dataset_background"] for x in outputs]))

        iou_per_image_mass_no_empty = torch.nanmean(torch.Tensor([x["iou_per_image_mass_no_empty"] for x in outputs]))
        iou_per_image_background_no_empty = torch.nanmean(torch.Tensor([x["iou_per_image_background_no_empty"] for x in outputs]))
        iou_per_dataset_image_no_empty = torch.nanmean(torch.Tensor([x["iou_per_dataset_mass_no_empty"] for x in outputs]))
        iou_per_dataset_background_no_empty = torch.nanmean(torch.Tensor([x["iou_per_dataset_background_no_empty"] for x in outputs]))


        dice_per_image_mass = torch.nanmean(torch.Tensor([x["dice_per_image_mass"] for x in outputs]))
        dice_per_image_background = torch.nanmean(torch.Tensor([x["dice_per_image_background"] for x in outputs]))
        dice_per_dataset_image = torch.nanmean(torch.Tensor([x["dice_per_dataset_mass"] for x in outputs]))
        dice_per_dataset_background = torch.nanmean(torch.Tensor([x["dice_per_dataset_background"] for x in outputs]))

        dice_per_image_mass_no_empty = torch.nanmean(torch.Tensor([x["dice_per_image_mass_no_empty"] for x in outputs]))
        dice_per_image_background_no_empty = torch.nanmean(torch.Tensor([x["dice_per_image_background_no_empty"] for x in outputs]))
        dice_per_dataset_image_no_empty = torch.nanmean(torch.Tensor([x["dice_per_dataset_mass_no_empty"] for x in outputs]))
        dice_per_dataset_background_no_empty = torch.nanmean(torch.Tensor([x["dice_per_dataset_background_no_empty"] for x in outputs]))


        acc_per_image_mass = torch.nanmean(torch.Tensor([x["acc_per_image_mass"] for x in outputs]))
        acc_per_image_background = torch.nanmean(torch.Tensor([x["acc_per_image_background"] for x in outputs]))
        acc_per_dataset_image = torch.nanmean(torch.Tensor([x["acc_per_dataset_mass"] for x in outputs]))
        acc_per_dataset_background = torch.nanmean(torch.Tensor([x["acc_per_dataset_background"] for x in outputs]))


        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction = "micro-imagewise")
        per_dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction = "micro")

        per_image_dice = compute_dice_score_from_cm(tp, fp, fn, tn, reduction = "micro-imagewise")
        per_dataset_dice = compute_dice_score_from_cm(tp, fp, fn, tn, reduction = "micro")


        # MACRO AVG
        precision = compute_mean_precision(tp, fp, fn, tn)
        recall = compute_mean_recall(tp, fp, fn, tn)

        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro-imagewise')

        # MACRO IMAGEWISE MEAN DICE WITH EMPTY
        dice1_per_image = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none')
        dice0_per_image = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='none')
        mean_dice_per_image = np.mean(np.nanmean(np.array([dice0_per_image.cpu().numpy(), dice1_per_image.cpu().numpy()]), axis=0))

        # MACRO MEAN DICE WITH EMPTY
        mean_dice1_per_dataset = compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise')
        mean_dice0_per_dataset = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise')
        mean_dice_per_dataset = np.nanmean(np.array([mean_dice0_per_dataset.cpu().numpy(), mean_dice1_per_dataset.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN DICE NO EMPTY
        dice1_per_image_no_empty = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none',exclude_empty=True)
        dice0_per_image_no_empty = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='none',exclude_empty=True)
        combined_dice_scores = np.hstack((dice0_per_image_no_empty, dice1_per_image_no_empty ))
        valid_pairs = ~np.isnan(combined_dice_scores).any(axis=1)
        mean_dice_per_image_no_empty = np.mean(np.nanmean(combined_dice_scores[valid_pairs], axis=1))

        if mean_dice_per_image_no_empty.size == 0:
             mean_dice_per_image_no_empty = float('nan')

        # MACRO MEAN DICE NO EMPTY
        mean_dice1_per_dataset_no_empty = compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=True)
        mean_dice0_per_dataset_no_empty = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise',exclude_empty=True)
        mean_dice_per_dataset_no_empty = np.mean(np.array([mean_dice0_per_dataset_no_empty.cpu().numpy(), mean_dice1_per_dataset_no_empty.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN IOU WITH EMPTY

        iou1_per_image = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none')
        iou0_per_image = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='none')
        mean_iou_per_image = np.mean(np.nanmean(np.array([iou0_per_image.cpu().numpy(), iou1_per_image.cpu().numpy()]), axis=0))

        # MACRO MEAN IOU WITH EMPTY
        mean_iou1_per_dataset = compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise')
        mean_iou0_per_dataset = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise')
        mean_iou_per_dataset = np.nanmean(np.array([mean_iou0_per_dataset.cpu().numpy(), mean_iou1_per_dataset.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN IOU NO EMPTY
        iou1_per_image_no_empty = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none',exclude_empty=True)
        iou0_per_image_no_empty = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='none',exclude_empty=True)
        
        combined_iou_scores = np.hstack((iou0_per_image_no_empty, iou1_per_image_no_empty))
        valid_pairs = ~np.isnan(combined_iou_scores).any(axis=1) #10, 2
        mean_iou_per_image_no_empty = np.mean(np.nanmean(combined_iou_scores[valid_pairs], axis=1))

        if mean_iou_per_image_no_empty.size == 0:
             mean_iou_per_image_no_empty = float('nan')

        # MACRO MEAN IOU NO EMPTY
        mean_iou1_per_dataset_no_empty = compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=True)
        mean_iou0_per_dataset_no_empty = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise',exclude_empty=True)
        mean_iou_per_dataset_no_empty = np.mean(np.array([mean_iou0_per_dataset_no_empty.cpu().numpy(), mean_iou1_per_dataset_no_empty.cpu().numpy()]))
        

       
        self.log(f"{stage}_per_image_iou", per_image_iou, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_per_dataset_iou", per_dataset_iou, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_per_image_dice", per_image_dice, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_per_dataset_dice", per_dataset_dice, sync_dist = True, prog_bar=True, batch_size=batch_size)
        
        self.log(f"{stage}_mean_iou_per_image", mean_iou_per_image, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_iou_per_dataset", mean_iou_per_dataset, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_iou_per_image_no_empty", mean_iou_per_image_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_iou_per_dataset_no_empty", mean_iou_per_dataset_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_dice_per_image", mean_dice_per_image, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_dice_per_dataset", mean_dice_per_dataset, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_dice_per_image_no_empty", mean_dice_per_image_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_dice_per_dataset_no_empty", mean_dice_per_dataset_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        
        
        self.log(f"{stage}_precision", precision, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_recall", recall, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_accuracy", accuracy, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f'{stage}_iou_per_image_mass', iou_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_image_background', iou_per_image_background, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_mass', iou_per_dataset_image, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_background', iou_per_dataset_background, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_iou_per_image_mass_no_empty', iou_per_image_mass_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_image_background_no_empty', iou_per_image_background_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_mass_no_empty', iou_per_dataset_image_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_background_no_empty', iou_per_dataset_background_no_empty, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_dice_per_image_mass', dice_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_image_background', dice_per_image_background, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_mass', dice_per_dataset_image, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_background', dice_per_dataset_background, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_dice_per_image_mass_no_empty', dice_per_image_mass_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_image_background_no_empty', dice_per_image_background_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_mass_no_empty', dice_per_dataset_image_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_background_no_empty', dice_per_dataset_background_no_empty, sync_dist = True, batch_size=batch_size)


        self.log(f'{stage}_acc_per_image_mass', acc_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_acc_per_image_background', acc_per_image_background, sync_dist = True, prog_bar=True,batch_size=batch_size)
        self.log(f'{stage}_acc_per_dataset_mass', acc_per_dataset_image, sync_dist = True, prog_bar=True,batch_size=batch_size)
        self.log(f'{stage}_acc_per_dataset_background', acc_per_dataset_background, sync_dist = True, prog_bar=True,batch_size=batch_size)
        self.log(f'{stage}_acc_per_dataset_background', acc_per_dataset_background, sync_dist = True, prog_bar=True,batch_size=batch_size)



    def on_train_epoch_end(self):
        self.shared_epoch_end(outputs = self.train_outputs, stage = "train")
        self.train_outputs.clear()

    def on_validation_epoch_end(self):
        self.shared_epoch_end(outputs = self.val_outputs, stage = "valid")
        self.val_outputs.clear()

    def on_test_epoch_end(self):
        self.shared_epoch_end(outputs = self.test_outputs, stage = "test")
        self.test_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-4, weight_decay=1e-4)

        iterations_per_epoch = self.len_train_loader  # Number of iterations per epoch
        step_size_up = iterations_per_epoch // 2  # Half an epoch for the increasing phase
        gamma = 0.99 

        base_lr = 3e-5  # Increased base learning rate
        max_lr = 9e-4   # Increased maximum learning rate

        #base_lr = 1e-4  # Increased base learning rate
        #max_lr = 1e-3   # Increased maximum learning rate
        
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                             base_lr=base_lr,  # Minimum learning rate
                             max_lr=max_lr,   # Maximum learning rate
                             step_size_up=step_size_up,
                             mode='triangular',
                             cycle_momentum=False)  # Set to True if using an optimizer with momentum

        return [optimizer], [scheduler]
        
class BreastModel(L.LightningModule):

    def __init__(self,
                 arch,
                 encoder_name,
                 in_channels,
                 out_classes,
                 batch_size,
                 len_train_loader,
                 threshold = 0.4,
                 t_loss = AsymmetricUnifiedFocalLoss(delta=0.4, gamma=0.1),
                 boundaryloss=False, use_decoder_attention=True,use_simple_fusion=False,
                 **kwargs
                 ):
        super().__init__()

        aux_params = dict(
                pooling = 'avg',  # one of 'avg', 'max'
                dropout = 0.5,  # dropout ratio, default is None
                activation = None,  # activation function, default is None
                classes = out_classes,  # define number of output labels
        )
        self.model = MultiInputUNet(n_channels=in_channels, n_classes=out_classes, use_simple_fusion=use_simple_fusion,use_decoder_attention=use_decoder_attention)
        """self.model = smp.create_model(
               arch, encoder_name=encoder_name,
                aux_params = aux_params, 
               in_channels = in_channels, encoder_weights=None,**kwargs)"""

        self.t_loss = t_loss
    
        self.train_outputs = []
        self.val_outputs = []
        self.test_outputs = []
        self.save_hyperparameters()
        self.threshold = threshold
        self.batch_size = batch_size
        self.len_train_loader = len_train_loader
        self.augment_inference=False
        self.val=80
        self.boundaryloss=boundaryloss

    def ttaug(self, model, image):
        transforms = tta.Compose(
        [
            tta.HorizontalFlip(),
            tta.VerticalFlip(),
            tta.Rotate90(angles=[0, 180]),
            #tta.Scale(scales=[1, 2, 4]),
            #tta.Multiply(factors=[0.8]),        
        ]
        )
    
        model.eval()
        merger = Merger(type="tsharpen", n=len(transforms))
        for transformer in transforms: # custom transforms or e.g. tta.aliases.d4_transform() 
        
            # augment image
            augmented_image = transformer.augment_image(image)
            augmented_image = augmented_image.to("cuda")
    
            model_output = model(augmented_image)[0]
            
            # reverse augmentation for mask and label
            deaug_mask = transformer.deaugment_mask(model_output)
            #deaug_mask = deaug_mask.sigmoid()
            
            merger.append(deaug_mask)
    
        masks = merger.result
        #masks = masks.sigmoid()

        return masks


    def step(self, batch, batch_idx, stage):
        
        image1 = batch[0]["image"].to("cuda")
        image2 = batch[1]["image"].to("cuda")
        image3 = batch[2]["image"].to("cuda")

        assert image1.ndim == 4
        h, w = image1.shape[2:]
        assert h % 32 == 0 and w % 32 == 0

        mask = batch[0]["label"].to("cuda")
        assert mask.ndim == 4

        mask = mask
        assert mask.max() <= 1 and mask.min() >= 0

        if stage=='test' and self.augment_inference:
            prob_mask = self.ttaug(self.model, image)
            logits_mask = self.forward(image)

        else:
            logits_mask = self.forward(image1,image2,image3)
            prob_mask = logits_mask.sigmoid()

        if self.boundaryloss:
            dist_map = batch[0]["boundary"].to("cuda")
            t_loss = self.t_loss(logits_mask, prob_mask, dist_map, mask, self.current_epoch)

        else:
            t_loss = self.t_loss(logits_mask, mask)

        loss = t_loss

        pred_mask = (prob_mask > self.threshold).float()

        if stage=='test' and self.augment_inference:
            pred_mask = (prob_mask > self.threshold).int()
            pred_mask = remove_far_masses_based_on_largest_mass(pred_mask, distance_threshold=10)
            pred_mask = RemoveSmallObjects(connectivity=1, min_size=140)(pred_mask)
            pred_mask = torch.Tensor(pred_mask).to("cuda")


        tp, fp, fn, tn = smp.metrics.get_stats(pred_mask.long(), mask.long(), mode = "binary")

        iou_per_image_mass = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise")
        iou_per_image_background = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise")

        iou_per_image_mass_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise", exclude_empty=True)
        iou_per_image_background_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise",exclude_empty=True)

        iou_per_dataset_mass = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro")
        iou_per_dataset_background = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro")

        iou_per_dataset_mass_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro",exclude_empty=True)
        iou_per_dataset_background_no_empty = compute_iou(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro",exclude_empty=True)

        ###

        dice_per_image_mass = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise")
        dice_per_image_background = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise")

        dice_per_image_mass_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1,
                                         reduction = "micro_image_wise", exclude_empty=True)
        dice_per_image_background_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0,
                                               reduction = "micro_image_wise",exclude_empty=True)

        dice_per_dataset_mass = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro")
        dice_per_dataset_background = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro")

        dice_per_dataset_mass_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 1, reduction = "micro",exclude_empty=True)
        dice_per_dataset_background_no_empty = compute_dice_score(y_true = mask, y_pred = pred_mask, class_id = 0, reduction = "micro",exclude_empty=True)


        ###

        acc_per_image_mass = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 1,
                                                           averaging = "micro_image_wise")
        acc_per_image_background = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 0,
                                                                 averaging = "micro_image_wise")

        acc_per_dataset_mass = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 1,
                                                              averaging = "micro")
        acc_per_dataset_background = class_specific_accuracy_score(preds = mask, targets = pred_mask, class_id = 0,
                                                                   averaging = "micro")

        loss = loss.to('cpu')

        output = {
            "loss":                       loss,
            "tp":                         tp,
            "fp":                         fp,
            "fn":                         fn,
            "tn":                         tn,
            "iou_per_image_mass":         iou_per_image_mass,
            "iou_per_image_background":   iou_per_image_background,
            "iou_per_dataset_mass":      iou_per_dataset_mass,
            "iou_per_dataset_background": iou_per_dataset_background,
            "iou_per_image_mass_no_empty":         iou_per_image_mass_no_empty,
            "iou_per_image_background_no_empty":   iou_per_image_background_no_empty,
            "iou_per_dataset_mass_no_empty":      iou_per_dataset_mass_no_empty,
            "iou_per_dataset_background_no_empty": iou_per_dataset_background_no_empty, ###
            "dice_per_image_mass":         dice_per_image_mass,
            "dice_per_image_background":   dice_per_image_background,
            "dice_per_dataset_mass":      dice_per_dataset_mass,
            "dice_per_dataset_background": dice_per_dataset_background,
            "dice_per_image_mass_no_empty":         dice_per_image_mass_no_empty,
            "dice_per_image_background_no_empty":   dice_per_image_background_no_empty,
            "dice_per_dataset_mass_no_empty":      dice_per_dataset_mass_no_empty,
            "dice_per_dataset_background_no_empty": dice_per_dataset_background_no_empty,
            "acc_per_image_mass":         acc_per_image_mass,
            "acc_per_image_background":   acc_per_image_background,
            "acc_per_dataset_mass":      acc_per_dataset_mass,
            "acc_per_dataset_background": acc_per_dataset_background
        }

        if stage=='train':
            self.train_outputs.append(output)
        if stage=='valid':
            self.val_outputs.append(output)
        if stage=='test':
            self.test_outputs.append(output)
        self.log(f'{stage}_loss', loss, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
        return loss

    def validation_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "valid")

    def test_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "test")

    def training_step(self, batch, batch_idx):
        if batch:
            return self.step(batch, batch_idx, "train")


    def forward(self, image1,image2, image3):
        mask=self.model(image1, image2, image3)
        return mask

    def single_predict_sliding_window(self, image, roi_size=(128,128), sw_batch_size=128, overlap=0):
        return sliding_window_inference(image, roi_size, sw_batch_size, self.model, overlap=0,
                                              mode="gaussian")

    def shared_epoch_end(self, outputs, stage):
        # aggregate step metics
        if not outputs:
            return

        tp = torch.cat([x["tp"] for x in outputs])
        fp = torch.cat([x["fp"] for x in outputs])
        fn = torch.cat([x["fn"] for x in outputs])
        tn = torch.cat([x["tn"] for x in outputs])
        
        iou_per_image_mass = torch.nanmean(torch.Tensor([x["iou_per_image_mass"] for x in outputs]))
        iou_per_image_background = torch.nanmean(torch.Tensor([x["iou_per_image_background"] for x in outputs]))
        iou_per_dataset_image = torch.nanmean(torch.Tensor([x["iou_per_dataset_mass"] for x in outputs]))
        iou_per_dataset_background = torch.nanmean(torch.Tensor([x["iou_per_dataset_background"] for x in outputs]))

        iou_per_image_mass_no_empty = torch.nanmean(torch.Tensor([x["iou_per_image_mass_no_empty"] for x in outputs]))
        iou_per_image_background_no_empty = torch.nanmean(torch.Tensor([x["iou_per_image_background_no_empty"] for x in outputs]))
        iou_per_dataset_image_no_empty = torch.nanmean(torch.Tensor([x["iou_per_dataset_mass_no_empty"] for x in outputs]))
        iou_per_dataset_background_no_empty = torch.nanmean(torch.Tensor([x["iou_per_dataset_background_no_empty"] for x in outputs]))


        dice_per_image_mass = torch.nanmean(torch.Tensor([x["dice_per_image_mass"] for x in outputs]))
        dice_per_image_background = torch.nanmean(torch.Tensor([x["dice_per_image_background"] for x in outputs]))
        dice_per_dataset_image = torch.nanmean(torch.Tensor([x["dice_per_dataset_mass"] for x in outputs]))
        dice_per_dataset_background = torch.nanmean(torch.Tensor([x["dice_per_dataset_background"] for x in outputs]))

        dice_per_image_mass_no_empty = torch.nanmean(torch.Tensor([x["dice_per_image_mass_no_empty"] for x in outputs]))
        dice_per_image_background_no_empty = torch.nanmean(torch.Tensor([x["dice_per_image_background_no_empty"] for x in outputs]))
        dice_per_dataset_image_no_empty = torch.nanmean(torch.Tensor([x["dice_per_dataset_mass_no_empty"] for x in outputs]))
        dice_per_dataset_background_no_empty = torch.nanmean(torch.Tensor([x["dice_per_dataset_background_no_empty"] for x in outputs]))


        acc_per_image_mass = torch.nanmean(torch.Tensor([x["acc_per_image_mass"] for x in outputs]))
        acc_per_image_background = torch.nanmean(torch.Tensor([x["acc_per_image_background"] for x in outputs]))
        acc_per_dataset_image = torch.nanmean(torch.Tensor([x["acc_per_dataset_mass"] for x in outputs]))
        acc_per_dataset_background = torch.nanmean(torch.Tensor([x["acc_per_dataset_background"] for x in outputs]))


        per_image_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction = "micro-imagewise")
        per_dataset_iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction = "micro")

        per_image_dice = compute_dice_score_from_cm(tp, fp, fn, tn, reduction = "micro-imagewise")
        per_dataset_dice = compute_dice_score_from_cm(tp, fp, fn, tn, reduction = "micro")


        # MACRO AVG
        precision = compute_mean_precision(tp, fp, fn, tn)
        recall = compute_mean_recall(tp, fp, fn, tn)

        accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction='micro-imagewise')

        
        # MACRO IMAGEWISE MEAN DICE WITH EMPTY
        dice1_per_image = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none')
        dice0_per_image = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='none')
        mean_dice_per_image = np.mean(np.nanmean(np.array([dice0_per_image.cpu().numpy(), dice1_per_image.cpu().numpy()]), axis=0))

        # MACRO MEAN DICE WITH EMPTY
        mean_dice1_per_dataset = compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise')
        mean_dice0_per_dataset = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise')
        mean_dice_per_dataset = np.nanmean(np.array([mean_dice0_per_dataset.cpu().numpy(), mean_dice1_per_dataset.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN DICE NO EMPTY
        dice1_per_image_no_empty = compute_dice_from_metrics(tp, fp, tn, fn, reduction='none',exclude_empty=True)
        dice0_per_image_no_empty = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='none',exclude_empty=True)
        combined_dice_scores = np.hstack((dice0_per_image_no_empty, dice1_per_image_no_empty ))
        valid_pairs = ~np.isnan(combined_dice_scores).any(axis=1)
        mean_dice_per_image_no_empty = np.mean(np.nanmean(combined_dice_scores[valid_pairs], axis=1))

        if mean_dice_per_image_no_empty.size == 0:
             mean_dice_per_image_no_empty = float('nan')

        # MACRO MEAN DICE NO EMPTY
        mean_dice1_per_dataset_no_empty = compute_dice_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=True)
        mean_dice0_per_dataset_no_empty = compute_dice_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise',exclude_empty=True)
        mean_dice_per_dataset_no_empty = np.mean(np.array([mean_dice0_per_dataset_no_empty.cpu().numpy(), mean_dice1_per_dataset_no_empty.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN IOU WITH EMPTY

        iou1_per_image = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none')
        iou0_per_image = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='none')
        mean_iou_per_image = np.mean(np.nanmean(np.array([iou0_per_image.cpu().numpy(), iou1_per_image.cpu().numpy()]), axis=0))

        # MACRO MEAN IOU WITH EMPTY
        mean_iou1_per_dataset = compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise')
        mean_iou0_per_dataset = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise')
        mean_iou_per_dataset = np.nanmean(np.array([mean_iou0_per_dataset.cpu().numpy(), mean_iou1_per_dataset.cpu().numpy()]))

        # MACRO IMAGEWISE MEAN IOU NO EMPTY
        iou1_per_image_no_empty = compute_iou_from_metrics(tp, fp, tn, fn, reduction='none',exclude_empty=True)
        iou0_per_image_no_empty = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='none',exclude_empty=True)
        
        combined_iou_scores = np.hstack((iou0_per_image_no_empty, iou1_per_image_no_empty))
        valid_pairs = ~np.isnan(combined_iou_scores).any(axis=1) #10, 2
        mean_iou_per_image_no_empty = np.mean(np.nanmean(combined_iou_scores[valid_pairs], axis=1))

        if mean_iou_per_image_no_empty.size == 0:
             mean_iou_per_image_no_empty = float('nan')

        # MACRO MEAN IOU NO EMPTY
        mean_iou1_per_dataset_no_empty = compute_iou_from_metrics(tp, fp, tn, fn, reduction='micro-imagewise',exclude_empty=True)
        mean_iou0_per_dataset_no_empty = compute_iou_from_metrics(tn, fn, tp, fp,  reduction='micro-imagewise',exclude_empty=True)
        mean_iou_per_dataset_no_empty = np.mean(np.array([mean_iou0_per_dataset_no_empty.cpu().numpy(), mean_iou1_per_dataset_no_empty.cpu().numpy()]))
    
        self.log(f"{stage}_per_image_iou", per_image_iou, sync_dist = True, prog_bar=True, batch_size=batch_size)
        
        self.log(f"{stage}_per_dataset_iou", per_dataset_iou, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_per_image_dice", per_image_dice, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_per_dataset_dice", per_dataset_dice, sync_dist = True, prog_bar=True, batch_size=batch_size)
        
        self.log(f"{stage}_mean_iou_per_image", mean_iou_per_image, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_iou_per_dataset", mean_iou_per_dataset, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_iou_per_image_no_empty", mean_iou_per_image_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_iou_per_dataset_no_empty", mean_iou_per_dataset_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_dice_per_image", mean_dice_per_image, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_dice_per_dataset", mean_dice_per_dataset, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f"{stage}_mean_dice_per_image_no_empty", mean_dice_per_image_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_mean_dice_per_dataset_no_empty", mean_dice_per_dataset_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        
        
        self.log(f"{stage}_precision", precision, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_recall", recall, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f"{stage}_accuracy", accuracy, sync_dist = True, prog_bar=True, batch_size=batch_size)

        self.log(f'{stage}_iou_per_image_mass', iou_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_image_background', iou_per_image_background, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_mass', iou_per_dataset_image, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_background', iou_per_dataset_background, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_iou_per_image_mass_no_empty', iou_per_image_mass_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_image_background_no_empty', iou_per_image_background_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_mass_no_empty', iou_per_dataset_image_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_iou_per_dataset_background_no_empty', iou_per_dataset_background_no_empty, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_dice_per_image_mass', dice_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_image_background', dice_per_image_background, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_mass', dice_per_dataset_image, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_background', dice_per_dataset_background, sync_dist = True, batch_size=batch_size)

        self.log(f'{stage}_dice_per_image_mass_no_empty', dice_per_image_mass_no_empty, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_image_background_no_empty', dice_per_image_background_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_mass_no_empty', dice_per_dataset_image_no_empty, sync_dist = True, batch_size=batch_size)
        self.log(f'{stage}_dice_per_dataset_background_no_empty', dice_per_dataset_background_no_empty, sync_dist = True, batch_size=batch_size)


        self.log(f'{stage}_acc_per_image_mass', acc_per_image_mass, sync_dist = True, prog_bar=True, batch_size=batch_size)
        self.log(f'{stage}_acc_per_image_background', acc_per_image_background, sync_dist = True, prog_bar=True,batch_size=batch_size)
        self.log(f'{stage}_acc_per_dataset_mass', acc_per_dataset_image, sync_dist = True, prog_bar=True,batch_size=batch_size)
        self.log(f'{stage}_acc_per_dataset_background', acc_per_dataset_background, sync_dist = True, prog_bar=True,batch_size=batch_size)

    def on_train_epoch_end(self):
        self.shared_epoch_end(outputs = self.train_outputs, stage = "train")
        self.train_outputs.clear()

    def on_validation_epoch_end(self):
        self.shared_epoch_end(outputs = self.val_outputs, stage = "valid")
        self.val_outputs.clear()

    def on_test_epoch_end(self):
        self.shared_epoch_end(outputs = self.test_outputs, stage = "test")
        self.test_outputs.clear()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-4, weight_decay=1e-4)



        iterations_per_epoch = self.len_train_loader  # Number of iterations per epoch
        step_size_up = iterations_per_epoch // 2  # Half an epoch for the increasing phase
        gamma = 0.99 

        base_lr = 3e-5  # Increased base learning rate
        max_lr = 9e-4   # Increased maximum learning rate

        #base_lr = 1e-4  # Increased base learning rate
        #max_lr = 1e-3   # Increased maximum learning rate
        
        scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                             base_lr=base_lr,  # Minimum learning rate
                             max_lr=max_lr,   # Maximum learning rate
                             step_size_up=step_size_up,
                             mode='triangular',
                             cycle_momentum=False)  # Set to True if using an optimizer with momentum

        return [optimizer], [scheduler]

## SOFTDICELOSS

In [15]:
import torch
import torch.nn as nn

class SoftDiceLoss(nn.Module):
    """
    Soft Dice loss for binary (single-channel) segmentation.
    Expects predictions of shape (B, 1, H, W) already in [0,1].
    Expects ground truth of shape (B, 1, H, W), either 0/1 or float in [0,1].
    """
    def __init__(self, smooth=1e-6):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        """
        Args:
            probs: model outputs as probabilities in [0,1], (B, 1, H, W)
            targets: ground truth masks, (B, 1, H, W) in {0,1} or floats
        Returns:
            Dice loss (scalar).
        """

        probs = logits.sigmoid()
        
        # Flatten: (B, 1, H, W) -> (B, H*W)
        probs_flat = probs.view(probs.size(0), -1)
        targets_flat = targets.view(targets.size(0), -1)

        # Numerator = 2 * Σ (p_i * t_i)
        intersection = 2.0 * torch.sum(probs_flat * targets_flat, dim=1)

        # Denominator = Σ (p_i^2) + Σ (t_i^2)
        denominator = torch.sum(probs_flat * probs_flat, dim=1) + \
                      torch.sum(targets_flat * targets_flat, dim=1)

        # Dice coefficient (per sample), then average over batch
        dice_per_sample = (intersection + self.smooth) / (denominator + self.smooth)
        dice = dice_per_sample.mean()

        # Dice loss = 1 - dice coefficient
        return 1.0 - dice

# Data split

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class AttentionGate(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

class EncoderBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(EncoderBlock, self).__init__()
        self.conv = ConvBlock(in_channels, out_channels)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p

class DecoderBlock(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels, F_int, use_attention=True):
        super(DecoderBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, mid_channels, kernel_size=2, stride=2)
        # Adjust the attention gate to handle the combined skip connections
        if use_attention:
            self.attention = AttentionGate(F_g=mid_channels, F_l=mid_channels, F_int=F_int)
        self.conv = ConvBlock(mid_channels + mid_channels, out_channels)  # Adjust for concatenated skip connection size
        self.use_attention=use_attention

    def forward(self, x, combined_skip):

        x = self.up(x)
        # Apply attention to the combined skip connections
        if self.use_attention:
            combined_skip = self.attention(g=x, x=combined_skip)
        
        x = torch.cat([x, combined_skip], dim=1)
        x = self.conv(x)
        return x
    

class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction_ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction_ratio, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // reduction_ratio, in_channels, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_pooled = self.avg_pool(x).view(b, c)
        max_pooled = self.max_pool(x).view(b, c)
        avg_out = self.fc2(self.relu(self.fc1(avg_pooled)))
        max_out = self.fc2(self.relu(self.fc1(max_pooled)))
        out = self.sigmoid(avg_out + max_out).view(b, c, 1, 1)
        return x * out
        
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        avg_out = torch.mean(x, dim=1, keepdim=True)
        x = torch.cat([max_out, avg_out], dim=1)
        x = self.conv(x)
        return self.sigmoid(x)

class FeatureFusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, reduction=16):
        super(FeatureFusionBlock, self).__init__()
        self.channel_attention_local = ChannelAttention(in_channels)
        self.spatial_attention_local = SpatialAttention()

        self.channel_attention_global = ChannelAttention(in_channels)
        self.spatial_attention_global = SpatialAttention()
        
        self.fusion_conv = nn.Conv2d(in_channels*3, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, global_feat, local_feat1, local_feat2):
        # Apply channel attention to each feature map
        global_ca = global_feat * self.channel_attention_global(global_feat)
        local_ca1 = local_feat1 * self.channel_attention_local(local_feat1)
        local_ca2 = local_feat2 * self.channel_attention_local(local_feat2)

        # Apply spatial attention to each feature map
        global_sa = global_ca * self.spatial_attention_global(global_ca)
        local_sa1 = local_ca1 * self.spatial_attention_local(local_ca1)
        local_sa2 = local_ca2 * self.spatial_attention_local(local_ca2)

        # Concatenate the feature maps
        fused_features = torch.cat((global_sa, local_sa1, local_sa2), dim=1)
        
        # Fuse them using a convolutional layer
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.relu(fused_features)
        
        return fused_features

class SimpleFeatureFusionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(SimpleFeatureFusionBlock, self).__init__()
        # Since we're concatenating three feature maps, the input to the fusion_conv will be 3 times in_channels
        self.fusion_conv = nn.Conv2d(in_channels * 3, out_channels, kernel_size=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, global_feat, local_feat1, local_feat2):
        # Concatenate the feature maps along the channel dimension
        fused_features = torch.cat((global_feat, local_feat1, local_feat2), dim=1)
        
        # Apply a convolutional layer to reduce dimensions
        fused_features = self.fusion_conv(fused_features)
        fused_features = self.relu(fused_features)
        
        return fused_features



class MultiInputUNet(nn.Module):
    def __init__(self, n_channels, n_classes, use_simple_fusion=False, use_decoder_attention=True):
        super(MultiInputUNet, self).__init__()

        if use_simple_fusion:
            self.fusion_skip1 = SimpleFeatureFusionBlock(64, 64)
            self.fusion_skip2 = SimpleFeatureFusionBlock(128, 128)
            self.fusion_skip3 = SimpleFeatureFusionBlock(256, 256)
            self.fusion_skip4 = SimpleFeatureFusionBlock(512, 512)
        else:
          # Initialize fusion blocks with channel reduction
            self.fusion_skip1 = FeatureFusionBlock(64, 64)
            self.fusion_skip2 = FeatureFusionBlock(128, 128)
            self.fusion_skip3 = FeatureFusionBlock(256, 256)
            self.fusion_skip4 = FeatureFusionBlock(512,512)
        
        # Encoders for each input stream
        self.encoder1 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])
        self.encoder2 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])
        self.encoder3 = nn.ModuleList([EncoderBlock(n_channels, 64), EncoderBlock(64, 128), EncoderBlock(128, 256), EncoderBlock(256,512)])

        if use_simple_fusion:
            self.deep_feature_fusion = SimpleFeatureFusionBlock(512, 512)
        else:
             self.deep_feature_fusion = FeatureFusionBlock(512,512)
        
        
        # Decoder Blocks
        self.decoder1 = DecoderBlock(512,512, 256,256, use_attention=use_decoder_attention)  # Input channels adjusted for merged features
        self.decoder2 = DecoderBlock(256, 256, 128, 128,use_attention=use_decoder_attention)  # Input channels adjusted for merged features
        self.decoder3 = DecoderBlock(128, 128, 64, 64,use_attention=use_decoder_attention)
        self.decoder4 = DecoderBlock(64, 64, 32, 32,use_attention=use_decoder_attention)

        self.final_conv = nn.Conv2d(32, n_classes, kernel_size=1)

    def forward(self, x1, x2, x3):
        # Process each input through its respective encoders
        skips1, p1 = self.process_through_encoders(x1, self.encoder1)
        skips2, p2 = self.process_through_encoders(x2, self.encoder2)
        skips3, p3 = self.process_through_encoders(x3, self.encoder3)

        fused_skips1 = self.fusion_skip1(skips1[0], skips2[0], skips3[0])
        fused_skips2 = self.fusion_skip2(skips1[1], skips2[1], skips3[1])
        fused_skips3 = self.fusion_skip3(skips1[2], skips2[2], skips3[2])
        fused_skips4 = self.fusion_skip4(skips1[3], skips2[3], skips3[3])

        fused_features = self.deep_feature_fusion(p1, p2, p3)
        
        # Decode the combined features
        # Note: Attention mechanism applies to concatenated skip connections from corresponding layers of each input
        d1 = self.decoder1(fused_features, fused_skips4)
        d2 = self.decoder2(d1, fused_skips3)
        d3 = self.decoder3(d2, fused_skips2)
        d4 = self.decoder4(d3, fused_skips1)

        return self.final_conv(d4)

    def process_through_encoders(self, x, encoders):
        skips = []
        p = x
        for encoder in encoders:
            x, p = encoder(p)
            skips.append(x)
        return skips, p  # Reverse skips for correct order in decoding

# Compute volume metrics

In [34]:
def compute_precision(gt_mask,prediction_mask):
    """
    Compute precision between prediction and ground truth masks, with consideration
    for cases where no positives are retrieved and no relevant elements exist.
    
    :param prediction_mask: Binary mask of predictions (HxW).
    :param gt_mask: Binary mask of ground truth (HxW).
    :return: Precision as a float.
    """
    prediction_mask = (prediction_mask > 0).astype(np.uint8)
    gt_mask = (gt_mask > 0).astype(np.uint8)
    
    TP = np.logical_and(prediction_mask, gt_mask).sum()
    FP = np.logical_and(prediction_mask, np.logical_not(gt_mask)).sum()
    
    if TP + FP == 0:
        # If no predictions are made and the ground truth has no positives,
        # precision is perfect if ground truth also has no positives.
        return 1.0 if np.sum(gt_mask) == 0 else 0
    else:
        precision = TP / (TP + FP)
    
    return precision

def compute_recall(gt_mask,prediction_mask):
    """
    Compute recall between prediction and ground truth masks, with consideration
    for cases where no positives are identified and no relevant elements exist.
    
    :param prediction_mask: Binary mask of predictions (HxW).
    :param gt_mask: Binary mask of ground truth (HxW).
    :return: Recall as a float.
    """
    prediction_mask = (prediction_mask > 0).astype(np.uint8)
    gt_mask = (gt_mask > 0).astype(np.uint8)
    
    TP = np.logical_and(prediction_mask, gt_mask).sum()
    FN = np.logical_and(np.logical_not(prediction_mask), gt_mask).sum()
    
    if TP + FN == 0:
        # If the ground truth has no positives, recall is perfect.
        return 1.0
    else:
        recall = TP / (TP + FN)
    
    return recall


def compute_precision_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, is_mean=True, return_std=False):
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])
    
    # Denominators for precision
    denom_class_1 = tp + fp
    denom_class_0 = tn + fn
    
    # Precision for class 1 (per sample)
    precision_class_1 = torch.where(
        denom_class_1 > 0, tp / denom_class_1, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_precision_class_1 = torch.nanmean(precision_class_1).item()
    stddev_class_1 = np.nanstd(precision_class_1.cpu().numpy())
    
    # Precision for class 0 (per sample)
    precision_class_0 = torch.where(
        denom_class_0 > 0, tn / denom_class_0, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_precision_class_0 = torch.nanmean(precision_class_0).item()
    stddev_class_0 = np.nanstd(precision_class_0.cpu().numpy())
    
    if not is_mean:
        if return_std:
            return mean_precision_class_1, stddev_class_1
        else:
            return mean_precision_class_1
    
    # Overall mean precision (mean between the two classes)
    overall_mean_precision = (mean_precision_class_1 + mean_precision_class_0) / 2

    overall_std_precision = (stddev_class_1 + stddev_class_0) / 2

    if return_std:
        # Include std deviations when `return_std` is True
        return overall_mean_precision, overall_std_precision
    else:
        # Default behavior
        return overall_mean_precision



def compute_recall_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, is_mean=True, return_std=False):
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])
    
    # Denominators for recall
    denom_class_1 = tp + fn
    denom_class_0 = tn + fp
    
    # Recall for class 1 (per sample)
    recall_class_1 = torch.where(
        denom_class_1 > 0, tp / denom_class_1, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_recall_class_1 = torch.nanmean(recall_class_1).item()
    stddev_class_1 = np.nanstd(recall_class_1.cpu().numpy())
    
    # Recall for class 0 (per sample)
    recall_class_0 = torch.where(
        denom_class_0 > 0, tn / denom_class_0, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_recall_class_0 = torch.nanmean(recall_class_0).item()
    stddev_class_0 = np.nanstd(recall_class_0.cpu().numpy())
    
    if not is_mean:
        if return_std:
            # Return per-class recall and standard deviations
            return mean_recall_class_1, stddev_class_1
        else:
            # Return per-class recall only
            return mean_recall_class_1
    
    # Overall mean recall (mean between the two classes)
    overall_mean_recall = (mean_recall_class_1 + mean_recall_class_0) / 2

    overall_std_recall = (stddev_class_1 +  stddev_class_0) / 2

    if return_std:
        # Return overall mean recall and standard deviations
        return overall_mean_recall, overall_std_recall
    else:
        # Return overall mean recall and class-wise mean recalls
        return overall_mean_recall



def compute_f1_from_cumulator(
    TPs, FPs, FNs, TNs,
    exclude_empty=False,
    reduce_mean=True,
    is_mean=True,
    return_std=False
):
    """
    Compute the F1-score for both class 1 and class 0 from cumulative TPs, FPs, FNs, and TNs.

    Arguments:
    ----------
    TPs, FPs, FNs, TNs : list of torch.Tensor
        Lists of True/False Positives/Negatives for each batch/segment.
    exclude_empty : bool, optional
        If True, denominators of zero yield NaN; if False, they fall back to 1.0.
    reduce_mean : bool, optional
        If False, returns per-image F1 array; if True, returns aggregated mean (and std).
    is_mean : bool, optional
        If True, averages class 0 and class 1 F1’s; if False, returns only class 1’s.
    return_std : bool, optional
        If True and reduce_mean=True, also returns the standard deviation.

    Returns:
    --------
    If reduce_mean == False:
        - per_image_f1 (np.ndarray)
          * If is_mean: overall (class-0+class-1)/2 per image
          * Else: class-1 F1 per image
    If reduce_mean == True:
        If return_std:
            - mean_f1 (float), std_f1 (float)
        Else:
            - mean_f1 (float)
    """
    # concatenate
    tp = torch.cat(TPs) if isinstance(TPs, (list, tuple)) else TPs
    fp = torch.cat(FPs) if isinstance(FPs, (list, tuple)) else FPs
    fn = torch.cat(FNs) if isinstance(FNs, (list, tuple)) else FNs
    tn = torch.cat(TNs) if isinstance(TNs, (list, tuple)) else TNs

    # class 1
    denom1 = 2*tp + fp + fn
    fallback = torch.tensor(float('nan') if exclude_empty else 1.0, device=tp.device)
    f1_1 = torch.where(denom1 > 0, (2.0*tp)/denom1, fallback)

    # class 0
    denom0 = 2*tn + fn + fp
    f1_0 = torch.where(denom0 > 0, (2.0*tn)/denom0, fallback)

    # per-image arrays (CPU numpy)
    f1_1_np = f1_1.cpu().numpy()
    f1_0_np = f1_0.cpu().numpy()
    if is_mean:
        per_image = (f1_1_np + f1_0_np) / 2.0
    else:
        per_image = f1_1_np

    # if user wants the raw per-image values…
    if not reduce_mean:
        return per_image

    # otherwise aggregate across images
    mean_f1 = np.nanmean(per_image)
    if return_std:
        std_f1 = np.nanstd(per_image)
        return mean_f1, std_f1
    return mean_f1



def compute_accuracy_from_cumulator(TPs, FPs, FNs, TNs, exclude_empty=False, is_mean=True, return_std=False):
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])
    
    # Denominators for accuracy
    denom_class_1 = tp + fp + fn
    denom_class_0 = tn + fp + fn
    
    # Accuracy for class 1 (foreground)
    accuracy_class_1 = torch.where(
        denom_class_1 > 0, tp / denom_class_1, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_accuracy_class_1 = torch.nanmean(accuracy_class_1).item()
    stddev_class_1 = np.nanstd(accuracy_class_1.cpu().numpy())
    
    # Accuracy for class 0 (background)
    accuracy_class_0 = torch.where(
        denom_class_0 > 0, tn / denom_class_0, torch.tensor(1.0 if not exclude_empty else float('nan'))
    )
    mean_accuracy_class_0 = torch.nanmean(accuracy_class_0).item()
    stddev_class_0 = np.nanstd(accuracy_class_0.cpu().numpy())
    
    if not is_mean:
        if return_std:
            # Return per-class accuracy and standard deviations
            return mean_accuracy_class_1, stddev_class_1
        else:
            # Return per-class accuracy only
            return mean_accuracy_class_1
    
    # Overall mean accuracy (mean between the two classes)
    overall_mean_accuracy = (mean_accuracy_class_1 + mean_accuracy_class_0) / 2

    overall_std_accuracy = (stddev_class_1 + stddev_class_0) / 2

    if return_std:
        # Return overall mean accuracy and standard deviations
        return overall_mean_accuracy, overall_std_accuracy
    else:
        # Return overall mean accuracy and class-wise mean accuracies
        return overall_mean_accuracy, mean_accuracy_class_0, mean_accuracy_class_1

def compute_accuracy_excluding_cases(TPs, FPs, FNs, TNs, return_std=False, exclude_blank_case=False):
    """
    Computes accuracy, excluding cases with a zero denominator or no ground truth positives.

    Args:
        TPs: List of tensors for true positives across batches.
        FPs: List of tensors for false positives across batches.
        FNs: List of tensors for false negatives across batches.
        TNs: List of tensors for true negatives across batches.
        return_std: Boolean indicating whether to return the standard deviation.

    Returns:
        mean_accuracy: Mean accuracy across valid cases.
        (optional) stddev_accuracy: Standard deviation of accuracy across valid cases if return_std is True.
    """
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    # Compute the denominator for accuracy (tp + fp + fn + tn)
    denominator = tp + fp + fn + tn


    if exclude_blank_case:
        valid_mask = ((tp + fp + fn) != 0)
    else:
        # Exclude cases with zero denominator or no ground truth positives
        valid_mask = ((tp + fp + fn) != 0) & ((tp + fn) != 0)

    # Compute accuracy only for valid cases
    accuracy = torch.zeros_like(denominator, dtype=torch.float)
    accuracy[valid_mask] = (tp[valid_mask] + tn[valid_mask]) / denominator[valid_mask]

    # Exclude invalid cases by setting them to NaN
    accuracy[~valid_mask] = torch.tensor(float('nan'))

    # Compute mean accuracy
    mean_accuracy = torch.nanmean(accuracy).item()

    if return_std:
        # Compute standard deviation, ignoring NaN values
        stddev_accuracy = np.nanstd(accuracy)
        return mean_accuracy, stddev_accuracy

    return mean_accuracy


def compute_precision_excluding_cases_from_cumulator(TPs, FPs, FNs, TNs, return_std=False, exclude_only_zero_denominator=False):
    """
    Computes precision for cases with a non-zero denominator and excludes cases where there are no ground truth positives.

    Args:
        TPs: List of tensors for true positives across batches.
        FPs: List of tensors for false positives across batches.
        FNs: List of tensors for false negatives across batches.
        TNs: List of tensors for true negatives across batches.
        return_std: Boolean indicating whether to return the standard deviation.

    Returns:
        mean_precision: Mean precision across valid cases.
        (optional) stddev_precision: Standard deviation of precision across valid cases if return_std is True.
    """
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    # Compute the denominator for precision (tp + fp)
    denominator = tp + fp

    if exclude_only_zero_denominator:
        valid_mask = denominator != 0
    else:
        # Exclude cases with zero denominator or no ground truth positives
        valid_mask = (denominator != 0) & ((tp + fn) != 0)

    # Compute precision only for valid cases
    precision = torch.zeros_like(denominator, dtype=torch.float)
    precision[valid_mask] = tp[valid_mask] / denominator[valid_mask]

    # Exclude invalid cases by setting them to NaN
    precision[~valid_mask] = torch.tensor(float('nan'))

    # Compute mean precision
    mean_precision = torch.nanmean(precision).item()

    if return_std:
        # Compute standard deviation, ignoring NaN values
        stddev_precision = np.nanstd(precision)
        return mean_precision, stddev_precision

    return mean_precision



def compute_recall_excluding_cases_from_cumulator(TPs, FPs, FNs, TNs, return_std=False, exclude_only_zero_denominator=False):
    """
    Computes recall for class 1, excluding cases with a zero denominator or no ground truth positives.

    Args:
        TPs: List of tensors for true positives across batches.
        FPs: List of tensors for false positives across batches.
        FNs: List of tensors for false negatives across batches.
        TNs: List of tensors for true negatives across batches.
        return_std: Boolean indicating whether to return the standard deviation.

    Returns:
        mean_recall: Mean recall across valid cases.
        (optional) stddev_recall: Standard deviation of recall across valid cases if return_std is True.
    """
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    # Compute the denominator for recall (tp + fn)
    denominator = tp + fn

    if exclude_only_zero_denominator:
        valid_mask = denominator != 0
    else:
        # Exclude cases with zero denominator or no ground truth positives
        valid_mask = (denominator != 0) & ((tp + fn) != 0)

    # Compute recall only for valid cases
    recall = torch.zeros_like(denominator, dtype=torch.float)
    recall[valid_mask] = tp[valid_mask] / denominator[valid_mask]

    # Exclude invalid cases by setting them to NaN
    recall[~valid_mask] = torch.tensor(float('nan'))

    # Compute mean recall
    mean_recall = torch.nanmean(recall).item()

    if return_std:
        # Compute standard deviation, ignoring NaN values
        stddev_recall = np.nanstd(recall)
        return mean_recall, stddev_recall

    return mean_recall


def compute_f1_excluding_cases_from_cumulator(TPs, FPs, FNs, TNs, return_std=False, exclude_only_zero_denominator=False):
    """
    Computes F1 score for each 'case' (sample), excluding invalid cases.
    Invalid cases may be:
      - Those with zero denominator (tp + 0.5*(fp + fn) = 0).
      - Those with no ground-truth positives (tp + fn = 0), depending on the 'exclude_only_zero_denominator' flag.

    Args:
        TPs: List of tensors for true positives across batches.
        FPs: List of tensors for false positives across batches.
        FNs: List of tensors for false negatives across batches.
        TNs: List of tensors for true negatives across batches.
        return_std: Boolean indicating whether to return standard deviation.
        exclude_only_zero_denominator: If True, we only exclude cases where (2*tp + fp + fn) = 0.
                                       If False, we also exclude cases with no ground-truth positives (tp + fn = 0).

    Returns:
        mean_f1: Mean F1 across valid cases (float).
        (optional) stddev_f1: Standard deviation of F1 across valid cases (float),
                              only returned if return_std is True.
    """
    # 1. Concatenate all batches
    tp = torch.cat([tp for tp in TPs])
    fp = torch.cat([fp for fp in FPs])
    fn = torch.cat([fn for fn in FNs])
    tn = torch.cat([tn for tn in TNs])

    # 2. Compute the per-case denominator for F1 = 2*TP / (2*TP + FP + FN)
    denominator = tp + 0.5*(fp + fn)

    # 3. Determine valid cases
    #    If 'exclude_only_zero_denominator' is True, only exclude denominator == 0.
    #    Otherwise, also exclude cases where there are no positives in ground truth (tp+fn=0).
    if exclude_only_zero_denominator:
        valid_mask = (denominator != 0)
    else:
        valid_mask = (denominator != 0) & ((tp + fn) != 0)

    # 4. Allocate a tensor to hold the per-case F1
    f1 = torch.zeros_like(denominator, dtype=torch.float)

    # 5. Compute F1 only for valid cases
    f1[valid_mask] = (tp[valid_mask]) / denominator[valid_mask]

    # 6. Mark invalid cases as NaN for later exclusion in mean/std computations
    f1[~valid_mask] = torch.tensor(float('nan'))

    # 7. Compute mean F1 (ignoring NaNs)
    mean_f1 = torch.nanmean(f1).item()

    if return_std:
        # 8. Compute standard deviation (ignoring NaNs)
        stddev_f1 = np.nanstd(f1.cpu().numpy())
        return mean_f1, stddev_f1

    return mean_f1


# Dataset & Dataloader building

In [18]:
batch_size = 24
dataset_base_path = 'Dataset-arrays-4-FINAL'

In [19]:

g = reseed()
with torch.no_grad():
    torch.cuda.empty_cache()

Seed set to 200


Using random seed 200...


In [20]:
has_mask = False
get_boundaryloss=True
num_workers = os.cpu_count()*0.2
image_only = False

## PATCHES AND SUBTRACTED 3RD

In [21]:
sub_third_images_path_prefixes = ("Dataset-arrays-4", "Dataset-arrays")

In [22]:
mean_patches_sub= 86.13536834716797
std_patches_sub= 238.13461303710938

In [23]:
test_transforms_patches_sub = Compose(
        [
            LoadImaged(keys = ["image", "label"], image_only = image_only,reader=monai.data.NumpyReader()),
            EnsureChannelFirstd(keys = ["image", "label"]),
            Preprocess(has_mask =has_mask,keys=None, mode='test', get_boundaryloss=get_boundaryloss, subtracted_images_path_prefixes=sub_third_images_path_prefixes, subtrahend = mean_patches_sub, divisor = std_patches_sub,get_patches=True)
        ]
)

## NO THORAX AND SUBTRACTED

In [24]:
sub_third_images_path_prefixes = ("Dataset-arrays-4", "Dataset-arrays")

In [25]:
mean_no_thorax_third_sub= 43.1498
std_no_thorax_third_sub= 172.6704

In [26]:
test_transforms_no_thorax_third_sub  = Compose(
        [
            LoadImaged(keys = ["image", "label"], image_only = image_only,reader=monai.data.NumpyReader()),
            EnsureChannelFirstd(keys = ["image", "label"]),
            Preprocess(has_mask =has_mask,keys=None, mode='test', get_boundaryloss=get_boundaryloss, subtracted_images_path_prefixes=sub_third_images_path_prefixes, subtrahend = mean_no_thorax_third_sub, divisor = std_no_thorax_third_sub, get_patches=False)
        ]
    )

# Merge utils

In [27]:
import numpy as np
from scipy.ndimage import label as labell, find_objects
from skimage.measure import regionprops

def calculate_iou_for_mass_detection(mask1, mask2):
        # Calculate intersection over union for two masks
        intersection = np.logical_and(mask1, mask2).sum()
        union = np.logical_or(mask1, mask2).sum()
        if union == 0:
            return 0
        return intersection / union


def calculate_mass_detection_iou(y_pred,y_true):
    structure = np.ones((3,3), dtype=np.bool_)  # 2D connectivity
    labels_true, num_true = labell(y_true, structure=structure)
    labels_pred, num_pred = labell(y_pred, structure=structure)

    if num_true == 0:
        return float('nan')
    
    detected_masses = 0
    true_objects = find_objects(labels_true)
    pred_objects = find_objects(labels_pred)
    
    for i, true_slice in enumerate(true_objects):
        for j, pred_slice in enumerate(pred_objects):
            if not check_overlap(true_slice, pred_slice):
                continue  # Skip if bounding boxes don't overlap
            # Calculate the IoU only for the overlapping region
            overlap_region = tuple(
                slice(max(t.start, p.start), min(t.stop, p.stop))
                for t, p in zip(true_slice, pred_slice)
            )
            true_mass = labels_true == (i + 1)
            pred_mass = labels_pred == (j + 1)
            if np.any(true_mass[overlap_region]) and np.any(pred_mass[overlap_region]):
                    detected_masses += 1
                    break  # Found an overlapping mass, move to the next true mass
    
    detection_rate = detected_masses / num_true
    return detection_rate

def check_overlap(slice1, slice2):
    # Check if two slices overlap
    for dim1, dim2 in zip(slice1, slice2):
        if dim1.stop <= dim2.start or dim2.stop <= dim1.start:
            return False
    return True

def calculate_mass_detection_imagewise_volume(y_pred, y_true):

    detection_rates = []
    for idx in range(0, y_pred.shape[-1]):
        slice_pred = y_pred[:,:, idx]
        slice_true = y_true[:,:, idx]

    
        structure = np.ones((3, 3), dtype=np.bool_)  # 3D connectivity
        labels_true, num_true = labell(slice_true, structure=structure)
        labels_pred, num_pred = labell(slice_pred, structure=structure)

        if num_true != 0:
            detected_masses = 0
            true_objects = find_objects(labels_true)
            pred_objects = find_objects(labels_pred)
            
            for i, true_slice in enumerate(true_objects):
                for j, pred_slice in enumerate(pred_objects):
                    if not check_overlap(true_slice, pred_slice):
                        continue  # Skip if bounding boxes don't overlap
                    # Calculate the IoU only for the overlapping region
                    overlap_region = tuple(
                        slice(max(t.start, p.start), min(t.stop, p.stop))
                        for t, p in zip(true_slice, pred_slice)
                    )
                    true_mass = labels_true == (i + 1)
                    pred_mass = labels_pred == (j + 1)
                    if np.any(true_mass[overlap_region]) and np.any(pred_mass[overlap_region]):
                            detected_masses += 1
                            break  # Found an overlapping mass, move to the next true mass
            
            detection_rate = detected_masses / num_true
            detection_rates.append(detection_rate)
            
        
    return detection_rates

def calculate_mass_detection_iou_volume_get_rates(y_pred, y_true):
    structure = np.ones((3, 3, 3), dtype=np.bool_)  # 3D connectivity
    labels_true, num_true = labell(y_true, structure=structure)
    labels_pred, num_pred = labell(y_pred, structure=structure)
    
    detected_masses = 0
    true_objects = find_objects(labels_true)
    pred_objects = find_objects(labels_pred)
    
    for i, true_slice in enumerate(true_objects):
        for j, pred_slice in enumerate(pred_objects):
            if not check_overlap(true_slice, pred_slice):
                continue  # Skip if bounding boxes don't overlap
            # Calculate the IoU only for the overlapping region
            overlap_region = tuple(
                slice(max(t.start, p.start), min(t.stop, p.stop))
                for t, p in zip(true_slice, pred_slice)
            )
            true_mass = labels_true == (i + 1)
            pred_mass = labels_pred == (j + 1)
            if np.any(true_mass[overlap_region]) and np.any(pred_mass[overlap_region]):
                    detected_masses += 1
                    break  # Found an overlapping mass, move to the next true mass
    
    return detected_masses, num_true

In [28]:
patient_ids = os.listdir(dataset_base_path)

In [29]:
datasets = {}

In [30]:
# Split the dataset into train+val (80%) and test (20%)
x_train_val, x_test= train_test_split(patient_ids, test_size=0.2, random_state=SEED)

# Split the train+val dataset into train (60% of the original dataset) and val (20% of the original dataset)
x_train, x_val = train_test_split(x_train_val, test_size=0.25, random_state=SEED)

In [31]:
for patient_id in x_test:
    print(patient_id)


    patient_id = [patient_id]
    images_fnames, _ = get_filenames(suffix="images",
                                       base_path='Dataset-arrays-4-FINAL',
                                       patient_ids=patient_id,
                                       remove_black_samples=False,
                                       get_random_samples_and_remove_black_samples=False,
                                       random_samples_indexes_list=None)

    labels_fnames, _ = get_filenames(suffix="masks",
                                      base_path='Dataset-arrays-4-FINAL',
                                      patient_ids=patient_id,
                                      remove_black_samples=False,
                                      get_random_samples_and_remove_black_samples=False,
                                      random_samples_indexes_list=None, remove_picked_samples=False)
    

    

    test_dicts = [{"image": image_name, "label":label_name} for image_name, label_name in zip(images_fnames,labels_fnames)]

    #no_thorax_third_test_ds = CacheDataset(data=test_dicts, transform=test_transforms_no_thorax_third,num_workers)
    no_thorax_sub_test_ds = CacheDataset(data=test_dicts, transform=test_transforms_no_thorax_third_sub,num_workers=32)
    #patches_third_test_ds = CacheDataset(data=test_dicts, transform=test_transforms_patches_third, num_workers)
    patches_sub_test_ds = CacheDataset(data=test_dicts, transform=test_transforms_patches_sub,num_workers=32)

    datasets[patient_id[0]]={
        #"no_thorax_third_test_ds":no_thorax_third_test_ds,
        "no_thorax_sub_test_ds": no_thorax_sub_test_ds,
        #"patches_third_test_ds": patches_third_test_ds,
        "patches_sub_test_ds": patches_sub_test_ds
        
    }

LAXXX


Loading dataset: 100%|██████████| 228/228 [00:16<00:00, 13.86it/s]
Loading dataset: 100%|██████████| 228/228 [00:48<00:00,  4.73it/s]


AS0170


Loading dataset: 100%|██████████| 192/192 [00:14<00:00, 13.52it/s]
Loading dataset: 100%|██████████| 192/192 [00:39<00:00,  4.89it/s]


PR0760


Loading dataset: 100%|██████████| 228/228 [00:16<00:00, 13.80it/s]
Loading dataset: 100%|██████████| 228/228 [00:50<00:00,  4.52it/s]


GF0380


Loading dataset: 100%|██████████| 204/204 [00:15<00:00, 12.82it/s]
Loading dataset: 100%|██████████| 204/204 [00:44<00:00,  4.55it/s]


FP211261


Loading dataset: 100%|██████████| 188/188 [00:13<00:00, 13.89it/s]
Loading dataset: 100%|██████████| 188/188 [00:40<00:00,  4.64it/s]


D2MP3(VR)


Loading dataset: 100%|██████████| 228/228 [00:17<00:00, 12.97it/s]
Loading dataset: 100%|██████████| 228/228 [00:48<00:00,  4.70it/s]


MG0477


Loading dataset: 100%|██████████| 132/132 [00:10<00:00, 12.58it/s]
Loading dataset: 100%|██████████| 132/132 [00:29<00:00,  4.42it/s]


OL1062R


Loading dataset: 100%|██████████| 228/228 [00:16<00:00, 14.23it/s]
Loading dataset: 100%|██████████| 228/228 [00:48<00:00,  4.67it/s]


BV1252


Loading dataset: 100%|██████████| 188/188 [00:12<00:00, 14.77it/s]
Loading dataset: 100%|██████████| 188/188 [00:38<00:00,  4.87it/s]


D1AP7(VR)


Loading dataset: 100%|██████████| 228/228 [00:18<00:00, 12.44it/s]
Loading dataset: 100%|██████████| 228/228 [00:48<00:00,  4.65it/s]


SD080569


Loading dataset: 100%|██████████| 236/236 [00:17<00:00, 13.62it/s]
Loading dataset: 100%|██████████| 236/236 [00:50<00:00,  4.69it/s]


CC0167


Loading dataset: 100%|██████████| 228/228 [00:18<00:00, 12.54it/s]
Loading dataset: 100%|██████████| 228/228 [00:50<00:00,  4.54it/s]


RHCL031174


Loading dataset: 100%|██████████| 188/188 [00:15<00:00, 11.82it/s]
Loading dataset: 100%|██████████| 188/188 [00:40<00:00,  4.66it/s]


LA0248


Loading dataset: 100%|██████████| 284/284 [00:22<00:00, 12.82it/s]
Loading dataset: 100%|██████████| 284/284 [01:00<00:00,  4.71it/s]


RP271052


Loading dataset: 100%|██████████| 96/96 [00:08<00:00, 11.19it/s]
Loading dataset: 100%|██████████| 96/96 [00:22<00:00,  4.35it/s]


SL191251


Loading dataset: 100%|██████████| 80/80 [00:08<00:00,  8.92it/s]
Loading dataset: 100%|██████████| 80/80 [00:18<00:00,  4.34it/s]


LGM0159(1,5)


Loading dataset: 100%|██████████| 112/112 [00:09<00:00, 12.33it/s]
Loading dataset: 100%|██████████| 112/112 [00:24<00:00,  4.66it/s]


PA150139


Loading dataset: 100%|██████████| 96/96 [00:08<00:00, 11.39it/s]
Loading dataset: 100%|██████████| 96/96 [00:21<00:00,  4.57it/s]


HF230274


Loading dataset: 100%|██████████| 204/204 [00:16<00:00, 12.68it/s]
Loading dataset: 100%|██████████| 204/204 [00:45<00:00,  4.52it/s]


CF160366


Loading dataset: 100%|██████████| 212/212 [00:18<00:00, 11.72it/s]
Loading dataset: 100%|██████████| 212/212 [00:48<00:00,  4.38it/s]


GLA1074


Loading dataset: 100%|██████████| 112/112 [00:10<00:00, 10.47it/s]
Loading dataset: 100%|██████████| 112/112 [00:24<00:00,  4.66it/s]


In [32]:
os.cpu_count()

256

In [33]:
datasets.keys()

dict_keys(['LAXXX', 'AS0170', 'PR0760', 'GF0380', 'FP211261', 'D2MP3(VR)', 'MG0477', 'OL1062R', 'BV1252', 'D1AP7(VR)', 'SD080569', 'CC0167', 'RHCL031174', 'LA0248', 'RP271052', 'SL191251', 'LGM0159(1,5)', 'PA150139', 'HF230274', 'CF160366', 'GLA1074'])

# PATIENT AWARE TEST FUNCTIONS

In [40]:
def fuse_segmentations(model1_prob, model2_prob, prob_threshold=0.5, boost_factor=1.5, penalty_factor=0.5, kernel_size=3):

    
    model1_prob = np.squeeze(model1_prob)
    model2_prob = np.squeeze(model2_prob)
    # Step 1: Check where both models agree above the probability threshold
    agreement = np.logical_and(model1_prob > prob_threshold, model2_prob > prob_threshold)# Step 2: Create a kernel for dilation
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    
    # Step 3: Dilate the agreement area to enlarge it
    enlarged_agreement = cv2.dilate(agreement.astype(np.uint8), kernel)
    
    # Step 4: Sum the probabilities of both models
    prob_sum = model1_prob + model2_prob
    
    # Step 5: Boost the probability sum where there is agreement
    prob_sum[enlarged_agreement > 0] *= boost_factor
    
    # Step 6: Identify disagreement (where there is no enlarged agreement)
    disagreement = enlarged_agreement == 0
    
    # Step 7: Apply penalty factor where there is disagreement
    prob_sum[disagreement] *= penalty_factor
    
    # Step 8: Normalize the probability sum to get the fused probability, ensuring it's within [0, 1]
    fused_prob = np.clip(prob_sum / 2.0, 0, 1)
    
    return fused_prob

def filter_masses(volume, min_slices=7, window_size=3):
    """
    Filters out masses in a 3D volume (HxWxB) that do not consecutively appear in at least 'min_slices' slices,
    considering a window around each mass. This function uses cv2 for dilation and assumes binary masks as input.
    
    :param volume: 3D numpy array of shape (H, W, B) representing a volume of binary masks.
    :param min_slices: Minimum number of consecutive slices a mass must appear in to be kept.
    :param window_size: Diameter of the window used for dilation to connect varying shapes.
    :return: Filtered 3D volume.
    """
    volume_copy = np.copy(volume)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (window_size, window_size))
    
    for i in range(volume_copy.shape[2]):
        volume_copy[:, :, i] = cv2.dilate(volume_copy[:, :, i].astype(np.uint8), kernel, iterations=1)
    
    volume_copy_transposed = np.transpose(volume_copy, (2, 0, 1))
    volume_transposed = np.transpose(volume, (2, 0, 1))
    
    structure = generate_binary_structure(3, 1)
    labeled_volume, num_features = labell(volume_copy_transposed, structure=structure)
    print(f"num features: {num_features}")
    for feature_id in range(1, num_features + 1):
        print(feature_id)
        feature_mask = labeled_volume == feature_id
        slice_presence_count = np.sum(np.any(feature_mask, axis=(1, 2)))
    
        if slice_presence_count < min_slices:
            volume_transposed[labeled_volume == feature_id] = 0  # Filter out the feature
    
    filtered_volume = np.transpose(volume_transposed, (1, 2, 0))
    print("fine")
    return filtered_volume

In [41]:
def test_patient_aware_no_patches(model_path, patient_ids, datasets, dataset_key, filter=False):

    model = BreastModel2.load_from_checkpoint(model_path, strict=False)

    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy = []
    model_precision = []
    model_recall = []
    
    for patient_id in patient_ids:
        
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
    
         TP = []
         FP = []
         FN = []
         TN = []

         detection_rates =  []
    
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            original_image = np.load(e['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)

            gt_label = np.load(e['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if e['keep_sample']:
                image = torch.unsqueeze(e['image'], 0)
                    
                with torch.no_grad():
                    model = model.to("cuda")
                    model.eval()
                    masks = model(image.to("cuda"))[0]
                    masks = masks.sigmoid()
                    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = reverse_transformations(dataset[idx], pred_label, mode='whole')
                
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='neares-exact')(pred_label)
            else:
                pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)

            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
                
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             

             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP =  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP =  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN =  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN =  [torch.tensor([[elem]]) for elem in tn.squeeze()]

    
         detection_iou =  np.array(calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)).mean()
         
         mean_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)
         mean_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)

         iou_mass_volume = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise')
         iou_mass_volume_no_empty = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise', exclude_empty=True)
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='micro-imagewise', class_id=1)
         dice_mass_volume_no_empty = compute_dice_score_npy(gt_label_volume, predicted_label_volume, reduction='micro-imagewise',class_id=1, exclude_empty=True)
        
         accuracy = compute_accuracy_from_cumulator(TP, FP, FN, TN)
         precision = compute_precision_from_cumulator(TP, FP, FN, TN)
         recall = compute_recall_from_cumulator(TP, FP, FN, TN)
        
         print("CLASS MEAN IOU", mean_iou)
         print("CLASS MEAN DICE", mean_dice)
         print("DIOU", detection_iou)
         print("IOU MASS VOLUME", iou_mass_volume)
         print("IOU MASS VOLUME NO EMPTY", iou_mass_volume_no_empty)
         print("DICE MASS VOLUME ", dice_mass_volume)
         print("DICE MASS VOLUME NO EMPTY ", dice_mass_volume_no_empty)

         print("ACCURACY ", accuracy)
         print("PRECISION ", precision)
         print("RECALL", recall)
         
         print()
         model_class_mean_iou.append(mean_iou)
         model_class_mean_dice.append(mean_dice)
         model_detection_iou.append(detection_iou)
        
         model_iou_mass_volume.append(iou_mass_volume)
         model_iou_mass_volume_no_empty.append(iou_mass_volume_no_empty)
        
         model_dice_mass_volume.append(dice_mass_volume)
         model_dice_mass_volume_no_empty.append(dice_mass_volume_no_empty)

         model_accuracy.append(accuracy)
         model_precision.append(precision)
         model_recall.append(recall)

    model_class_mean_iou = np.array(model_class_mean_iou).mean()
    model_class_mean_dice = np.array(model_class_mean_dice).mean()
    model_detection_iou = np.array(model_detection_iou).mean()
    
    model_iou_mass_volume = np.array(model_iou_mass_volume).mean()
    model_iou_mass_volume_no_empty = np.array(model_iou_mass_volume_no_empty).mean()
    
    model_dice_mass_volume = np.array(model_dice_mass_volume).mean()
    model_dice_mass_volume_no_empty = np.array(model_dice_mass_volume_no_empty).mean()

    model_accuracy = np.array(model_accuracy).mean()
    model_precision = np.array(model_precision).mean()
    model_recall = np.array(model_recall).mean()

    print("MODEL CLASS MEAN IOU", model_class_mean_iou)
    print("MODEL CLASS MEAN DICE", model_class_mean_dice)
    print("MODEL DIOU", model_detection_iou)
        
    print("MODEL IOU MASS VOLUME", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME NO EMPTY", model_iou_mass_volume_no_empty)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)

    print("MODEL ACCURACY ", model_accuracy)
    print("MODEL PRECISION ", model_precision)
    print("MODEL RECALL", model_recall)

In [42]:
def test_dataset_aware_no_patches(model_path, patient_ids, datasets, dataset_key, filter=False, strict=False, get_scores_for_statistics=False,get_only_masses=False,arch_name=False):

    if arch_name:
        model = BreastModel2.load_from_checkpoint(model_path, strict=strict, arch=arch_name)

    else:

        model = BreastModel2.load_from_checkpoint(model_path, strict=strict)

    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy = []
    model_precision = []
    model_recall = []

    TP = []
    FP = []
    FN = []
    TN = []
    detection_iou=[]
    
    for patient_id in patient_ids:


    
        
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
        
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            original_image = np.load(e['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)

            gt_label = np.load(e['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if e['keep_sample']:
                image = torch.unsqueeze(e['image'], 0)
                    
                with torch.no_grad():
                    model = model.to("cuda")
                    model.eval()
                    if arch_name:
                            masks = model(image.to("cuda"))[0]
                    else:
                            masks = model(image.to("cuda"))[0]
                    masks = masks.sigmoid()
                    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = torch.squeeze(pred_label)
                pred_label = torch.unsqueeze(pred_label,0)
                pred_label = reverse_transformations(dataset[idx], pred_label, mode='whole')
                
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)
            else:
                pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)

            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
            
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             
            
             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP +=  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP +=  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN +=  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN +=  [torch.tensor([[elem]]) for elem in tn.squeeze()]

         detection_iou+=calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)

    model_detection_iou = np.array(detection_iou).mean()
    model_detection_iou_std = np.array(detection_iou).std()
    
    model_class_mean_iou, model_class_std_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_class_mean_dice, model_class_std_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)

    model_iou_mass_volume , model_iou_mass_volume_std = compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=False, return_std=True)
    model_iou_mass_volume_no_empty, model_iou_mass_volume_no_empty_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_iou_mass_volume_no_empty_optimistic, model_iou_mass_volume_no_empty_optimistic_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, exclude_empty_only_gt=True, return_std=True)
    
    
    model_dice_mass_volume, model_dice_mass_volume_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=False, return_std=True)
    model_dice_mass_volume_no_empty, model_dice_mass_volume_no_empty_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, return_std=True)
    model_dice_mass_volume_no_empty_optimistic, model_dice_mass_volume_no_empty_optimistic_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, exclude_empty_only_gt=True,return_std=True)
    
    model_mean_accuracy_no_empty, model_mean_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_precision_no_empty,model_mean_precision_no_empty_std = compute_precision_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_recall_no_empty, model_mean_recall_no_empty_std = compute_recall_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_f1_no_empty, model_mean_f1_no_empty_std = compute_f1_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
        
    model_accuracy_excluding_cases, model_accuracy_excluding_cases_std = compute_accuracy_excluding_cases(TP, FP, FN, TN, return_std=True)
    model_precision_excluding_cases,model_precision_excluding_cases_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)
    model_recall_excluding_cases,model_recall_excluding_cases_std  =compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)

    model_accuracy_no_empty, model_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=False, return_std=True)
    model_precision_no_empty,model_precision_no_empty_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    model_recall_no_empty,model_recall_no_empty_std  = compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)

    model_f1_no_empty,model_f1_no_empty_std = compute_f1_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    
    
    print("MODEL CLASS MEAN IOU ", model_class_mean_iou)
    print("MODEL CLASS STD IOU ", model_class_std_iou)
    print()
    print("MODEL CLASS MEAN DICE ", model_class_mean_dice)
    print("MODEL CLASS STD DICE ", model_class_std_dice)
    print()
    print("MODEL DIOU", model_detection_iou)
    print("MODEL DIOU STD ", model_detection_iou_std) 
    print()
    print("MODEL IOU MASS VOLUME ", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME STD ", model_iou_mass_volume_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY ", model_iou_mass_volume_no_empty)
    print("MODEL IOU MASS VOLUME NO EMPTY STD ", model_iou_mass_volume_no_empty_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC ", model_iou_mass_volume_no_empty_optimistic)
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_iou_mass_volume_no_empty_optimistic_std)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME STD ", model_dice_mass_volume_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)
    print("MODEL DICE MASS VOLUME NO EMPTY STD ", model_dice_mass_volume_no_empty_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC ", model_dice_mass_volume_no_empty_optimistic)
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_dice_mass_volume_no_empty_optimistic_std)
    print() 
    print("MODEL MEAN ACCURACY NO EMPTY", model_mean_accuracy_no_empty)
    print("MODEL MEAN ACCURACY NO EMPTY STD", model_mean_accuracy_no_empty_std)
    print()                                                                              
    print("MODEL MEAN PRECISION NO EMPTY", model_mean_precision_no_empty)
    print("MODEL MEAN PRECISION NO EMPTY STD", model_mean_precision_no_empty_std)
    print()
    print("MODEL MEAN RECALL NO EMPTY", model_mean_recall_no_empty)
    print("MODEL MEAN RECALL NO EMPTY STD", model_mean_recall_no_empty_std)
    print()
    print("MODEL MEAN F1 NO EMPTY", model_mean_f1_no_empty)
    print("MODEL MEAN F1 NO EMPTY STD", model_mean_f1_no_empty_std)
    print()
    print("MODEL ACCURACY EXCLUDING CASES ",  model_accuracy_excluding_cases)
    print("MODEL ACCURACY EXCLUDING CASES STD ",  model_accuracy_excluding_cases_std)
    print()
    print("MODEL PRECISION EXCLUDING CASES ",  model_precision_excluding_cases)
    print("MODEL PRECISION EXCLUDING CASES STD ",  model_precision_excluding_cases_std)
    print()
    print("MODEL RECALL EXCLUDING CASES ", model_recall_excluding_cases)
    print("MODEL RECALL EXCLUDING CASES STD ", model_recall_excluding_cases_std)
    print()
    print("MODEL ACCURACY NO EMPTY ",  model_accuracy_no_empty)
    print("MODEL ACCURACY NO EMPTY STD ",  model_accuracy_no_empty_std)
    print()
    print("MODEL PRECISION NO EMPTY",  model_precision_no_empty)
    print("MODEL PRECISION NO EMPTY STD ",  model_precision_no_empty_std)
    print()
    print("MODEL RECALL NO EMPTY ", model_recall_no_empty)
    print("MODEL RECALL NO EMPTY STD ", model_recall_no_empty_std)
    print()
    print("MODEL F1 NO EMPTY ", model_f1_no_empty)
    print("MODEL F1 NO EMPTY STD ", model_f1_no_empty_std)
    print()
    


    if get_scores_for_statistics:
            tp = torch.cat([tp for tp in TP])
            fp = torch.cat([fp for fp in FP])
            fn = torch.cat([fn for fn in FN])
            tn = torch.cat([tn for tn in TN])


            if get_only_masses:
                # Create a mask where tp + fn is not equal to 0
                mask = (tp + fn) != 0
                
                # Apply this mask to each tensor to filter out the desired values
                tp = tp[mask]
                fp = fp[mask]
                fn = fn[mask]
                tn = tn[mask]

            miou_scores = compute_mean_iou_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mdice_scores = compute_mean_dice_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mf1_scores = compute_f1_from_cumulator(tp, fp, fn, tn, exclude_empty=False, is_mean=True, return_std=False,reduce_mean=False)


            scores_dict = {
                 'miou': miou_scores.squeeze().tolist(),
                 'mdice': mdice_scores.squeeze().tolist(),
                 "mf1": mf1_scores.squeeze().tolist(),
            }
            return scores_dict


In [43]:
def test_patient_aware_patches(model_path, patient_ids, datasets, dataset_key):

    model = BreastModel2.load_from_checkpoint(model_path, strict=False)

    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []
    
    model_precision = []
    model_recall = []
    model_accuracy = []

    
    
    for patient_id in patient_ids:
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
    
         TP = []
         FP = []
         FN =[]
         TN = []
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)

            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
            
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)


            merged_label = torch.zeros(original_image.shape)
            merged_label_for_fusion = torch.zeros(original_image.shape)

            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model = model.to("cuda")
                        model.eval()
                        logits = model(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        #label = pr_mask
                        pr_mask = (pr_mask > 0.4).int()
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label += label
                    

                pred_label = merged_label
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)

            tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
            TP.append(tp)
            FP.append(fp)
            FN.append(fn)
            TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
        
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
       
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         detection_iou =  np.array(calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)).mean()
         
         mean_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)
         mean_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)

         iou_mass_volume = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise')
         iou_mass_volume_no_empty = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise', exclude_empty=True)
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='micro-imagewise', class_id=1)
         dice_mass_volume_no_empty = compute_dice_score_npy(gt_label_volume, predicted_label_volume, reduction='micro-imagewise',class_id=1, exclude_empty=True)

        
         accuracy = compute_accuracy_from_cumulator(TP, FP, FN, TN)
         precision = compute_precision_from_cumulator(TP, FP, FN, TN)
         recall = compute_recall_from_cumulator(TP, FP, FN, TN)
        
         print("CLASS MEAN IOU", mean_iou)
         print("CLASS MEAN DICE", mean_dice)
         print("DIOU", detection_iou)
         print("IOU MASS VOLUME", iou_mass_volume)
         print("IOU MASS VOLUME NO EMPTY", iou_mass_volume_no_empty)
         print("DICE MASS VOLUME ", dice_mass_volume)
         print("DICE MASS VOLUME NO EMPTY ", dice_mass_volume_no_empty)

         print("ACCURACY ", accuracy)
         print("PRECISION ", precision)
         print("RECALL", recall)
         
         print()
         model_class_mean_iou.append(mean_iou)
         model_class_mean_dice.append(mean_dice)
         model_detection_iou.append(detection_iou)
        
         model_iou_mass_volume.append(iou_mass_volume)
         model_iou_mass_volume_no_empty.append(iou_mass_volume_no_empty)
        
         model_dice_mass_volume.append(dice_mass_volume)
         model_dice_mass_volume_no_empty.append(dice_mass_volume_no_empty)

         model_accuracy.append(accuracy)
         model_precision.append(precision)
         model_recall.append(recall)

    model_class_mean_iou = np.array(model_class_mean_iou).mean()
    model_class_mean_dice = np.array(model_class_mean_dice).mean()
    model_detection_iou = np.array(model_detection_iou).mean()
    
    model_iou_mass_volume = np.array(model_iou_mass_volume).mean()
    model_iou_mass_volume_no_empty = np.array(model_iou_mass_volume_no_empty).mean()
    
    model_dice_mass_volume = np.array(model_dice_mass_volume).mean()
    model_dice_mass_volume_no_empty = np.array(model_dice_mass_volume_no_empty).mean()

    model_accuracy = np.array(model_accuracy).mean()
    model_precision = np.array(model_precision).mean()
    model_recall = np.array(model_recall).mean()

    print("MODEL CLASS MEAN IOU", model_class_mean_iou)
    print("MODEL CLASS MEAN DICE", model_class_mean_dice)
    print("MODEL DIOU", model_detection_iou)
        
    print("MODEL IOU MASS VOLUME", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME NO EMPTY", model_iou_mass_volume_no_empty)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)

    print("MODEL ACCURACY ", model_accuracy)
    print("MODEL PRECISION ", model_precision)
    print("MODEL RECALL", model_recall)

In [44]:
# Function to sum every two consecutive elements in a tensor
def sum_every_two(tensor):
    # Ensure the tensor has an even number of elements
    if tensor.numel() % 2 != 0:
        raise ValueError("The number of elements in the tensor must be even")
    # Reshape to have pairs of elements in the last dimension, then sum along that dimension
    return tensor.view(-1, 2).sum(dim=1)

In [None]:
def test_dataset_aware_patches(model_path, patient_ids, datasets, dataset_key, get_scores_for_statistics=False,get_only_masses=False, filter = False,  strict=False, arch_name=False):

    
    if arch_name:
        model = BreastModel2.load_from_checkpoint(model_path, strict=strict, arch=arch_name)

    else:

        model = BreastModel2.load_from_checkpoint(model_path, strict=strict)
    
    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy = []
    model_precision = []
    model_recall = []
    
    TP = []
    FP = []
    FN = []
    TN = []

    
    detection_iou=[]

    for patient_id in patient_ids:
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)

            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
            
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)


            merged_label = torch.zeros(original_image.shape)
            merged_label_for_fusion = torch.zeros(original_image.shape)

            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model = model.to("cuda")
                        model.eval()
                        if arch_name:
                            logits = model(image.to("cuda"))[0]
                        else:
                            logits = model(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        #label = pr_mask
                        pr_mask = (pr_mask > 0.4).int()
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label += label
                    

                pred_label = merged_label
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)

            tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
            TP.append(tp)
            FP.append(fp)
            FN.append(fn)
            TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
        
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             
            
             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP +=  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP +=  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN +=  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN +=  [torch.tensor([[elem]]) for elem in tn.squeeze()]

    
         detection_iou+=calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)
    
    model_detection_iou = np.array(detection_iou).mean()
    model_detection_iou_std = np.array(detection_iou).std()
    
    model_class_mean_iou, model_class_std_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_class_mean_dice, model_class_std_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)

    model_iou_mass_volume , model_iou_mass_volume_std = compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=False, return_std=True)
    model_iou_mass_volume_no_empty, model_iou_mass_volume_no_empty_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_iou_mass_volume_no_empty_optimistic, model_iou_mass_volume_no_empty_optimistic_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, exclude_empty_only_gt=True, return_std=True)
    
    
    model_dice_mass_volume, model_dice_mass_volume_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=False, return_std=True)
    model_dice_mass_volume_no_empty, model_dice_mass_volume_no_empty_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, return_std=True)
    model_dice_mass_volume_no_empty_optimistic, model_dice_mass_volume_no_empty_optimistic_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, exclude_empty_only_gt=True,return_std=True)
    
    model_mean_accuracy_no_empty, model_mean_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_precision_no_empty,model_mean_precision_no_empty_std = compute_precision_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_recall_no_empty, model_mean_recall_no_empty_std = compute_recall_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_f1_no_empty, model_mean_f1_no_empty_std = compute_f1_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
        
    model_accuracy_excluding_cases, model_accuracy_excluding_cases_std = compute_accuracy_excluding_cases(TP, FP, FN, TN, return_std=True)
    model_precision_excluding_cases,model_precision_excluding_cases_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)
    model_recall_excluding_cases,model_recall_excluding_cases_std  =compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)

    model_accuracy_no_empty, model_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=False, return_std=True)
    model_precision_no_empty,model_precision_no_empty_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    model_recall_no_empty,model_recall_no_empty_std  = compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)

    model_f1_no_empty,model_f1_no_empty_std = compute_f1_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    
    
    print("MODEL CLASS MEAN IOU ", model_class_mean_iou)
    print("MODEL CLASS STD IOU ", model_class_std_iou)
    print()
    print("MODEL CLASS MEAN DICE ", model_class_mean_dice)
    print("MODEL CLASS STD DICE ", model_class_std_dice)
    print()
    print("MODEL DIOU", model_detection_iou)
    print("MODEL DIOU STD ", model_detection_iou_std) 
    print()
    print("MODEL IOU MASS VOLUME ", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME STD ", model_iou_mass_volume_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY ", model_iou_mass_volume_no_empty)
    print("MODEL IOU MASS VOLUME NO EMPTY STD ", model_iou_mass_volume_no_empty_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC ", model_iou_mass_volume_no_empty_optimistic)
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_iou_mass_volume_no_empty_optimistic_std)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME STD ", model_dice_mass_volume_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)
    print("MODEL DICE MASS VOLUME NO EMPTY STD ", model_dice_mass_volume_no_empty_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC ", model_dice_mass_volume_no_empty_optimistic)
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_dice_mass_volume_no_empty_optimistic_std)
    print() 
    print("MODEL MEAN ACCURACY NO EMPTY", model_mean_accuracy_no_empty)
    print("MODEL MEAN ACCURACY NO EMPTY STD", model_mean_accuracy_no_empty_std)
    print()                                                                              
    print("MODEL MEAN PRECISION NO EMPTY", model_mean_precision_no_empty)
    print("MODEL MEAN PRECISION NO EMPTY STD", model_mean_precision_no_empty_std)
    print()
    print("MODEL MEAN RECALL NO EMPTY", model_mean_recall_no_empty)
    print("MODEL MEAN RECALL NO EMPTY STD", model_mean_recall_no_empty_std)
    print()
    print("MODEL MEAN F1 NO EMPTY", model_mean_f1_no_empty)
    print("MODEL MEAN F1 NO EMPTY STD", model_mean_f1_no_empty_std)
    print()
    print("MODEL ACCURACY EXCLUDING CASES ",  model_accuracy_excluding_cases)
    print("MODEL ACCURACY EXCLUDING CASES STD ",  model_accuracy_excluding_cases_std)
    print()
    print("MODEL PRECISION EXCLUDING CASES ",  model_precision_excluding_cases)
    print("MODEL PRECISION EXCLUDING CASES STD ",  model_precision_excluding_cases_std)
    print()
    print("MODEL RECALL EXCLUDING CASES ", model_recall_excluding_cases)
    print("MODEL RECALL EXCLUDING CASES STD ", model_recall_excluding_cases_std)
    print()
    print("MODEL ACCURACY NO EMPTY ",  model_accuracy_no_empty)
    print("MODEL ACCURACY NO EMPTY STD ",  model_accuracy_no_empty_std)
    print()
    print("MODEL PRECISION NO EMPTY",  model_precision_no_empty)
    print("MODEL PRECISION NO EMPTY STD ",  model_precision_no_empty_std)
    print()
    print("MODEL RECALL NO EMPTY ", model_recall_no_empty)
    print("MODEL RECALL NO EMPTY STD ", model_recall_no_empty_std)
    print()
    print("MODEL F1 NO EMPTY ", model_f1_no_empty)
    print("MODEL F1 NO EMPTY STD ", model_f1_no_empty_std)
    print()
    



    if get_scores_for_statistics:
            tp = torch.cat([tp for tp in TP])
            fp = torch.cat([fp for fp in FP])
            fn = torch.cat([fn for fn in FN])
            tn = torch.cat([tn for tn in TN])

            tp =  sum_every_two(tp.squeeze())
            fp =  sum_every_two(fp.squeeze())
            fn =  sum_every_two(fn.squeeze())
            tn = sum_every_two(tn.squeeze())

            if get_only_masses:
                # Create a mask where tp + fn is not equal to 0
                mask = (tp + fn) != 0
                
                # Apply this mask to each tensor to filter out the desired values
                tp = tp[mask]
                fp = fp[mask]
                fn = fn[mask]
                tn = tn[mask]

            miou_scores = compute_mean_iou_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mdice_scores = compute_mean_dice_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mf1_scores = compute_f1_from_cumulator(tp, fp, fn, tn, exclude_empty=False, is_mean=True, return_std=False,reduce_mean=False)


            scores_dict = {
                 'miou': miou_scores.squeeze().tolist(),
                 'mdice': mdice_scores.squeeze().tolist(),
                 "mf1": mf1_scores.squeeze().tolist(),
            }
            return scores_dict

In [46]:
def test_patient_aware_fusion(model_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key, use_simple_fusion=False, use_decoder_attention=True, strict=False, filter=False, get_scores_for_statistics=False,get_only_masses=False):

    print(use_decoder_attention)
    print(use_simple_fusion)
    model = BreastModel.load_from_checkpoint(model_path, strict=strict, use_simple_fusion=use_simple_fusion, use_decoder_attention=use_decoder_attention)
    print(model)
    
    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []
    
    model_precision = []
    model_recall = []
    model_accuracy = []
    
    for patient_id in patient_ids:

         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
    
         TP = []
         FP = []
         FN =[]
         TN = []
         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)
        
         prev_had_mask=False

         for idx, e in tqdm(enumerate(fusion_dataset), total = len(patches_ds)):
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)
    
            
            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
    
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if fusion_dataset[idx][0]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model = model.to("cuda")
                    model.eval()
                    
                    masks = model(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = reverse_transformations(fusion_dataset[idx][0], pred_label, mode='whole')
                
                
            pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)
             
            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             

             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP =  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP =  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN =  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN =  [torch.tensor([[elem]]) for elem in tn.squeeze()]

         detection_iou =  np.array(calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)).mean()
         
         mean_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)
         mean_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)

         iou_mass_volume = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise')
         iou_mass_volume_no_empty = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise', exclude_empty=True)
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='micro-imagewise', class_id=1)
         dice_mass_volume_no_empty = compute_dice_score_npy(gt_label_volume, predicted_label_volume, reduction='micro-imagewise',class_id=1, exclude_empty=True)

        
         accuracy = compute_accuracy_from_cumulator(TP, FP, FN, TN)
         precision = compute_precision_from_cumulator(TP, FP, FN, TN)
         recall = compute_recall_from_cumulator(TP, FP, FN, TN)
        
         print("CLASS MEAN IOU", mean_iou)
         print("CLASS MEAN DICE", mean_dice)
         print("DIOU", detection_iou)
         print("IOU MASS VOLUME", iou_mass_volume)
         print("IOU MASS VOLUME NO EMPTY", iou_mass_volume_no_empty)
         print("DICE MASS VOLUME ", dice_mass_volume)
         print("DICE MASS VOLUME NO EMPTY ", dice_mass_volume_no_empty)

         print("ACCURACY ", accuracy)
         print("PRECISION ", precision)
         print("RECALL", recall)
         
         print()
         model_class_mean_iou.append(mean_iou)
         model_class_mean_dice.append(mean_dice)
         model_detection_iou.append(detection_iou)
        
         model_iou_mass_volume.append(iou_mass_volume)
         model_iou_mass_volume_no_empty.append(iou_mass_volume_no_empty)
        
         model_dice_mass_volume.append(dice_mass_volume)
         model_dice_mass_volume_no_empty.append(dice_mass_volume_no_empty)

         model_accuracy.append(accuracy)
         model_precision.append(precision)
         model_recall.append(recall)

    model_class_mean_iou = np.array(model_class_mean_iou).mean()
    model_class_mean_dice = np.array(model_class_mean_dice).mean()
    model_detection_iou = np.array(model_detection_iou).mean()
    
    model_iou_mass_volume = np.array(model_iou_mass_volume).mean()
    model_iou_mass_volume_no_empty = np.array(model_iou_mass_volume_no_empty).mean()
    
    model_dice_mass_volume = np.array(model_dice_mass_volume).mean()
    model_dice_mass_volume_no_empty = np.array(model_dice_mass_volume_no_empty).mean()

    model_accuracy = np.array(model_accuracy).mean()
    model_precision = np.array(model_precision).mean()
    model_recall = np.array(model_recall).mean()

    print("MODEL CLASS MEAN IOU", model_class_mean_iou)
    print("MODEL CLASS MEAN DICE", model_class_mean_dice)
    print("MODEL DIOU", model_detection_iou)
        
    print("MODEL IOU MASS VOLUME", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME NO EMPTY", model_iou_mass_volume_no_empty)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)

    print("MODEL ACCURACY ", model_accuracy)
    print("MODEL PRECISION ", model_precision)
    print("MODEL RECALL", model_recall)

In [47]:
def test_dataset_aware_fusion(model_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key, use_simple_fusion=False, use_decoder_attention=True, strict=False, filter=False, get_scores_for_statistics=False,get_only_masses=False):

    model = BreastModel.load_from_checkpoint(model_path, strict=strict, use_simple_fusion=use_simple_fusion, use_decoder_attention=use_decoder_attention)
    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy = []
    model_precision = []
    model_recall = []
    
    TP = []
    FP = []
    FN = []
    TN = []


    detection_iou=[]


    for patient_id in patient_ids:
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)
        
         prev_had_mask=False

         for idx, e in tqdm(enumerate(fusion_dataset), total = len(patches_ds)):
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)
    
            
            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
    
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if fusion_dataset[idx][0]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model = model.to("cuda")
                    model.eval()
                    
                    masks = model(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = reverse_transformations(fusion_dataset[idx][0], pred_label, mode='whole')
                

                
            pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)
             
            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(pred_label,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(label_whole, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(label_patches, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label, cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             
            
             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP +=  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP +=  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN +=  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN +=  [torch.tensor([[elem]]) for elem in tn.squeeze()]



         detection_iou+=calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)
    
    model_detection_iou = np.array(detection_iou).mean()
    model_detection_iou_std = np.array(detection_iou).std()
    
    model_class_mean_iou, model_class_std_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_class_mean_dice, model_class_std_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)

    model_iou_mass_volume , model_iou_mass_volume_std = compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=False, return_std=True)
    model_iou_mass_volume_no_empty, model_iou_mass_volume_no_empty_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_iou_mass_volume_no_empty_optimistic, model_iou_mass_volume_no_empty_optimistic_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, exclude_empty_only_gt=True, return_std=True)
    
    
    model_dice_mass_volume, model_dice_mass_volume_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=False, return_std=True)
    model_dice_mass_volume_no_empty, model_dice_mass_volume_no_empty_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, return_std=True)
    model_dice_mass_volume_no_empty_optimistic, model_dice_mass_volume_no_empty_optimistic_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, exclude_empty_only_gt=True,return_std=True)
    
    model_mean_accuracy_no_empty, model_mean_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_precision_no_empty,model_mean_precision_no_empty_std = compute_precision_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_recall_no_empty, model_mean_recall_no_empty_std = compute_recall_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_f1_no_empty, model_mean_f1_no_empty_std = compute_f1_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
        
    model_accuracy_excluding_cases, model_accuracy_excluding_cases_std = compute_accuracy_excluding_cases(TP, FP, FN, TN, return_std=True)
    model_precision_excluding_cases,model_precision_excluding_cases_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)
    model_recall_excluding_cases,model_recall_excluding_cases_std  =compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)

    model_accuracy_no_empty, model_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=False, return_std=True)
    model_precision_no_empty,model_precision_no_empty_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    model_recall_no_empty,model_recall_no_empty_std  = compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)

    model_f1_no_empty,model_f1_no_empty_std = compute_f1_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    
    
    print("MODEL CLASS MEAN IOU ", model_class_mean_iou)
    print("MODEL CLASS STD IOU ", model_class_std_iou)
    print()
    print("MODEL CLASS MEAN DICE ", model_class_mean_dice)
    print("MODEL CLASS STD DICE ", model_class_std_dice)
    print()
    print("MODEL DIOU", model_detection_iou)
    print("MODEL DIOU STD ", model_detection_iou_std) 
    print()
    print("MODEL IOU MASS VOLUME ", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME STD ", model_iou_mass_volume_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY ", model_iou_mass_volume_no_empty)
    print("MODEL IOU MASS VOLUME NO EMPTY STD ", model_iou_mass_volume_no_empty_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC ", model_iou_mass_volume_no_empty_optimistic)
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_iou_mass_volume_no_empty_optimistic_std)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME STD ", model_dice_mass_volume_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)
    print("MODEL DICE MASS VOLUME NO EMPTY STD ", model_dice_mass_volume_no_empty_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC ", model_dice_mass_volume_no_empty_optimistic)
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_dice_mass_volume_no_empty_optimistic_std)
    print() 
    print("MODEL MEAN ACCURACY NO EMPTY", model_mean_accuracy_no_empty)
    print("MODEL MEAN ACCURACY NO EMPTY STD", model_mean_accuracy_no_empty_std)
    print()                                                                              
    print("MODEL MEAN PRECISION NO EMPTY", model_mean_precision_no_empty)
    print("MODEL MEAN PRECISION NO EMPTY STD", model_mean_precision_no_empty_std)
    print()
    print("MODEL MEAN RECALL NO EMPTY", model_mean_recall_no_empty)
    print("MODEL MEAN RECALL NO EMPTY STD", model_mean_recall_no_empty_std)
    print()
    print("MODEL MEAN F1 NO EMPTY", model_mean_f1_no_empty)
    print("MODEL MEAN F1 NO EMPTY STD", model_mean_f1_no_empty_std)
    print()
    print("MODEL ACCURACY EXCLUDING CASES ",  model_accuracy_excluding_cases)
    print("MODEL ACCURACY EXCLUDING CASES STD ",  model_accuracy_excluding_cases_std)
    print()
    print("MODEL PRECISION EXCLUDING CASES ",  model_precision_excluding_cases)
    print("MODEL PRECISION EXCLUDING CASES STD ",  model_precision_excluding_cases_std)
    print()
    print("MODEL RECALL EXCLUDING CASES ", model_recall_excluding_cases)
    print("MODEL RECALL EXCLUDING CASES STD ", model_recall_excluding_cases_std)
    print()
    print("MODEL ACCURACY NO EMPTY ",  model_accuracy_no_empty)
    print("MODEL ACCURACY NO EMPTY STD ",  model_accuracy_no_empty_std)
    print()
    print("MODEL PRECISION NO EMPTY",  model_precision_no_empty)
    print("MODEL PRECISION NO EMPTY STD ",  model_precision_no_empty_std)
    print()
    print("MODEL RECALL NO EMPTY ", model_recall_no_empty)
    print("MODEL RECALL NO EMPTY STD ", model_recall_no_empty_std)
    print()
    print("MODEL F1 NO EMPTY ", model_f1_no_empty)
    print("MODEL F1 NO EMPTY STD ", model_f1_no_empty_std)
    print()
    

    
    



    if get_scores_for_statistics:
            tp = torch.cat([tp for tp in TP])
            fp = torch.cat([fp for fp in FP])
            fn = torch.cat([fn for fn in FN])
            tn = torch.cat([tn for tn in TN])


            if get_only_masses:
                # Create a mask where tp + fn is not equal to 0
                mask = (tp + fn) != 0
                
                # Apply this mask to each tensor to filter out the desired values
                tp = tp[mask]
                fp = fp[mask]
                fn = fn[mask]
                tn = tn[mask]

            miou_scores = compute_mean_iou_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mdice_scores = compute_mean_dice_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mf1_scores = compute_f1_from_cumulator(tp, fp, fn, tn, exclude_empty=False, is_mean=True, return_std=False,reduce_mean=False)


            scores_dict = {
                 'miou': miou_scores.squeeze().tolist(),
                 'mdice': mdice_scores.squeeze().tolist(),
                 "mf1": mf1_scores.squeeze().tolist(),
            }
            return scores_dict

In [48]:
def test_patient_aware_ensemble(model_whole_path, model_patches_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key, filter=False, get_scores_for_statistics=False,get_only_masses=False):

    model_whole = BreastModel.load_from_checkpoint(model_whole_path, strict=False)
    model_patches = BreastModel2.load_from_checkpoint(model_patches_path, strict=False)
    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy  = []
    model_precision = []
    model_recall = []

    for patient_id in patient_ids:

         cum_iou=[]
         cum_precision=[]
         cum_recall=[]
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
    
         TP = []
         FP = []
         FN =[]
         TN = []
         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)
        
         prev_had_mask=False

        

         for idx, e in tqdm(enumerate(patches_ds), total = len(patches_ds)):
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)
    
            merged_label_for_fusion = torch.zeros(original_image.shape)
    
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)
            
            ## FIRST MODEL
            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model_patches = model_patches.to("cuda")
                        model_patches.eval()
                        logits = model_patches(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        #label = pr_mask
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label_for_fusion += label
    
    
            original_image = np.transpose(original_image, (1,2,0))
    
            label_patches_for_fusion = merged_label_for_fusion[0]
    
            # SECOND MODEL

            if fusion_dataset[idx][0]['keep_sample'] or fusion_dataset[idx][1]['keep_sample'] or fusion_dataset[idx][2]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model_whole = model_whole.to("cuda")
                    model_whole.eval()
                    
                    masks = model_whole(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
    
                label_whole = masks[0]
                label_whole = (label_whole > 0.4).int()
                label_whole = reverse_transformations(whole_image_ds[idx], label_whole, mode='whole')
                label_whole = label_whole.squeeze()
        
                label_whole_for_fusion= masks[0]
                label_whole_for_fusion = reverse_transformations(whole_image_ds[idx], label_whole_for_fusion, mode='whole')
                
                # Plot the first image
            else:
                label_whole_for_fusion = torch.zeros(original_image.shape)
                
            original_image_squeeze = np.load(e[0]['image_meta_dict']['filename_or_obj'])

            fusion = fuse_segmentations(label_whole_for_fusion.numpy(), label_patches_for_fusion.numpy(), prob_threshold=0.4, boost_factor=3, penalty_factor=0.5, kernel_size=150)
            
            fusion = (fusion > 0.4).astype(int)
    
            fusion = np.expand_dims(fusion, 0)
            pred_label=fusion

            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(fusion,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(((label_whole_for_fusion > 0.4).numpy().astype(int)).squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(((label_patches_for_fusion > 0.4).numpy().astype(int)).squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label.squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion.squeeze() , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())
            


         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
 
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)


         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             

             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP =  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP =  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN =  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN =  [torch.tensor([[elem]]) for elem in tn.squeeze()]



         
         detection_iou =  np.array(calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)).mean()

         mean_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)
         mean_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True)

         iou_mass_volume = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise')
         iou_mass_volume_no_empty = compute_iou_npy(gt_label_volume, predicted_label_volume, class_id=1, reduction='micro-imagewise', exclude_empty=True)
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='micro-imagewise', class_id=1)
         dice_mass_volume_no_empty = compute_dice_score_npy(gt_label_volume, predicted_label_volume, reduction='micro-imagewise',class_id=1, exclude_empty=True)

        
         accuracy = compute_accuracy_from_cumulator(TP, FP, FN, TN)
         precision = compute_precision_from_cumulator(TP, FP, FN, TN)
         recall = compute_recall_from_cumulator(TP, FP, FN, TN)
        
         print("CLASS MEAN IOU", mean_iou)
         print("CLASS MEAN DICE", mean_dice)
         print("DIOU", detection_iou)
         print("IOU MASS VOLUME", iou_mass_volume)
         print("IOU MASS VOLUME NO EMPTY", iou_mass_volume_no_empty)
         print("DICE MASS VOLUME ", dice_mass_volume)
         print("DICE MASS VOLUME NO EMPTY ", dice_mass_volume_no_empty)

         print("ACCURACY ", accuracy)
         print("PRECISION ", precision)
         print("RECALL", recall)
         
         print()
         model_class_mean_iou.append(mean_iou)
         model_class_mean_dice.append(mean_dice)
         model_detection_iou.append(detection_iou)
        
         model_iou_mass_volume.append(iou_mass_volume)
         model_iou_mass_volume_no_empty.append(iou_mass_volume_no_empty)
        
         model_dice_mass_volume.append(dice_mass_volume)
         model_dice_mass_volume_no_empty.append(dice_mass_volume_no_empty)

         model_accuracy.append(accuracy)
         model_precision.append(precision)
         model_recall.append(recall)

    model_class_mean_iou = np.array(model_class_mean_iou).mean()
    model_class_mean_dice = np.array(model_class_mean_dice).mean()
    model_detection_iou = np.array(model_detection_iou).mean()
    
    model_iou_mass_volume = np.array(model_iou_mass_volume).mean()
    model_iou_mass_volume_no_empty = np.array(model_iou_mass_volume_no_empty).mean()
    
    model_dice_mass_volume = np.array(model_dice_mass_volume).mean()
    model_dice_mass_volume_no_empty = np.array(model_dice_mass_volume_no_empty).mean()

    model_accuracy = np.array(model_accuracy).mean()
    model_precision = np.array(model_precision).mean()
    model_recall = np.array(model_recall).mean()

    print("MODEL CLASS MEAN IOU", model_class_mean_iou)
    print("MODEL CLASS MEAN DICE", model_class_mean_dice)
    print("MODEL DIOU", model_detection_iou)
        
    print("MODEL IOU MASS VOLUME", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME NO EMPTY", model_iou_mass_volume_no_empty)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)

    print("MODEL ACCURACY ", model_accuracy)
    print("MODEL PRECISION ", model_precision)
    print("MODEL RECALL", model_recall)

In [117]:
import torch

def print_model_params_and_memory(model):
    # Function to calculate memory size in MB
    def get_model_memory(model):
        total_params = sum(p.numel() for p in model.parameters())
        param_size = next(model.parameters()).element_size()  # Size of one parameter (in bytes)
        total_memory = total_params * param_size / (1024 ** 2)  # Memory in MB
        return total_memory
    
    # Calculate total parameters
    total_params = sum(p.numel() for p in model.parameters())
    
    # Calculate trainable parameters
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    # Calculate non-trainable parameters
    non_trainable_params = sum(p.numel() for p in model.parameters() if not p.requires_grad)
    
    # Get model memory size
    model_memory = get_model_memory(model)
    
    # Print the results in millions
    print(f"Total parameters: {total_params / 1e6:.2f} million")
    print(f"Trainable parameters: {trainable_params / 1e6:.2f} million")
    print(f"Non-trainable parameters: {non_trainable_params / 1e6:.2f} million")
    print(f"Memory required (in MB): {model_memory:.2f} MB")


In [118]:
def test_dataset_aware_ensemble(model_whole_path, model_patches_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key, filter=False, use_decoder_attention=True, use_simple_fusion=False,get_scores_for_statistics=False,get_only_masses=False):

    model_whole = BreastModel.load_from_checkpoint(model_whole_path, strict=False, use_simple_fusion=use_simple_fusion, use_decoder_attention=use_decoder_attention)
    model_patches = BreastModel2.load_from_checkpoint(model_patches_path, strict=False)


    print_model_params_and_memory(model_whole)


    print("aa")

    print_model_params_and_memory(model_patches)

    print("aa")

    
    model_class_mean_iou = []
    model_class_mean_dice = []
    model_detection_iou = []
    
    model_iou_mass_volume = []
    model_iou_mass_volume_no_empty = []
    
    model_dice_mass_volume = []
    model_dice_mass_volume_no_empty = []

    model_accuracy = []
    model_precision = []
    model_recall = []


    
    TP = []
    FP = []
    FN = []
    TN = []

    detection_iou =  []

    # Initialize performance metrics
    inference_times = []  # To store the time for each volume inference
    inference_times_slice = []  # To store the time for each volume inference
    memory_usage = []  # To store the memory usage for each volume

    for patient_id in patient_ids:

         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]

         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)
        
         prev_had_mask=False

         # Measure inference time per slice
         start_time = time.time()

         

         
         for idx, e in tqdm(enumerate(patches_ds), total = len(patches_ds)):

            start_time_slice = time.time()


             
            original_image = np.load(e[0]['image_meta_dict']['filename_or_obj'])
            original_image = np.expand_dims(original_image,0)
    
            merged_label_for_fusion = torch.zeros(original_image.shape)
    
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)
            
            ## FIRST MODEL
            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model_patches = model_patches.to("cuda")
                        model_patches.eval()
                        logits = model_patches(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        #label = pr_mask
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label_for_fusion += label
    
    
            original_image = np.transpose(original_image, (1,2,0))
    
            label_patches_for_fusion = merged_label_for_fusion[0]
    
            # SECOND MODEL

            if fusion_dataset[idx][0]['keep_sample'] or fusion_dataset[idx][1]['keep_sample'] or fusion_dataset[idx][2]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model_whole = model_whole.to("cuda")
                    model_whole.eval()
                    
                    masks = model_whole(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
                    
    
                label_whole = masks[0]
                label_whole = (label_whole > 0.4).int()
                label_whole = reverse_transformations(whole_image_ds[idx], label_whole, mode='whole')
                label_whole = label_whole.squeeze()
        
                label_whole_for_fusion= masks[0]
                label_whole_for_fusion = reverse_transformations(whole_image_ds[idx], label_whole_for_fusion, mode='whole')
                
                # Plot the first image
            else:
                label_whole_for_fusion = torch.zeros(original_image.shape)
                
            original_image_squeeze = np.load(e[0]['image_meta_dict']['filename_or_obj'])

            fusion = fuse_segmentations(label_whole_for_fusion.numpy(), label_patches_for_fusion.numpy(), prob_threshold=0.4, boost_factor=3, penalty_factor=0.5, kernel_size=150)
            
            fusion = (fusion > 0.4).astype(int)
    
            fusion = np.expand_dims(fusion, 0)
            pred_label=fusion

            if not filter:
                tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(np.expand_dims(fusion,0).astype(int)), torch.tensor(np.expand_dims(gt_label,0).astype(int)), mode = "binary")
                TP.append(tp)
                FP.append(fp)
                FN.append(fn)
                TN.append(tn)
             
            """plt.figure(figsize=(15, 10))
    
            plt.subplot(2, 2, 1)
            plt.imshow(original_image.squeeze(),  cmap='gray')  # convert CHW -> HWC
            plt.title("Image")
            plt.axis("off")
        
            plt.subplot(2, 2, 2)
            plt.imshow(((label_whole_for_fusion > 0.4).numpy().astype(int)).squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Whole")
            plt.axis("off")
        
            plt.subplot(2, 2, 3)
            plt.imshow(((label_patches_for_fusion > 0.4).numpy().astype(int)).squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("Patch")
            plt.axis("off")
    
            plt.subplot(2, 2, 4)
            plt.imshow(gt_label.squeeze(), cmap='gray') # just squeeze classes dim, because we have only one class
            plt.title("GT")
            plt.axis("off")
        
            plt.show()
    
            plt.imshow(fusion.squeeze() , cmap='gray')
            plt.show()"""
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

            # Measure inference time after processing  the voluyme
            end_time_slice = time.time()
            inference_times_slice.append(end_time_slice - start_time_slice)

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)

        
         # Measure inference time after processing  the volume
         end_time = time.time()
         inference_times.append(end_time - start_time)
         memory_allocated = torch.cuda.memory_allocated()
         memory_usage.append(memory_allocated)
    
         if filter:
             print("filtering")
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N
             # H x W x N -> N x H x W -> N x 1 x H x W
             predicted_label_volume_for_stats = np.transpose(predicted_label_volume, (2, 0, 1))
             predicted_label_volume_for_stats = np.expand_dims(predicted_label_volume_for_stats, 1)  # N x 1 x H x W

             gt_label_volume_for_stats = np.transpose(gt_label_volume, (2, 0, 1))
             gt_label_volume_for_stats = np.expand_dims(gt_label_volume_for_stats, 1)  # N x 1 x H x W             
            
             
             tp, fp, fn, tn = smp.metrics.get_stats(torch.tensor(predicted_label_volume_for_stats.astype(int)), torch.tensor(gt_label_volume_for_stats.astype(int)), mode = "binary")
             TP +=  [torch.tensor([[elem]]) for elem in tp.squeeze()]
             FP +=  [torch.tensor([[elem]]) for elem in fp.squeeze()]
             FN +=  [torch.tensor([[elem]]) for elem in fn.squeeze()]
             TN +=  [torch.tensor([[elem]]) for elem in tn.squeeze()]

        


         detection_iou+=calculate_mass_detection_imagewise_volume(predicted_label_volume.astype(int), gt_label_volume)




    # Calculate mean and standard deviation for inference time and memory usage
    mean_inference_time = np.mean(inference_times)
    std_inference_time = np.std(inference_times)

    mean_inference_time_slice = np.mean(inference_times_slice)
    std_inference_time_slice = np.std(inference_times_slice)
    
    mean_memory_usage = np.mean(memory_usage)
    std_memory_usage = np.std(memory_usage)
        
    # Frames per second (inference speed)
    fps = 1 / mean_inference_time_slice
        
    # Final outputs
    print(f"Mean Inference Time per Volume: {mean_inference_time:.4f} seconds")
    print(f"Standard Deviation of Inference Time per Volume: {std_inference_time:.4f} seconds")

    print(f"Mean Inference Time per Slice: {mean_inference_time_slice:.4f} seconds")
    print(f"Standard Deviation of Inference Time per Slice: {std_inference_time_slice:.4f} seconds")

    print(f"Frames per second (FPS): {fps:.2f}")
    print(f"Mean Memory Usage per Volume: {mean_memory_usage / (1024**2):.2f} MB")  # Convert to MB
    print(f"Standard Deviation of Memory Usage: {std_memory_usage / (1024**2):.2f} MB")  # Convert to MB

    model_detection_iou = np.array(detection_iou).mean()
    model_detection_iou_std = np.array(detection_iou).std()
    
    model_class_mean_iou, model_class_std_iou = compute_mean_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_class_mean_dice, model_class_std_dice = compute_mean_dice_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)

    model_iou_mass_volume , model_iou_mass_volume_std = compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=False, return_std=True)
    model_iou_mass_volume_no_empty, model_iou_mass_volume_no_empty_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, return_std=True)
    model_iou_mass_volume_no_empty_optimistic, model_iou_mass_volume_no_empty_optimistic_std =compute_iou_imagewise_from_cumulator(TP, FP, FN, TN, exclude_empty=True, exclude_empty_only_gt=True, return_std=True)
    
    
    model_dice_mass_volume, model_dice_mass_volume_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=False, return_std=True)
    model_dice_mass_volume_no_empty, model_dice_mass_volume_no_empty_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, return_std=True)
    model_dice_mass_volume_no_empty_optimistic, model_dice_mass_volume_no_empty_optimistic_std = compute_dice_imagewise_from_cumulator(TP, FP, FN, TN,exclude_empty=True, exclude_empty_only_gt=True,return_std=True)
    
    model_mean_accuracy_no_empty, model_mean_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_precision_no_empty,model_mean_precision_no_empty_std = compute_precision_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_recall_no_empty, model_mean_recall_no_empty_std = compute_recall_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
    model_mean_f1_no_empty, model_mean_f1_no_empty_std = compute_f1_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=True, return_std=True)
        
    model_accuracy_excluding_cases, model_accuracy_excluding_cases_std = compute_accuracy_excluding_cases(TP, FP, FN, TN, return_std=True)
    model_precision_excluding_cases,model_precision_excluding_cases_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)
    model_recall_excluding_cases,model_recall_excluding_cases_std  =compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True)

    model_accuracy_no_empty, model_accuracy_no_empty_std = compute_accuracy_from_cumulator(TP, FP, FN, TN, exclude_empty=True, is_mean=False, return_std=True)
    model_precision_no_empty,model_precision_no_empty_std =compute_precision_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    model_recall_no_empty,model_recall_no_empty_std  = compute_recall_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)

    model_f1_no_empty,model_f1_no_empty_std = compute_f1_excluding_cases_from_cumulator(TP, FP, FN, TN, return_std=True,exclude_only_zero_denominator=True)
    
    
    print("MODEL CLASS MEAN IOU ", model_class_mean_iou)
    print("MODEL CLASS STD IOU ", model_class_std_iou)
    print()
    print("MODEL CLASS MEAN DICE ", model_class_mean_dice)
    print("MODEL CLASS STD DICE ", model_class_std_dice)
    print()
    print("MODEL DIOU", model_detection_iou)
    print("MODEL DIOU STD ", model_detection_iou_std) 
    print()
    print("MODEL IOU MASS VOLUME ", model_iou_mass_volume)
    print("MODEL IOU MASS VOLUME STD ", model_iou_mass_volume_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY ", model_iou_mass_volume_no_empty)
    print("MODEL IOU MASS VOLUME NO EMPTY STD ", model_iou_mass_volume_no_empty_std)
    print()
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC ", model_iou_mass_volume_no_empty_optimistic)
    print("MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_iou_mass_volume_no_empty_optimistic_std)
    
    print("MODEL DICE MASS VOLUME ", model_dice_mass_volume)
    print("MODEL DICE MASS VOLUME STD ", model_dice_mass_volume_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY ", model_dice_mass_volume_no_empty)
    print("MODEL DICE MASS VOLUME NO EMPTY STD ", model_dice_mass_volume_no_empty_std)
    print()
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC ", model_dice_mass_volume_no_empty_optimistic)
    print("MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD ", model_dice_mass_volume_no_empty_optimistic_std)
    print() 
    print("MODEL MEAN ACCURACY NO EMPTY", model_mean_accuracy_no_empty)
    print("MODEL MEAN ACCURACY NO EMPTY STD", model_mean_accuracy_no_empty_std)
    print()                                                                              
    print("MODEL MEAN PRECISION NO EMPTY", model_mean_precision_no_empty)
    print("MODEL MEAN PRECISION NO EMPTY STD", model_mean_precision_no_empty_std)
    print()
    print("MODEL MEAN RECALL NO EMPTY", model_mean_recall_no_empty)
    print("MODEL MEAN RECALL NO EMPTY STD", model_mean_recall_no_empty_std)
    print()
    print("MODEL MEAN F1 NO EMPTY", model_mean_f1_no_empty)
    print("MODEL MEAN F1 NO EMPTY STD", model_mean_f1_no_empty_std)
    print()
    print("MODEL ACCURACY EXCLUDING CASES ",  model_accuracy_excluding_cases)
    print("MODEL ACCURACY EXCLUDING CASES STD ",  model_accuracy_excluding_cases_std)
    print()
    print("MODEL PRECISION EXCLUDING CASES ",  model_precision_excluding_cases)
    print("MODEL PRECISION EXCLUDING CASES STD ",  model_precision_excluding_cases_std)
    print()
    print("MODEL RECALL EXCLUDING CASES ", model_recall_excluding_cases)
    print("MODEL RECALL EXCLUDING CASES STD ", model_recall_excluding_cases_std)
    print()
    print("MODEL ACCURACY NO EMPTY ",  model_accuracy_no_empty)
    print("MODEL ACCURACY NO EMPTY STD ",  model_accuracy_no_empty_std)
    print()
    print("MODEL PRECISION NO EMPTY",  model_precision_no_empty)
    print("MODEL PRECISION NO EMPTY STD ",  model_precision_no_empty_std)
    print()
    print("MODEL RECALL NO EMPTY ", model_recall_no_empty)
    print("MODEL RECALL NO EMPTY STD ", model_recall_no_empty_std)
    print()
    print("MODEL F1 NO EMPTY ", model_f1_no_empty)
    print("MODEL F1 NO EMPTY STD ", model_f1_no_empty_std)
    print()
        


    if get_scores_for_statistics:
            tp = torch.cat([tp for tp in TP])
            fp = torch.cat([fp for fp in FP])
            fn = torch.cat([fn for fn in FN])
            tn = torch.cat([tn for tn in TN])


            if get_only_masses:
                # Create a mask where tp + fn is not equal to 0
                mask = (tp + fn) != 0
                
                # Apply this mask to each tensor to filter out the desired values
                tp = tp[mask]
                fp = fp[mask]
                fn = fn[mask]
                tn = tn[mask]


            miou_scores = compute_mean_iou_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mdice_scores = compute_mean_dice_imagewise_from_cumulator(tp, fp, fn, tn, exclude_empty=False, return_std=False,reduce_mean=False)
            mf1_scores = compute_f1_from_cumulator(tp, fp, fn, tn, exclude_empty=False, is_mean=True, return_std=False,reduce_mean=False)


            scores_dict = {
                 'miou': miou_scores.squeeze().tolist(),
                 'mdice': mdice_scores.squeeze().tolist(),
                 "mf1": mf1_scores.squeeze().tolist(),
            }
            return scores_dict

# TEST DATASET AWARE

In [50]:
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus

In [85]:
import time

### TEST MULTI-UNET SUB CABFL

In [74]:
scores_for_statistics_fusion = test_dataset_aware_fusion(model_path="PRIVATE-FUSION-SUB-CABL-FINAL.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", strict=True, get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.58887863
MODEL CLASS STD IOU  0.15059951

MODEL CLASS MEAN DICE  0.60679895
MODEL CLASS STD DICE  0.1754792

MODEL DIOU 0.8334518909549894
MODEL DIOU STD  0.29091198719192896

MODEL IOU MASS VOLUME  0.5803412795066833
MODEL IOU MASS VOLUME STD  0.4636002779006958

MODEL IOU MASS VOLUME NO EMPTY  0.1788276880979538
MODEL IOU MASS VOLUME NO EMPTY STD  0.3012268543243408

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5764801502227783
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2515328526496887
MODEL DICE MASS VOLUME  0.5983843207359314
MODEL DICE MASS VOLUME STD  0.4661327004432678

MODEL DICE MASS VOLUME NO EMPTY  0.21413369476795197
MODEL DICE MASS VOLUME NO EMPTY STD  0.35099151730537415

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6902948617935181
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2616250813007355

MODEL MEAN ACCURACY NO EMPTY 0.5891403257846832
MODEL MEAN ACCURACY NO EMPTY STD 0.15114951133728027

MODEL MEAN PRECISION NO EMPTY 0.608280591

In [75]:
scores_for_statistics_fusion.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [76]:
save_to_json(scores_for_statistics_fusion, "scores_for_statistics_fcn.json")

Dictionary successfully saved to scores_for_statistics_fcn.json


In [77]:
print("ciao")

ciao


### TEST MULTI-UNET SUB CABFL NO DECODER ATTENTION

In [51]:
scores_for_statistics_fusion_nda = test_dataset_aware_fusion(model_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-FINAL.ckpt", 
                              patient_ids=x_test,
                              datasets=datasets, 
                              use_decoder_attention=False,
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", strict=True, get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.58709514
MODEL CLASS STD IOU  0.15093568

MODEL CLASS MEAN DICE  0.60390264
MODEL CLASS STD DICE  0.17522915

MODEL DIOU 0.8214560608616285
MODEL DIOU STD  0.3010394225147611

MODEL IOU MASS VOLUME  0.5586973428726196
MODEL IOU MASS VOLUME STD  0.46707409620285034

MODEL IOU MASS VOLUME NO EMPTY  0.17504799365997314
MODEL IOU MASS VOLUME NO EMPTY STD  0.30202382802963257

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5906805992126465
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.24960079789161682
MODEL DICE MASS VOLUME  0.5764503479003906
MODEL DICE MASS VOLUME STD  0.4708462059497833

MODEL DICE MASS VOLUME NO EMPTY  0.20823468267917633
MODEL DICE MASS VOLUME NO EMPTY STD  0.3505549132823944

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.7026655077934265
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.25933638215065

MODEL MEAN ACCURACY NO EMPTY 0.5872945711016655
MODEL MEAN ACCURACY NO EMPTY STD 0.1514940857887268

MODEL MEAN PRECISION NO EMPTY 0.607078686

In [52]:
save_to_json(scores_for_statistics_fusion_nda, "scores_for_statistics_fusion_final_nda.json")

Dictionary successfully saved to scores_for_statistics_fusion_final_nda.json


### TEST MULTI-UNET SUB CABFL USE SIMPLE FUSION

In [78]:
scores_for_statistics_fusion_sf = test_dataset_aware_fusion(model_path="PRIVATE-FUSION-SUB-CABL-NO-FUSION-FINAL.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              use_simple_fusion=True,
                              patches_dataset_key="patches_sub_test_ds", strict=True, get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.6434665
MODEL CLASS STD IOU  0.1672613

MODEL CLASS MEAN DICE  0.67304754
MODEL CLASS STD DICE  0.19436733

MODEL DIOU 0.7120013631668699
MODEL DIOU STD  0.36402384669588006

MODEL IOU MASS VOLUME  0.7919582724571228
MODEL IOU MASS VOLUME STD  0.37096554040908813

MODEL IOU MASS VOLUME NO EMPTY  0.2878640294075012
MODEL IOU MASS VOLUME NO EMPTY STD  0.33460259437561035

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5304722785949707
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2786027789115906
MODEL DICE MASS VOLUME  0.8091059327125549
MODEL DICE MASS VOLUME STD  0.36394983530044556

MODEL DICE MASS VOLUME NO EMPTY  0.3465612828731537
MODEL DICE MASS VOLUME NO EMPTY STD  0.38880589604377747

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6386389136314392
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.303384393453598

MODEL MEAN ACCURACY NO EMPTY 0.6437960863113403
MODEL MEAN ACCURACY NO EMPTY STD 0.16775724291801453

MODEL MEAN PRECISION NO EMPTY 0.716367319

In [79]:
scores_for_statistics_fusion_sf.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [80]:
save_to_json(scores_for_statistics_fusion_sf, "scores_for_statistics_fusion_final_sf.json")

Dictionary successfully saved to scores_for_statistics_fusion_final_sf.json


### TEST MULTI-UNET SUB CABFL SIMPLE FUSION + NO DECODER ATTENTION

In [81]:
scores_for_statistics_fusion_sf_nda = test_dataset_aware_fusion(model_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-NO-FUSION.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                                                         use_simple_fusion=True,
                                                         use_decoder_attention=False,
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", strict=True, get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.5779559
MODEL CLASS STD IOU  0.13250133

MODEL CLASS MEAN DICE  0.5989
MODEL CLASS STD DICE  0.16153026

MODEL DIOU 0.8688531439746999
MODEL DIOU STD  0.27509447742095167

MODEL IOU MASS VOLUME  0.582230269908905
MODEL IOU MASS VOLUME STD  0.4607909023761749

MODEL IOU MASS VOLUME NO EMPTY  0.15753380954265594
MODEL IOU MASS VOLUME NO EMPTY STD  0.26509344577789307

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.4927718937397003
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.23372112214565277
MODEL DICE MASS VOLUME  0.6026008725166321
MODEL DICE MASS VOLUME STD  0.4607812464237213

MODEL DICE MASS VOLUME NO EMPTY  0.198612779378891
MODEL DICE MASS VOLUME NO EMPTY STD  0.323122501373291

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6212685108184814
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2530020475387573

MODEL MEAN ACCURACY NO EMPTY 0.5783647522330284
MODEL MEAN ACCURACY NO EMPTY STD 0.13333997130393982

MODEL MEAN PRECISION NO EMPTY 0.589352354407310

In [82]:
scores_for_statistics_fusion_sf_nda.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [83]:
save_to_json(scores_for_statistics_fusion_sf_nda, "scores_for_statistics_fusion_sf_nda.json")

Dictionary successfully saved to scores_for_statistics_fusion_sf_nda.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES

In [104]:
scores_for_statistics_ensemble = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.62867385
MODEL CLASS STD IOU  0.16370061

MODEL CLASS MEAN DICE  0.65572965
MODEL CLASS STD DICE  0.19129984

MODEL DIOU 0.8682983429003069
MODEL DIOU STD  0.271822030176516

MODEL IOU MASS VOLUME  0.7384881377220154
MODEL IOU MASS VOLUME STD  0.40412038564682007

MODEL IOU MASS VOLUME NO EMPTY  0.2587006390094757
MODEL IOU MASS VOLUME NO EMPTY STD  0.32734376192092896

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5756822824478149
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.23657383024692535
MODEL DICE MASS VOLUME  0.7573391199111938
MODEL DICE MASS VOLUME STD  0.3995882570743561

MODEL DICE MASS VOLUME NO EMPTY  0.31213682889938354
MODEL DICE MASS VOLUME NO EMPTY STD  0.38259050250053406

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6945930123329163
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.24510321021080017

MODEL MEAN ACCURACY NO EMPTY 0.6291116625070572
MODEL MEAN ACCURACY NO EMPTY STD 0.16422304511070251

MODEL MEAN PRECISION NO EMPTY 0.647310

In [105]:
save_to_json(scores_for_statistics_ensemble, "scores_for_statistics_ensemble.json")

Dictionary successfully saved to scores_for_statistics_ensemble.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES -filtered

In [106]:
scores_for_statistics_ensemble_filtered = test_dataset_aware_ensemble(
                               model_whole_path="PRIVATE-FUSION-SUB-CABL-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test,
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True, filter=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 82
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
fine
AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

filtering
num features: 290
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 56
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
fine
GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

filtering
num features: 31
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
fine
FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

filtering
num features: 12
1
2
3
4
5
6
7
8
9
10
11
12
fine
D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 151
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
fine
MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

filtering
num features: 34
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
fine
OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 84
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
fine
BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

filtering
num features: 9
1
2
3
4
5
6
7
8
9
fine
D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 138
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
fine
SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

filtering
num features: 25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
fine
CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
fine
RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

filtering
num features: 66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
fine
LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

filtering
num features: 36
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
fine
RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

filtering
num features: 24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
fine
SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

filtering
num features: 9
1
2
3
4
5
6
7
8
9
fine
LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

filtering
num features: 20
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
fine
PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

filtering
num features: 23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
fine
HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

filtering
num features: 124
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
fine
CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

filtering
num features: 41
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
fine
GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

filtering
num features: 66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
fine
MODEL CLASS MEAN IOU  0.64619195
MODEL CLASS STD IOU  0.16802692

MODEL CLASS MEAN DICE  0.67626256
MODEL CLASS STD DICE  0.19542588

MODEL DIOU 0.8661127920923898
MODEL DIOU STD  0.2729700412436063

MODEL IOU MASS VOLUME  0.777884840965271
MODEL IOU MASS VOLUME STD  0.3782494366168976

MODEL IOU MASS VOLUME NO EMPTY  0.29373177886009216
MODEL IOU MASS VOLUME NO EMPTY STD  0.33596840500831604

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5827028751373291
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.23566047847270966
MODEL DICE MASS VOLUME  0.7965869903564453
MODEL DICE MASS VOLUME STD  0.37179121375083923

MODEL DICE MASS VOLUME NO EMPTY  0.3532000184059143
MODEL DICE MASS VOLUME NO EMPTY STD  0.39083045721054077

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.70067

In [107]:
save_to_json(scores_for_statistics_ensemble_filtered, "scores_for_statistics_ensemble_filtered.json")

Dictionary successfully saved to scores_for_statistics_ensemble_filtered.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES NO FUSION

In [90]:
scores_for_statistics_ensemble_nf = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-FUSION-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=True,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.67957485
MODEL CLASS STD IOU  0.17598999

MODEL CLASS MEAN DICE  0.71292734
MODEL CLASS STD DICE  0.20100383

MODEL DIOU 0.8016295954445445
MODEL DIOU STD  0.32977545214953796

MODEL IOU MASS VOLUME  0.8356460332870483
MODEL IOU MASS VOLUME STD  0.3316093385219574

MODEL IOU MASS VOLUME NO EMPTY  0.36033472418785095
MODEL IOU MASS VOLUME NO EMPTY STD  0.35189858078956604

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.584010899066925
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.26471084356307983
MODEL DICE MASS VOLUME  0.8526331186294556
MODEL DICE MASS VOLUME STD  0.3229965269565582

MODEL DICE MASS VOLUME NO EMPTY  0.4264479875564575
MODEL DICE MASS VOLUME NO EMPTY STD  0.4019955098628998

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6911637187004089
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.28098264336586

MODEL MEAN ACCURACY NO EMPTY 0.6800151616334915
MODEL MEAN ACCURACY NO EMPTY STD 0.17638441920280457

MODEL MEAN PRECISION NO EMPTY 0.7182077467

In [None]:
scores_for_statistics_ensemble_nf.keys()

In [91]:
save_to_json(scores_for_statistics_ensemble_nf, "scores_for_statistics_ensemble_nf.json")

Dictionary successfully saved to scores_for_statistics_ensemble_nf.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES -filtered NO FUSION

In [119]:
scores_for_statistics_ensemble_sf_filtered = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-FUSION-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=True,
                              filter=True,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

Total parameters: 21.56 million
Trainable parameters: 21.56 million
Non-trainable parameters: 0.00 million
Memory required (in MB): 82.25 MB
aa
Total parameters: 32.52 million
Trainable parameters: 32.52 million
Non-trainable parameters: 0.00 million
Memory required (in MB): 124.04 MB
aa
LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [101]:
scores_for_statistics_ensemble_nf_filtered.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [102]:
save_to_json(scores_for_statistics_ensemble_nf_filtered, "scores_for_statistics_ensemble_sf_filtered.json")

Dictionary successfully saved to scores_for_statistics_ensemble_sf_filtered.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES NO DECODER ATTENTION

In [105]:
scores_for_statistics_ensemble_nda = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=False,
                              use_decoder_attention=False,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

Mean Inference Time per Volume: 47.6399 seconds
Standard Deviation of Inference Time per Volume: 15.6155 seconds
Mean Inference Time per Slice: 0.2495 seconds
Standard Deviation of Inference Time per Slice: 0.1094 seconds
Frames per second (FPS): 4.01
Mean Memory Usage per Volume: 542.38 MB
Standard Deviation of Memory Usage: 254.26 MB
MODEL CLASS MEAN IOU  0.62876403
MODEL CLASS STD IOU  0.16610433

MODEL CLASS MEAN DICE  0.6545451
MODEL CLASS STD DICE  0.19311067

MODEL DIOU 0.8522614071450948
MODEL DIOU STD  0.2834916392194654

MODEL IOU MASS VOLUME  0.732399046421051
MODEL IOU MASS VOLUME STD  0.40822914242744446

MODEL IOU MASS VOLUME NO EMPTY  0.25871676206588745
MODEL IOU MASS VOLUME NO EMPTY STD  0.3322497010231018

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5891362428665161
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.23814257979393005
MODEL DICE MASS VOLUME  0.7507985830307007
MODEL DICE MASS VOLUME STD  0.4047059118747711

MODEL DICE MASS VOLUME NO EMPTY  0.30968543887

In [108]:
torch.cuda.memory_allocated()

264774144

In [106]:
scores_for_statistics_ensemble_nda.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [107]:
save_to_json(scores_for_statistics_ensemble_nda, "scores_for_statistics_ensemble_nda.json")

Dictionary successfully saved to scores_for_statistics_ensemble_nda.json


### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES -filtered NO DECODER ATTENTION (!)

In [104]:
scores_for_statistics_ensemble_nda_filtered = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-FINAL.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=False,
                              use_decoder_attention=False,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              filter=True,
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.
LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
scores_for_statistics_ensemble_nda_filtered.keys()

In [None]:
save_to_json(scores_for_statistics_ensemble_nda_filtered, "scores_for_statistics_ensemble_nda_filtered.json")

### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES NO DECODER ATTENTION NO FUSION

In [97]:
scores_for_statistics_ensemble_nda_nf = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-NO-FUSION.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=True,
                              use_decoder_attention=False,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.61095583
MODEL CLASS STD IOU  0.14788763

MODEL CLASS MEAN DICE  0.6396286
MODEL CLASS STD DICE  0.17811231

MODEL DIOU 0.8842068610301999
MODEL DIOU STD  0.26653173263055463

MODEL IOU MASS VOLUME  0.7201789617538452
MODEL IOU MASS VOLUME STD  0.4128887951374054

MODEL IOU MASS VOLUME NO EMPTY  0.22376088798046112
MODEL IOU MASS VOLUME NO EMPTY STD  0.2957236170768738

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5088112354278564
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2319953292608261
MODEL DICE MASS VOLUME  0.7405183911323547
MODEL DICE MASS VOLUME STD  0.40643733739852905

MODEL DICE MASS VOLUME NO EMPTY  0.28018370270729065
MODEL DICE MASS VOLUME NO EMPTY STD  0.35622110962867737

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6371114253997803
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.24726735055446625

MODEL MEAN ACCURACY NO EMPTY 0.6115471422672272
MODEL MEAN ACCURACY NO EMPTY STD 0.14861223101615906

MODEL MEAN PRECISION NO EMPTY 0.623858

In [None]:
scores_for_statistics_ensemble_nda_nf.keys()

In [None]:
save_to_json(scores_for_statistics_ensemble_nda_nf, "scores_for_statistics_ensemble_nda_nf.json")

### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES -filtered NO DECODER ATTENTION NO FUSION

In [103]:
scores_for_statistics_ensemble_nda_nf_filtered = test_dataset_aware_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL-NO-DECODER-ATTENTION-NO-FUSION.ckpt", 
                              model_patches_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              use_simple_fusion=True,
                              use_decoder_attention=False,
                              filter=True,
                              whole_dataset_key="no_thorax_sub_test_ds",
                              patches_dataset_key="patches_sub_test_ds", get_scores_for_statistics=True
                             )

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

filtering
num features: 118
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
fine
AS0170


  0%|          | 0/192 [00:00<?, ?it/s]


KeyboardInterrupt



In [None]:
scores_for_statistics_ensemble_nda_nf_filtered.keys()

In [None]:
save_to_json(scores_for_statistics_ensemble_nda_nf_filtered, "scores_for_statistics_ensemble_nda_nf_filtered.json")

### RESNET

In [68]:
scores_for_statistics_resnet = test_dataset_aware_no_patches(model_path="RESNET-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="UNet", get_scores_for_statistics=True)

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.6541247
MODEL CLASS STD IOU  0.16112955

MODEL CLASS MEAN DICE  0.69052994
MODEL CLASS STD DICE  0.18732928

MODEL DIOU 0.48139805449367845
MODEL DIOU STD  0.4263533941207131

MODEL IOU MASS VOLUME  0.8684301972389221
MODEL IOU MASS VOLUME STD  0.3053703308105469

MODEL IOU MASS VOLUME NO EMPTY  0.3098789155483246
MODEL IOU MASS VOLUME NO EMPTY STD  0.3217676877975464

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.37265825271606445
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.3179851770401001
MODEL DICE MASS VOLUME  0.8821564316749573
MODEL DICE MASS VOLUME STD  0.2927244305610657

MODEL DICE MASS VOLUME NO EMPTY  0.3818768262863159
MODEL DICE MASS VOLUME NO EMPTY STD  0.37446126341819763

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.45924246311187744
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.36482805013656616

MODEL MEAN ACCURACY NO EMPTY 0.654784083366394
MODEL MEAN ACCURACY NO EMPTY STD 0.16150318086147308

MODEL MEAN PRECISION NO EMPTY 0.83282315

In [69]:
scores_for_statistics_resnet.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [70]:
save_to_json(scores_for_statistics_resnet, "scores_for_statistics_resnet.json")

Dictionary successfully saved to scores_for_statistics_resnet.json


### RESNET PATCHES

In [63]:
scores = test_dataset_aware_patches(model_path="RESNET-PATCHES-PRIVATE.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="patches_sub_test_ds", 
                              filter=False, 
                              arch_name="UNet", get_scores_for_statistics=True)

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.6039572
MODEL CLASS STD IOU  0.1559185

MODEL CLASS MEAN DICE  0.62576866
MODEL CLASS STD DICE  0.18208908

MODEL DIOU 0.8400445631211193
MODEL DIOU STD  0.2847975325890597

MODEL IOU MASS VOLUME  0.659709095954895
MODEL IOU MASS VOLUME STD  0.44199663400650024

MODEL IOU MASS VOLUME NO EMPTY  0.20883384346961975
MODEL IOU MASS VOLUME NO EMPTY STD  0.3120490610599518

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5665929913520813
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.24796366691589355
MODEL DICE MASS VOLUME  0.6782744526863098
MODEL DICE MASS VOLUME STD  0.44071510434150696

MODEL DICE MASS VOLUME NO EMPTY  0.2519976794719696
MODEL DICE MASS VOLUME NO EMPTY STD  0.36431097984313965

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6837019920349121
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2548254728317261

MODEL MEAN ACCURACY NO EMPTY 0.6042191907763481
MODEL MEAN ACCURACY NO EMPTY STD 0.15651041269302368

MODEL MEAN PRECISION NO EMPTY 0.623275727

### UNETPLUSPLUS

In [None]:
from typing import List
from monai.networks.nets import UNet, BasicUNetPlusPlus,SwinUNETR
scores_for_statistics_unetplusplus = test_dataset_aware_no_patches(model_path="unetplusplus_model_final.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="unetplusplus", get_scores_for_statistics=True)

In [71]:
scores_for_statistics_unetplusplus.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [72]:
save_to_json(scores_for_statistics_unetplusplus, "scores_for_statistics_unetplusplus.json")

Dictionary successfully saved to scores_for_statistics_unetplusplus.json


In [73]:
print("ciao")

ciao


### Skinny

In [59]:
from typing import List
scores_for_statistics_skinny = test_dataset_aware_no_patches(model_path="skinny_model_private.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="skinny", get_scores_for_statistics=True)

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.590411
MODEL CLASS STD IOU  0.1497511

MODEL CLASS MEAN DICE  0.6091358
MODEL CLASS STD DICE  0.1755867

MODEL DIOU 0.7936827036693562
MODEL DIOU STD  0.3156589441909492

MODEL IOU MASS VOLUME  0.5905758738517761
MODEL IOU MASS VOLUME STD  0.46091538667678833

MODEL IOU MASS VOLUME NO EMPTY  0.1815723031759262
MODEL IOU MASS VOLUME NO EMPTY STD  0.2996978461742401

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5729680061340332
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.243257537484169
MODEL DICE MASS VOLUME  0.6091229319572449
MODEL DICE MASS VOLUME STD  0.4629926085472107

MODEL DICE MASS VOLUME NO EMPTY  0.218647301197052
MODEL DICE MASS VOLUME NO EMPTY STD  0.35129040479660034

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6899616122245789
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2534303367137909

MODEL MEAN ACCURACY NO EMPTY 0.590598426759243
MODEL MEAN ACCURACY NO EMPTY STD 0.15027032792568207

MODEL MEAN PRECISION NO EMPTY 0.6124422252178192


In [60]:
scores_for_statistics_skinny.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [61]:
save_to_json(scores_for_statistics_skinny, "scores_for_statistics_skinny.json")

Dictionary successfully saved to scores_for_statistics_skinny.json


### FNC-FFNET

In [None]:
from typing import List
scores_for_statistics_fcn= test_dataset_aware_no_patches(model_path="fcn_ffnet_model_final.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="fcn_ffnet", get_scores_for_statistics=True)

In [None]:
save_to_json(scores_for_statistics_fcn, "scores_for_statistics_fcn.json")

In [None]:
print("ciao")

In [None]:
print("ciao")

### SEGNET

In [55]:
from typing import List
from monai.networks.nets import UNet, BasicUNetPlusPlus,SwinUNETR
scores_for_statistics_segnet= test_dataset_aware_no_patches(model_path="segnet_model_private.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="segnet", get_scores_for_statistics=True)

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.584107
MODEL CLASS STD IOU  0.14451733

MODEL CLASS MEAN DICE  0.6023499
MODEL CLASS STD DICE  0.17040415

MODEL DIOU 0.7599096539629479
MODEL DIOU STD  0.329311818890917

MODEL IOU MASS VOLUME  0.5851113200187683
MODEL IOU MASS VOLUME STD  0.4631311297416687

MODEL IOU MASS VOLUME NO EMPTY  0.16894152760505676
MODEL IOU MASS VOLUME NO EMPTY STD  0.2892600893974304

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.5320152044296265
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2651973366737366
MODEL DICE MASS VOLUME  0.6031447649002075
MODEL DICE MASS VOLUME STD  0.46477100253105164

MODEL DICE MASS VOLUME NO EMPTY  0.20506411790847778
MODEL DICE MASS VOLUME NO EMPTY STD  0.3409436345100403

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.6457691788673401
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.28542381525039673

MODEL MEAN ACCURACY NO EMPTY 0.5842891782522202
MODEL MEAN ACCURACY NO EMPTY STD 0.14511148631572723

MODEL MEAN PRECISION NO EMPTY 0.61282841861

In [56]:
scores_for_statistics_segnet.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [57]:
save_to_json(scores_for_statistics_segnet, "scores_for_statistics_segnet_private.json")

Dictionary successfully saved to scores_for_statistics_segnet_private.json


In [58]:
print("ciao")

ciao


### SWIN-UNET

In [51]:
from typing import List
scores_for_statistics_swin= test_dataset_aware_no_patches(model_path="swin_model_final.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="swin_unetr", get_scores_for_statistics=True)



LAXXX


/opt/conda/lib/python3.10/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 't_loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['t_loss'])`.


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

MODEL CLASS MEAN IOU  0.6107712
MODEL CLASS STD IOU  0.16355146

MODEL CLASS MEAN DICE  0.6315172
MODEL CLASS STD DICE  0.18841812

MODEL DIOU 0.8711695413621241
MODEL DIOU STD  0.26683839405871324

MODEL IOU MASS VOLUME  0.6665454506874084
MODEL IOU MASS VOLUME STD  0.4405349791049957

MODEL IOU MASS VOLUME NO EMPTY  0.22240541875362396
MODEL IOU MASS VOLUME NO EMPTY STD  0.3272370994091034

MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC  0.6016120314598083
MODEL IOU MASS VOLUME NO EMPTY OPTIMISTIC STD  0.2480519711971283
MODEL DICE MASS VOLUME  0.6841536164283752
MODEL DICE MASS VOLUME STD  0.44022512435913086

MODEL DICE MASS VOLUME NO EMPTY  0.2634667456150055
MODEL DICE MASS VOLUME NO EMPTY STD  0.37692826986312866

MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC  0.712683916091919
MODEL DICE MASS VOLUME NO EMPTY OPTIMISTIC STD  0.25331050157546997

MODEL MEAN ACCURACY NO EMPTY 0.6110176518559456
MODEL MEAN ACCURACY NO EMPTY STD 0.16412875056266785

MODEL MEAN PRECISION NO EMPTY 0.634587451

In [52]:
scores_for_statistics_swin.keys()

dict_keys(['miou', 'mdice', 'mf1'])

In [None]:
scores_for_statistics_swin

In [53]:
save_to_json(scores_for_statistics_swin, "scores_for_statistics_swin.json")

Dictionary successfully saved to scores_for_statistics_swin.json


In [54]:
print("ciao")

ciao


# TEST PATIENT AWARE

### TEST MULTI-UNET SUB CABFL LARGE

In [None]:
patient_scores_for_statistics_fusion_mid = test_patient_aware_fusion(model_path="FUSION-SUB-CABL-FINAL.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", subtracted=True, strict=False, get_scores_for_statistics=True, base_channels=64
                             )

In [None]:
save_to_json(patient_scores_for_statistics_fusion_mid, "patient_scores_for_statistics_fusion_mid.json")

In [None]:
patient_scores_for_statistics_fusion_tiny = test_patient_aware_fusion(model_path="FUSION-SUB-CABL-tiny.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", subtracted=True, strict=False, get_scores_for_statistics=True, base_channels=16
                             )

In [None]:
save_to_json(patient_scores_for_statistics_fusion_tiny, "patient_scores_for_statistics_fusion_tiny.json")

### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES

In [None]:
patient_scores_for_statistics_ensemble = test_patient_aware_ensemble(
                              model_whole_path="FUSION-SUB-CABL-tiny.ckpt", 
                              model_patches_path="PATCHES-SUB-CABL-unetplusplus.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", base_channels=16,
                              patches_dataset_key="patches_sub_test_ds", subtracted=True, get_scores_for_statistics=True
                             )

In [None]:
print("ciao")

In [None]:
save_to_json(patient_scores_for_statistics_ensemble, "patient_scores_for_statistics_ensemble.json")

### ENSEMBLE MULTI-UNET CABFL + CABFL PATCHES -filtered

In [None]:
patient_scores_for_statistics_ensemble_filtered = test_patient_aware_ensemble(
                              model_whole_path="FUSION-SUB-CABL-tiny.ckpt", 
                              model_patches_path="PATCHES-SUB-CABL-unetplusplus.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, base_channels=16,
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds", subtracted=True, get_scores_for_statistics=True, filter=True
                             )

In [None]:
save_to_json(patient_scores_for_statistics_ensemble_filtered, "patient_scores_for_statistics_ensemble_filtered.json")

### UNETPLUSPLUS

In [None]:
from typing import List
patient_scores_for_statistics_unetplusplus = test_patient_aware_no_patches(model_path="unetplusplus_model.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="unetplusplus", #get_scores_for_statistics=True,
                              subtracted=True)

In [None]:
save_to_json(patient_scores_for_statistics_unetplusplus, "patient_scores_for_statistics_unetplusplus.json")

### Skinny

In [None]:
from typing import List
patient_scores_for_statistics_skinny = test_patient_aware_no_patches(model_path="skinny_model.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="skinny", #get_scores_for_statistics=True,
                              subtracted=True)

In [None]:
#save_to_json(patient_scores_for_statistics_skinny, "patient_scores_for_statistics_skinny.json")

### FNC-FFNET

In [None]:
from typing import List
patient_scores_for_statistics_fcn= test_patient_aware_no_patches(model_path="fcn_ffnet_model.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="fcn_ffnet", #get_scores_for_statistics=True,
                              subtracted=True)

In [None]:
#save_to_json(patient_scores_for_statistics_fcn, "patient_scores_for_statistics_fcn.json")

In [None]:
print("ciao")

In [None]:
print("ciao")

### SEGNET

In [None]:
from typing import List, Tuple
from monai.networks.nets import UNet, BasicUNetPlusPlus,SwinUNETR
patient_scores_for_statistics_segnet= test_patient_aware_no_patches(model_path="segnet_model_large.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="segnet", #get_scores_for_statistics=True,
                              subtracted=True)

In [None]:
#save_to_json(patient_scores_for_statistics_segnet, "patient_scores_for_statistics_segnet_large.json")

In [None]:
print("ciao")

### SWIN-UNET

In [None]:
from typing import List, Tuple
scores_for_statistics_swin= test_patient_aware_no_patches(model_path="swin_model.ckpt", 
                              patient_ids=x_test, 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", 
                              filter=False, 
                              arch_name="swin_unetr", #get_scores_for_statistics=True,
                              subtracted=True)

# VIZ for Back analysis

In [46]:
import numpy as np

def calculate_bbox(mask, desired_size=(150, 150)):
    """Calculate a fixed-size bounding box centered around the mask's non-zero region or return a bounding box equal to the image size if the mask is empty."""
    # Check if the mask is empty
    if not np.any(mask):
        # Return a bounding box equal to the image size
        return 0, 0, mask.shape[1], mask.shape[0]

    # Continue with original logic if the mask is not empty
    rows = np.any(mask, axis=1)
    cols = np.any(mask, axis=0)
    ymin, ymax = np.where(rows)[0][[0, -1]]
    xmin, xmax = np.where(cols)[0][[0, -1]]

    # Calculate center of original bounding box
    center_x = xmin + (xmax - xmin) // 2
    center_y = ymin + (ymax - ymin) // 2

    # Determine new bounding box dimensions
    desired_width, desired_height = desired_size

    # Calculate new bounding box coordinates, centered around the original bounding box
    new_xmin = max(center_x - desired_width // 2, 0)
    new_xmax = new_xmin + desired_width
    new_ymin = max(center_y - desired_height // 2, 0)
    new_ymax = new_ymin + desired_height

    # Adjust if the new bounding box exceeds image boundaries
    if new_xmax > mask.shape[1]:
        new_xmax = mask.shape[1]
        new_xmin = new_xmax - desired_width
    if new_ymax > mask.shape[0]:
        new_ymax = mask.shape[0]
        new_ymin = new_ymax - desired_height

    return new_xmin, new_ymin, new_xmax, new_ymax


def crop_to_bbox(img, bbox):
    """Crop the image to the specified bounding box."""
    xmin, ymin, xmax, ymax = bbox
    return img[ymin:ymax+1, xmin:xmax+1]

In [47]:
def plot_compare_2_models(iter1, iter2, model1, model2, device, prefix_fname=None, crop_to_size=None, extension='svg', patches=False):
    model1 = model1.to(device).eval()
    model2 = model2.to(device).eval()

    data_iter1 = next(iter1)
    data_iter2 = next(iter2)

    if patches:

        images1 = data_iter1["image"]
        images2 =data_iter2["image"]
        label1 =data_iter1["label"]
        label2 = data_iter2["label"]

    else:
        images1 = data_iter1["image"]
        images2 =data_iter2["image"]
        label1 =data_iter1["label"]
        label2 = data_iter2["label"]
        

    with torch.no_grad():
        # Perform inference with both models
        logits1 = model1(images1.to(device))[0]
        logits2 = model2(images2.to(device))[0]
    
        # Apply sigmoid to the logits to get the probabilities
        pr_masks1 = logits1.sigmoid()
        pr_masks2 = logits2.sigmoid()

    saved_counter = 0
    
    for image1,image2, gt_mask1,gt_mask2, pr_mask1, pr_mask2 in zip(images1,images2, label1,label2, pr_masks1, pr_masks2):
        # Threshold the probabilities to get binary masks
        pr_mask1 = (pr_mask1.cpu().numpy().squeeze() > 0.8).astype(int)
        pr_mask2 = (pr_mask2.cpu().numpy().squeeze() > 0.8).astype(int)
    
        gt_mask1 = gt_mask1.squeeze()
        gt_mask2 = gt_mask2.squeeze()

        if crop_to_size:
            # Calculate bounding boxes for the ground truth masks
            bbox1 = calculate_bbox(gt_mask1, crop_to_size)
            bbox2 = calculate_bbox(gt_mask2,crop_to_size)
        
            # Crop images, predictions, and ground truths to the bounding box
            image1 = crop_to_bbox(image1.cpu().numpy().transpose(1, 2, 0), bbox1)
            image2 = crop_to_bbox(image2.cpu().numpy().transpose(1, 2, 0), bbox2)
            pr_mask1 = crop_to_bbox(pr_mask1, bbox1)
            pr_mask2 = crop_to_bbox(pr_mask2, bbox2)
            gt_mask1 = crop_to_bbox(gt_mask1, bbox1)
            gt_mask2 = crop_to_bbox(gt_mask2, bbox2)

        else:
            image1 = image1.cpu().numpy().transpose(1, 2, 0)
            image2 = image2.cpu().numpy().transpose(1, 2, 0)
        
        plt.figure(figsize=(18, 10))
        # Original Image 1
        plt.subplot(2, 3, 1)
        plt.imshow(image1, cmap='gray')
        plt.title("Post Contrast Image")
        plt.axis("off")
        
        # Prediction from Model 1
        plt.subplot(2, 3, 2)
        plt.imshow(pr_mask1, cmap='gray')
        plt.title("Prediction")
        plt.axis("off")
        
        # Ground Truth 1
        plt.subplot(2, 3, 3)
        plt.imshow(gt_mask1, cmap='gray')
        plt.title("Ground Truth")
        plt.axis("off")
        
        # Original Image 2
        plt.subplot(2, 3, 4)
        plt.imshow(image2, cmap='gray')
        plt.title("Post Contrast Image (Subtracted)")
        plt.axis("off")
        
        # Prediction from Model 2
        plt.subplot(2, 3, 5)
        plt.imshow(pr_mask2, cmap='gray')
        plt.title("Prediction")
        plt.axis("off")
        
        # Ground Truth 2
        plt.subplot(2, 3, 6)
        plt.imshow(gt_mask2, cmap='gray')
        plt.title("Ground Truth")
        plt.axis("off")
        
        plt.tight_layout()

        if prefix_fname:
            if saved_counter<30:
                # Save the plot to a file
                plt.savefig(f'{prefix_fname}_{saved_counter}.{extension}', dpi=300)  # Adjust the filename and DPI as needed
                saved_counter+=1
            
        plt.show()

        print()
        print()
        print()

In [48]:
def plot_compare_2_models_fusion(iter1, iter2, model1, model2, device, prefix_fname=None, crop_to_size=None, extension='svg'):
    model1 = model1.to(device).eval()
    model2 = model2.to(device).eval()

    data_iter1 = next(iter1)
    data_iter2 = next(iter2)

    images11 = data_iter1[0]["image"]
    images12 = data_iter1[1]["image"]
    images13 = data_iter1[2]["image"]
        
    images21 =data_iter2[0]["image"]
    images22 =data_iter2[1]["image"]
    images23 =data_iter2[2]["image"]
        
    label1 =data_iter1[0]["label"]
    label2 = data_iter2[0]["label"]

    with torch.no_grad():
        # Perform inference with both models
        logits1 = model1(images11.to(device),images12.to(device),images13.to(device))
        logits2 = model2(images21.to(device),images22.to(device),images23.to(device))
    
        # Apply sigmoid to the logits to get the probabilities
        pr_masks1 = logits1.sigmoid()
        pr_masks2 = logits2.sigmoid()

    saved_counter = 0
    
    for image1,image2, gt_mask1,gt_mask2, pr_mask1, pr_mask2 in zip(images11,images12, label1,label2, pr_masks1, pr_masks2):
        # Threshold the probabilities to get binary masks
        pr_mask1 = (pr_mask1.cpu().numpy().squeeze() > 0.8).astype(int)
        pr_mask2 = (pr_mask2.cpu().numpy().squeeze() > 0.8).astype(int)
    
        gt_mask1 = gt_mask1.squeeze()
        gt_mask2 = gt_mask2.squeeze()

        if crop_to_size:
            # Calculate bounding boxes for the ground truth masks
            bbox1 = calculate_bbox(gt_mask1, crop_to_size)
            bbox2 = calculate_bbox(gt_mask2,crop_to_size)
        
            # Crop images, predictions, and ground truths to the bounding box
            image1 = crop_to_bbox(image1.cpu().numpy().transpose(1, 2, 0), bbox1)
            image2 = crop_to_bbox(image2.cpu().numpy().transpose(1, 2, 0), bbox2)
            pr_mask1 = crop_to_bbox(pr_mask1, bbox1)
            pr_mask2 = crop_to_bbox(pr_mask2, bbox2)
            gt_mask1 = crop_to_bbox(gt_mask1, bbox1)
            gt_mask2 = crop_to_bbox(gt_mask2, bbox2)

        else:
            image1 = image1.cpu().numpy().transpose(1, 2, 0)
            image2 = image2.cpu().numpy().transpose(1, 2, 0)
        
        plt.figure(figsize=(18, 10))
        # Original Image 1
        plt.subplot(1, 4, 1)
        plt.imshow(image1, cmap='gray')
        plt.title("Post Contrast Image (Subtracted)")
        plt.axis("off")
        
        # Prediction from Model 1
        plt.subplot(1, 4, 2)
        plt.imshow(pr_mask1, cmap='gray')
        plt.title("Prediction AUFL")
        plt.axis("off")

        # Prediction from Model 2
        plt.subplot(1, 4, 3)
        plt.imshow(pr_mask2, cmap='gray')
        plt.title("Prediction CABFL")
        plt.axis("off")
        
        
        # Ground Truth 1
        plt.subplot(1, 4, 4)
        plt.imshow(gt_mask1, cmap='gray')
        plt.title("Ground Truth")
        plt.axis("off")
    
        
        plt.tight_layout()

        if prefix_fname:
            if saved_counter<30:
                # Save the plot to a file
                plt.savefig(f'{prefix_fname}_{saved_counter}.{extension}', dpi=300)  # Adjust the filename and DPI as needed
                saved_counter+=1
            
        plt.show()

        print()
        print()
        print()

In [49]:
def inference_base(model_path, patient_ids, datasets, dataset_key, strict=False, arch_name=None,filter=False):

    if arch_name:
        model = BreastModel2.load_from_checkpoint(model_path, strict=strict, arch=arch_name)

    else:

        model = BreastModel2.load_from_checkpoint(model_path, strict=strict)


    results = {}
    
    
    for patient_id in patient_ids:
        
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
         image_slices_sub=[]
        
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            im_path_key_sub = 'subtracted_filename_or_obj' 
            im_path_key = 'filename_or_obj'
            original_image = np.load(e['image_meta_dict'][im_path_key])
            original_image = np.expand_dims(original_image,0)

            original_image_sub = np.load(e['image_meta_dict'][im_path_key_sub])
            original_image_sub = np.expand_dims(original_image_sub,0)

            gt_label = np.load(e['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if e['keep_sample']:
                image = torch.unsqueeze(e['image'], 0)
                    
                with torch.no_grad():
                    model = model.to("cuda")
                    model.eval()
                    if arch_name:
                        masks = model(image.to("cuda"))[0]
                    else:
                        masks = model(image.to("cuda"))[0]
                    masks = masks.sigmoid()
                    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = torch.squeeze(pred_label)
                pred_label = torch.unsqueeze(pred_label, 0)
                pred_label = reverse_transformations(dataset[idx], pred_label, mode='whole')
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)
            else:
                pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
    
    
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())
            image_slices_sub.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
         images_volume_sub = np.stack(image_slices_sub, axis=-1)
        
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N

         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='none', class_id=1, exclude_empty=True)
         results[patient_id] = {
            'images_volume': images_volume,
            'images_volume_sub': images_volume_sub,
            'gt_volume': gt_label_volume,
            'predicted_volume': predicted_label_volume,
             'dice_mass_scores':dice_mass_volume
        }


    return results


def inference_patches(model_path, patient_ids, datasets, dataset_key, arch_name=None, strict=False, subtracted=False):

    if arch_name:
        model = BreastModel2.load_from_checkpoint(model_path, strict=strict, arch=arch_name)

    else:
        model = BreastModel2.load_from_checkpoint(model_path, strict=strict)
    

    results = {}
    
    
    for patient_id in patient_ids:
         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
    
         print(patient_id)
         dataset = datasets[patient_id][dataset_key]
        
         for idx, e in tqdm(enumerate(dataset), total = len(dataset)):
            im_path_key = 'subtracted_filename_or_obj' if subtracted else 'filename_or_obj'
            original_image = np.load(e[0]['image_meta_dict'][im_path_key])
            original_image = np.expand_dims(original_image,0)

            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
            
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)


            merged_label = torch.zeros(original_image.shape)
            merged_label_for_fusion = torch.zeros(original_image.shape)

            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model = model.to("cuda")
                        model.eval()
                        if arch_name:
                            logits = model(image.to("cuda"))
                        else:
                            logits = model(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        pr_mask = (pr_mask > 0.4).int()
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label += label
                    

                pred_label = merged_label
                pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)

            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
       
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='none', class_id=1, exclude_empty=True)
         results[patient_id] = {
            'images_volume': images_volume,
            'gt_volume': gt_label_volume,
            'predicted_volume': predicted_label_volume,
             'dice_mass_scores':dice_mass_volume
         }

        

    return results

def inference_fusion(model_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key , subtracted=False,filter = False):

    model = BreastModel.load_from_checkpoint(model_path, strict=False)

    results = {}
    
    for patient_id in patient_ids:

         predicted_label_slices = []
         gt_label_slices=[]
         image_slices=[]
         image_slices_sub=[]
    
         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)

         for idx, e in tqdm(enumerate(fusion_dataset), total = len(patches_ds)):
            im_path_key_sub = 'subtracted_filename_or_obj' 
            im_path_key = 'filename_or_obj'
            original_image = np.load(e['image_meta_dict'][im_path_key])
            original_image = np.expand_dims(original_image,0)

            original_image_sub = np.load(e['image_meta_dict'][im_path_key_sub])
            original_image_sub = np.expand_dims(original_image_sub,0)

    
            
            pred_label = torch.zeros(original_image.shape, dtype=torch.uint8)
    
            gt_label = np.load(e[0]['label_meta_dict']['filename_or_obj'])
            gt_label= np.expand_dims(gt_label,0)

            if fusion_dataset[idx][0]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model = model.to("cuda")
                    model.eval()
                    
                    masks = model(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
    
                pred_label = masks[0]
                pred_label = (pred_label > 0.4).int()
                pred_label = reverse_transformations(fusion_dataset[idx][0], pred_label, mode='whole')
                

                
            pred_label = monai.transforms.Resize(spatial_size=(original_image.shape[1], original_image.shape[2]), mode='nearest-exact')(pred_label)
            
            predicted_label_slices.append(pred_label.squeeze())
            gt_label_slices.append(gt_label.squeeze())
            images_volume = np.stack(image_slices, axis=-1)
            images_volume_sub = np.stack(image_slices_sub, axis=-1)

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
                    
         gt_label_volume = np.stack(gt_label_slices, axis=-1)
         images_volume = np.stack(image_slices, axis=-1)
    
         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N 
         
         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='none', class_id=1, exclude_empty=True)
         results[patient_id] = {
            'images_volume': images_volume,
             'images_volume_sub': images_volume_sub,
            'gt_volume': gt_label_volume,
            'predicted_volume': predicted_label_volume,
             'dice_mass_scores':dice_mass_volume
         }

    return results

def inference_ensemble(model_whole_path, model_patches_path, patient_ids, datasets, whole_dataset_key, patches_dataset_key,filter = False):

    model_whole = BreastModel.load_from_checkpoint(model_whole_path, strict=False)
    model_patches = BreastModel2.load_from_checkpoint(model_patches_path, strict=False)

    results = {}
    
    for patient_id in patient_ids:
         predicted_label_slices = []         
         image_slices=[]
         image_slices_sub=[]

         print(patient_id)
         patches_ds = datasets[patient_id][patches_dataset_key]
         whole_image_ds = datasets[patient_id][whole_dataset_key]

         fusion_dataset = PairedDataset(whole_image_ds, patches_ds, augment=False)
        
         prev_had_mask=False

         for idx, e in tqdm(enumerate(patches_ds), total = len(patches_ds)):
            im_path_key_sub = 'subtracted_filename_or_obj' 
            im_path_key = 'filename_or_obj'
            original_image = np.load(e['image_meta_dict'][im_path_key])
            original_image = np.expand_dims(original_image,0)

            original_image_sub = np.load(e['image_meta_dict'][im_path_key_sub])
            original_image_sub = np.expand_dims(original_image_sub,0)

    
            merged_label_for_fusion = torch.zeros(original_image.shape)
            
            ## FIRST MODEL
            for elem in e:
                if elem['keep_sample']:
                    image = torch.unsqueeze(elem['image'], 0)
                    with torch.no_grad():
                        model_patches = model_patches.to("cuda")
                        model_patches.eval()
                        logits = model_patches(image.to("cuda"))[0]
    
                    pr_mask = logits.sigmoid()
                    pr_mask = pr_mask[0]
                    #pr_mask_to_viz = (pr_mask.cpu().numpy() > 0.4).astype(int)
    
                    if pr_mask.sum()>0:
                        #label = pr_mask
                        label = reverse_transformations(elem, pr_mask, mode='patches')
                        merged_label_for_fusion += label
    
    
            original_image = np.transpose(original_image, (1,2,0))
    
            label_patches_for_fusion = merged_label_for_fusion[0]
    
            # SECOND MODEL

            if fusion_dataset[idx][0]['keep_sample'] or fusion_dataset[idx][1]['keep_sample'] or fusion_dataset[idx][2]['keep_sample']:
    
                whole_image = torch.unsqueeze(fusion_dataset[idx][0]['image'], 0)
                patch_image2 = torch.unsqueeze(fusion_dataset[idx][1]['image'], 0)
                patch_image3 = torch.unsqueeze(fusion_dataset[idx][2]['image'], 0)
                    
        
                with torch.no_grad():
                    masks = []
                    # pass to model
                    model_whole = model_whole.to("cuda")
                    model_whole.eval()
                    
                    masks = model_whole(whole_image.to("cuda"),patch_image2.to("cuda"),patch_image3.to("cuda"))
                    masks = masks.sigmoid()
                    
    
                label_whole = masks[0]
                label_whole = (label_whole > 0.4).int()
                label_whole = reverse_transformations(whole_image_ds[idx], label_whole, mode='whole')
                
                label_whole = label_whole.squeeze()
        
                label_whole_for_fusion= masks[0]
                label_whole_for_fusion = reverse_transformations(whole_image_ds[idx], label_whole_for_fusion, mode='whole')
                
                # Plot the first image
            else:
                label_whole_for_fusion = torch.zeros(original_image.shape)
                
            original_image_squeeze = np.load(e[0]['image_meta_dict']['filename_or_obj'])
    
            fusion = fuse_segmentations(label_whole_for_fusion.numpy(), label_patches_for_fusion.numpy(), prob_threshold=0.4, boost_factor=3, penalty_factor=0.5, kernel_size=150)
            
            fusion = (fusion > 0.4).astype(int)
    
            fusion = np.expand_dims(fusion, 0)
            pred_label=fusion
    
            predicted_label_slices.append(pred_label.squeeze())
            image_slices.append(original_image.squeeze())

         predicted_label_volume = np.stack(predicted_label_slices, axis=-1)  # Stack along the first axis to create a 3D volume
 
         images_volume = np.stack(image_slices, axis=-1)


         if filter:
             predicted_label_volume = filter_masses(predicted_label_volume, min_slices=3, window_size=3) # H x W x N

         dice_mass_volume = compute_dice_score_npy(gt_label_volume, predicted_label_volume,reduction='none', class_id=1, exclude_empty=True)
         results[patient_id] = {
            'images_volume': images_volume,
            'images_volume_sub': images_volume_sub,
            'gt_volume': gt_label_volume,
            'predicted_volume': predicted_label_volume,
             'dice_mass_scores':dice_mass_volume
         }


    return results


In [50]:
def process_data_dict(d, key_to_del='gt_volume', old_key_name='predicted_volume', new_key_name='label_volume'):
    del [key_to_del]
    d[new_key_name] = d.pop(old_key_name)

In [51]:
def plot_image_groups_multiple_ids(results_groups, captions, num_plots, indexes, bbox_size=None, plot_captions=True):
    # Determine the figure size dynamically based on the number of patients and plots
    width_per_plot = max(2, 10 / num_plots)  # Ensure a minimum width
    total_width = width_per_plot * num_plots
    total_height = max(3, 5 - num_plots) * len(indexes)  # Height increases with more patients

    fig, axs = plt.subplots(len(indexes), num_plots, figsize=(total_width, total_height))

    # If there is only one patient, axs might not be a 2D array
    if len(indexes) == 1:
        axs = [axs]  # Make it a 2D array for consistency

    # Flatten axs array for easier indexing if multiple patients
    axs = axs.flatten()

    # Use subplots_adjust to reduce horizontal spacing
    fig.subplots_adjust(wspace=0.0001, hspace=0.000007)  # Adjust this value as needed

    # Loop over each index
    for elem_idx, index in enumerate(indexes):
        for i, (results_current, caption) in enumerate(zip(results_groups, captions)):
            # Calculate index in axs for the current patient
            ax_index =  elem_idx * num_plots + i
            images_volume = results_current['images_volume']
            predicted_volume = results_current['label_volume']
            dice_mass_scores = results_current['dice_mass_scores']
            img = images_volume[:,:, index] 
            pred = predicted_volume[:,:, index]
            dice_mass_score = dice_mass_scores[index]

            if bbox_size:
                gt_mask = results_groups[-1]['label_volume'][:,:, index]
                bbox = calculate_bbox(gt_mask, bbox_size)
                pred = crop_to_bbox(pred, bbox)
                img = crop_to_bbox(img, bbox)
                
            axs[ax_index].imshow(img, cmap='gray', interpolation='bilinear')
            axs[ax_index].contour(pred, colors='lime', linewidths=1, alpha=1)
            axs[ax_index].axis('off')

            if plot_captions:
                if i != len(results_groups)-1:
                    # Place captions below each image
                    caption_detail = f"{caption}\n(mDSC:{dice_mass_score:.4f})"
                else:
                    caption_detail = f"{caption}"
                axs[ax_index].set_title(caption_detail, fontsize=15, pad=3)
    
    plt.tight_layout()  # Adjust layout to prevent overlap
    plt.savefig("multiple_patients.pdf", bbox_inches='tight')
    plt.show()




def save_image_groups(results_groups, captions, indexes, suffixes, save_images=False, bbox_size=None):
    for elem_idx, index in enumerate(indexes):
        for i, (results_current, caption) in enumerate(zip(results_groups, captions)):
            plt.figure(figsize=(10, 10))  # Higher DPI for better resolution
            images_volume = results_current['images_volume']
            predicted_volume = results_current['label_volume']
            dice_mass_scores = results_current['dice_mass_scores']
            img = images_volume[:, :, index]
            pred = predicted_volume[:, :, index]

            if bbox_size:
                gt_mask = results_groups[-1]['label_volume'][:, :, index]
                bbox = calculate_bbox(gt_mask, bbox_size)
                pred = crop_to_bbox(pred, bbox)
                img = crop_to_bbox(img, bbox)

            plt.imshow(img, cmap='gray', interpolation='bilinear')
            plt.contour(pred, colors='lime', linewidths=1)
            plt.axis('off')

            if save_images:
                image_filename = f"{caption.replace(' ', '_')}_{suffixes[elem_idx]}.png"
                plt.savefig(image_filename, bbox_inches='tight', pad_inches=0)
                plt.close()
            else:
                plt.show()

# Viz for Paper

In [52]:
datasets.keys()

dict_keys(['LAXXX', 'AS0170', 'PR0760', 'GF0380', 'FP211261', 'D2MP3(VR)', 'MG0477', 'OL1062R', 'BV1252', 'D1AP7(VR)', 'SD080569', 'CC0167', 'RHCL031174', 'LA0248', 'RP271052', 'SL191251', 'LGM0159(1,5)', 'PA150139', 'HF230274', 'CF160366', 'GLA1074'])

In [53]:
from matplotlib import rc
rc('font', **{'family': 'serif', 'serif': ['cmr10']})

In [54]:
# resnet, segnet, fcffnet (skinny), multiunet cabfl
resnet_results = inference_base(model_path="RESNET-PRIVATE.ckpt", 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", arch_name='UNet', patient_ids = datasets.keys())



LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

AS0170


  0%|          | 0/192 [00:00<?, ?it/s]

PR0760


  0%|          | 0/228 [00:00<?, ?it/s]

GF0380


  0%|          | 0/204 [00:00<?, ?it/s]

FP211261


  0%|          | 0/188 [00:00<?, ?it/s]

D2MP3(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

MG0477


  0%|          | 0/132 [00:00<?, ?it/s]

OL1062R


  0%|          | 0/228 [00:00<?, ?it/s]

BV1252


  0%|          | 0/188 [00:00<?, ?it/s]

D1AP7(VR)


  0%|          | 0/228 [00:00<?, ?it/s]

SD080569


  0%|          | 0/236 [00:00<?, ?it/s]

CC0167


  0%|          | 0/228 [00:00<?, ?it/s]

RHCL031174


  0%|          | 0/188 [00:00<?, ?it/s]

LA0248


  0%|          | 0/284 [00:00<?, ?it/s]

RP271052


  0%|          | 0/96 [00:00<?, ?it/s]

SL191251


  0%|          | 0/80 [00:00<?, ?it/s]

LGM0159(1,5)


  0%|          | 0/112 [00:00<?, ?it/s]

PA150139


  0%|          | 0/96 [00:00<?, ?it/s]

HF230274


  0%|          | 0/204 [00:00<?, ?it/s]

CF160366


  0%|          | 0/212 [00:00<?, ?it/s]

GLA1074


  0%|          | 0/112 [00:00<?, ?it/s]

LAXXX


  0%|          | 0/228 [00:00<?, ?it/s]

In [None]:
skinny_results = inference_base(model_path="skinny_model_private.ckpt", 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", arch_name='skinny', patient_ids = datasets.keys())


In [None]:
multiunet_cabfl = inference_fusion(model_path="FUSION-SUB-CABL.ckpt", 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds",patient_ids = datasets.keys())
                             

In [None]:
segnet_results = inference_base(model_path="segnet_model_private.ckpt", 
                              datasets=datasets, 
                              dataset_key="no_thorax_sub_test_ds", arch_name='segnet', patient_ids = datasets.keys())


In [None]:
ground_truth = copy.deepcopy(skinny_results)
process_data_dict(ground_truth, key_to_del='predicted_volume', old_key_name='gt_volume', new_key_name='label_volume')

In [None]:
process_data_dict(multiunet_cabfl)
process_data_dict(segnet_results)
process_data_dict(skinny_results)
process_data_dict(resnet_results)

In [None]:
res = []
for g in results_groups[:-1]:
    res.append(g['dice_mass_scores'])


# Function to find increasing indices across arrays
def find_increasing_indices(arrays):
    num_arrays = len(arrays)
    num_elements = arrays[0].size  # Assumes all arrays have the same length
    increasing_indices = []

    for index in range(num_elements):
        is_increasing = True
        for i in range(num_arrays - 1):
            if arrays[i][index] >= arrays[i + 1][index]:
                is_increasing = False
                break
        if is_increasing:
            increasing_indices.append(index)

    return increasing_indices
# Get the increasing indexes
increasing_indexes = find_increasing_indices(res)

In [None]:
resnet, segnet, fcffnet (skinny), multiunet cabfl

In [None]:
results_groups = [no_thorax_sub_cabfl_base_unet_results, patches_sub_cabfl_results, multiunet_cabfl,cabl_cabfl_ensemble, ground_truth]
captions = ['ResNet50-UNet \nFull Breasts', ,'Multi-UNet', 'Ensemble & \n Post-Processing', 'Ground Truth'] 

plot_image_groups_multiple_ids(results_groups=results_groups, captions=captions, num_plots= len(results_groups), indexes=indexes, bbox_size = (100,100), plot_captions=True)

In [None]:
import matplotlib
matplotlib.rcParams.update(matplotlib.rcParamsDefault)
save_image_groups(results_groups=results_groups, captions=captions, suffixes=['mass1', 'mass2'], indexes=indexes, save_images=True, bbox_size = (100,100))

# Other 

In [None]:
spacing_dict = {'GF220280': [0.7589285969734192, 0.7589285969734192, 1.7999992370605469],
 'MP140270': [0.8035714030265808, 0.8035714030265808, 1.7999992370605469],
 'SL191251': [0.7589285969734192, 0.7589285969734192, 1.8000030517578125],
 'LPA310774': [0.7031000256538391, 0.7031000256538391, 1.0],
 'AM14051962': [0.7031000256538391, 0.7031000256538391, 1.0],
 'MA020377': [0.7031000256538391, 0.7031000256538391, 1.0],
 'PBS050277': [0.8088235259056091, 0.8088235259056091, 1.7999992370605469],
 'VF260656': [0.7031000256538391, 0.7031000256538391, 1.0],
 'SD080569': [0.7031000256538391, 0.7031000256538391, 1.0],
 'CC100582': [0.625, 0.625, 0.9999945163726807],
 'BP130964': [0.5468999743461609, 0.5468999743461609, 1.0],
 'CF160366': [0.5859000086784363, 0.5859000086784363, 1.0],
 'EA030650': [0.7031000256538391, 0.7031000256538391, 1.0],
 'HF230274': [0.625, 0.625, 1.0],
 'IV100377': [0.625, 0.625, 1.0],
 'PV200741': [0.7031000256538391, 0.7031000256538391, 1.0],
 'PA150139': [0.5625, 0.5625, 1.5999984741210938],
 'RP271052': [0.59375, 0.59375, 2.0],
 'SG170880': [0.7031000256538391, 0.7031000256538391, 1.0]}

In [None]:
import SimpleITK as sitk

In [None]:
def save_nifti_volumes(results_dict, dest_folder, key='label_volume', interpolator=sitk.sitkNearestNeighbor):
    def resample_image(volume, original_spacing, new_spacing, interpolator=sitk.sitkNearestNeighbor):
        # Ensure the spacing is 3D

        if len(new_spacing) == 3:
            new_spacing = (new_spacing[0],new_spacing[1], 1)
            
        
        if len(original_spacing) == 2:
            original_spacing = (*original_spacing, 1)  # assuming the third dimension has unit length
            
        if len(new_spacing) == 2:
            new_spacing = (*new_spacing, 1)  # assuming the third dimension has unit length
    
        # Create a SimpleITK image from the numpy array
        sitk_volume = sitk.GetImageFromArray(volume)
        sitk_volume.SetSpacing(original_spacing)
    
        # Calculate the new size
        original_size = sitk_volume.GetSize()
        new_size = [int(np.floor(os * osp / nsp)) for os, osp, nsp in zip(original_size, original_spacing, new_spacing)]
    
        # Create the resampler
        resampler = sitk.ResampleImageFilter()
        resampler.SetOutputSpacing(new_spacing)
        resampler.SetSize(new_size)
        resampler.SetInterpolator(interpolator)  
    
        # Resample the image
        resampled_sitk_volume = resampler.Execute(sitk_volume)
        resampled_volume = sitk.GetArrayFromImage(resampled_sitk_volume)
    
        return resampled_volume
        
    original_pixel_spacing = [0.5,0.5, 1]
    
    for patient_id in results_dict.keys():

        patient_dir = os.path.join(dest_folder, str(patient_id))
        os.makedirs(patient_dir, exist_ok=True)

        volume = results_dict[patient_id][key]  
        volume = volume.astype(np.int16)
        volume = np.transpose(volume, (2,0,1))

        #label_volume = np.expand_dims(label_volume, 2)
    
        new_pixel_spacing = spacing_dict[patient_id]
    
        # Assuming label_volume is your final NumPy array

        print(f"patient id: {patient_id}")
        print(f"new_pixel_spacing: {new_pixel_spacing}")
        print(f"original_pixel_spacing: {original_pixel_spacing}")
        print(f"original_shape: {volume.shape}")
        
        if new_pixel_spacing != original_pixel_spacing:
            print("resampling..")

            volume =  resample_image(volume, original_pixel_spacing, new_pixel_spacing, interpolator=interpolator)
        
        print(f"new_shape{volume.shape}")
        #print(f"unique vals: {np.unique(volume)}")
        print()
        
        volume_sitk = sitk.GetImageFromArray(volume)
        
        
        # Specify the file name for your NIfTI file
        nifti_file_name = f'{patient_dir}/{patient_id}.nii.gz'
        
        # Save the SimpleITK Image as a NIfTI file
        sitk.WriteImage(volume_sitk, nifti_file_name)
        

In [None]:
datasets.keys()

In [None]:
cabl_ensemble =inference_ensemble(
                              model_whole_path="PRIVATE-FUSION-SUB-CABL.ckpt", 
                              model_patches_path="PRIVATE-PATCHES-SUB-CABFL.ckpt", 
                              patient_ids=p_id, 
                              datasets=datasets, 
                              whole_dataset_key="no_thorax_sub_test_ds", 
                              patches_dataset_key="patches_sub_test_ds",
                              filter=True,
                              subtracted=True
)

In [None]:
save_nifti_volumes(cabl_ensemble, dest_folder='images-subtracted-blank-december2024', key='images_volume', interpolator=sitk.sitkLinear)

In [None]:
save_nifti_volumes(cabl_ensemble, dest_folder='predicted_volume-blank-december2024', key='predicted_volume', interpolator=sitk.sitkNearestNeighbor)

In [None]:
cabl_ensemble.keys()

In [None]:
volume1, volume2 = cabl_ensemble['HF230274']['images_volume'], cabl_ensemble['HF230274']['predicted_volume']

In [None]:
plot_slices_side_by_side(volume1, volume2)

In [None]:
volume1, volume2 = cabl_ensemble['SD080569']['images_volume'], cabl_ensemble['SD080569']['predicted_volume']

In [None]:
plot_slices_side_by_side(volume1, volume2) #MM0478