# SETUP


## 1. Installs & Imports


In [None]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings
import pprint
pp = pprint.PrettyPrinter(indent=4)

# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import seed_everything

from natsort import natsorted, ns
from PIL import Image
from numpy import einsum
from torch.utils.data import default_collate
import psutil
from typing import List
import monai

# Import modules
from breast_segmentation.config.settings import config
from breast_segmentation.utils.seed import set_deterministic_mode, seed_worker, reseed
from breast_segmentation.data.dataset import (
    get_image_label_files, create_data_dicts
)
from breast_segmentation.data import custom_collate_no_patches
from breast_segmentation.transforms.compose import Preprocess
from breast_segmentation.models.lightning_module import BreastSegmentationModel
from breast_segmentation.models.architectures import get_model
from breast_segmentation.metrics.losses import get_loss_function, CrossEntropy2d, compute_class_weight, AsymmetricUnifiedFocalLoss

from breast_segmentation.data.private_dataset import (
    PATIENT_INFO, get_filenames, get_train_val_test_dicts, PATIENTS_TO_EXCLUDE
)


# Import additional loss functions used in original
from monai.losses import DiceLoss, DiceCELoss

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()


In [None]:
# Fix checkpoint directory reference for private dataset baselines
checkpoints_dir = config.checkpoints_dir_private
os.makedirs(checkpoints_dir, exist_ok=True)
print(f"Baselines Training - Using checkpoint directory: {checkpoints_dir}")


## 2. Environment Setup


In [None]:
patient_ids = list(PATIENT_INFO.keys())

# Apply exclusions (exactly as in reference notebook)
print(f"Initial patients from PATIENT_INFO: {len(patient_ids)}")
print(f"Patients to exclude: {PATIENTS_TO_EXCLUDE}")
patient_ids = [pid for pid in patient_ids if pid not in PATIENTS_TO_EXCLUDE]
print(f"Patients after exclusion: {len(patient_ids)}")

# Data split using original train_test_split
SEED = config.SEED 
dataset_base_path = config.DATASET_BASE_PATH_PRIVATE
x_train_val, x_test = train_test_split(patient_ids, test_size=0.2, random_state=SEED)
x_train, x_val = train_test_split(x_train_val, test_size=0.25, random_state=SEED)


print("Creating data dictionaries with filtering logic...")
train_dicts, val_dicts, test_dicts = get_train_val_test_dicts(dataset_base_path, x_train, x_val, x_test)

print(f"\ndata:")
print(f"  Training samples: {len(train_dicts)}")
print(f"  Validation samples: {len(val_dicts)}")
print(f"  Test samples: {len(test_dicts)}")


In [None]:
# Configuration - using config parameters
batch_size = config.BATCH_SIZE
num_workers = config.NUM_WORKERS
checkpoints_dir = config.checkpoints_dir_private
get_boundaryloss = False

# Ensure checkpoints directory exists
os.makedirs(checkpoints_dir, exist_ok=True)

# Set random seed for reproducibility
g = set_deterministic_mode(config.SEED)

print(f"Batch size: {batch_size}")
print(f"Number of workers: {num_workers}")
print(f"Checkpoints directory: {checkpoints_dir}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name()}")


## 4. Private Dataset Preprocessing Pipeline


In [None]:
# Define subtracted images path prefixes
sub_third_images_path_prefixes = ("Dataset-arrays-4-FINAL", "Dataset-arrays-FINAL")

# Create transforms for statistics calculation
statistics_transforms_private = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics',  
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=False
    )
])

print("Statistics transforms created for private dataset")

In [None]:
# Create dataset for statistics calculation
statistics_ds_private = CacheDataset(
    data=train_dicts, 
    transform=statistics_transforms_private,
    num_workers=num_workers
)

statistics_loader_private = DataLoader(
    statistics_ds_private, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False,
    collate_fn=custom_collate_no_patches  
)

print("Statistics dataset and loader created")

In [None]:
# Calculate mean and std for normalization
def get_mean_std_dataloader(dataloader, masked=False):
    """Calculate mean and std from dataloader."""
    sum_of_images = 0.0
    sum_of_squares = 0.0
    num_pixels = 0

    for batch in tqdm(dataloader):
        if batch is not None:
            image = batch["image"]
      
            if masked:
                mask = image > 0.0
                image = image[mask]
      
            sum_of_images += image.sum()
            sum_of_squares += (image ** 2).sum()
            num_pixels += image.numel()
        else:
            print("batch is None")
    
    mean = sum_of_images / num_pixels
    std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

    print(f'Mean: {mean}, Standard Deviation: {std_dev}')
    return mean.item(), std_dev.item()

mean_private, std_private = get_mean_std_dataloader(statistics_loader_private)
print(f"Calculated - Mean: {mean_private}, Standard Deviation: {std_private}")

In [None]:
mean_no_thorax_third_sub, std_no_thorax_third_sub = 43.14976119995117, 172.67039489746094

In [None]:
# Create final transforms with calculated statistics
test_transforms_private = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=lambda image: image.convert("L"))
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_private, 
        divisor=std_private, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

train_transforms_private = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=lambda image: image.convert("L"))
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_private, 
        divisor=std_private, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

In [None]:
# Create datasets (use original naming convention)
train_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_private,
    num_workers=num_workers
)

val_ds_no_thorax_third_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_private,
    num_workers=num_workers
)

test_ds_no_thorax_third_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_private,
    num_workers=num_workers
)

In [None]:
# Create dataloaders
no_thorax_sub_train_loader = DataLoader(
    train_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False,
    collate_fn=custom_collate_no_patches  
)

no_thorax_sub_val_loader = DataLoader(
    val_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False,
    collate_fn=custom_collate_no_patches)

no_thorax_sub_test_loader = DataLoader(
    test_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False,
    collate_fn=custom_collate_no_patches )

In [None]:
# Visualize sample batch
i = next(iter(no_thorax_sub_train_loader))["image"]
print(i.shape)
plt.imshow(i[0,0], cmap='gray')
plt.title('Sample from training data')
plt.show()


# Training FCN-FFNET


In [None]:
g = reseed()
ENCODER_NAME = None
fcn_ffnet_model = BreastSegmentationModel(
    arch="fcn_ffnet", 
    encoder_name=ENCODER_NAME, 
    loss_function="crossentropy2d", 
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_fcn_ffnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='unet-clahe--mit2-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)

trainer_fcn_ffnet_model = L.Trainer(
    devices = 1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es, cc_fcn_ffnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_fcn_ffnet_model.fit(
    fcn_ffnet_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)


In [None]:
fcn_ffnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_fcn_ffnet_model.best_model_path,
    loss_function="crossentropy2d" ,
    boundaryloss=False
)
test_metrics = trainer_fcn_ffnet_model.test(
    fcn_ffnet_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics[0])

# Training Swin-UNETR


In [None]:
g = reseed()
ENCODER_NAME = None
swin_unetr_model = BreastSegmentationModel(
    arch="swin_unetr", 
    encoder_name=ENCODER_NAME, 
    loss_function='soft_dice',
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_swin = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_swin_unetr_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='swin-unetr-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_swin_unetr_model = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_swin, cc_swin_unetr_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_swin_unetr_model.fit(
    swin_unetr_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)


In [None]:
swin_unetr_model = BreastSegmentationModel.load_from_checkpoint(
    cc_swin_unetr_model.best_model_path,
    loss_function='soft_dice',
    boundaryloss=False
)
test_metrics_swin = trainer_swin_unetr_model.test(
    swin_unetr_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics_swin[0])


# Training Unet++


In [None]:
g = reseed()
ENCODER_NAME = None
unetplusplus_model = BreastSegmentationModel(
    arch="unetplusplus", 
    encoder_name=ENCODER_NAME, 
    loss_function="dice_ce",
    loss_kwargs={"sigmoid": True, "lambda_dice": 0.5, "lambda_ce": 0.5},
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_unetpp = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_unetplusplus_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='unetplusplus-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_unetplusplus_model = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_unetpp, cc_unetplusplus_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_unetplusplus_model.fit(
    unetplusplus_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)


In [None]:
unetplusplus_model = BreastSegmentationModel.load_from_checkpoint(
    cc_unetplusplus_model.best_model_path,
    loss_function="dice_ce",
    boundaryloss=False
)
test_metrics_unetpp = trainer_unetplusplus_model.test(
    unetplusplus_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics_unetpp[0])


# Training SegNet


In [None]:
g = reseed()
ENCODER_NAME = None
segnet_model = BreastSegmentationModel(
    arch="segnet", 
    encoder_name=ENCODER_NAME, 
    loss_function="crossentropy2d",  
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_segnet = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_segnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='segnet-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_segnet_model = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_segnet, cc_segnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_segnet_model.fit(
    segnet_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)


In [None]:
segnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_segnet_model.best_model_path,
    loss_function="crossentropy2d", 
    boundaryloss=False
)
test_metrics_segnet = trainer_segnet_model.test(
    segnet_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics_segnet[0])


# Training Skinny


In [None]:
skinny_model = BreastSegmentationModel(
    arch="skinny",
    encoder_name=None,
    loss_function="dice_ce",
    loss_kwargs={"sigmoid": True, "lambda_dice": 0.5, "lambda_ce": 0.5},
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0],
    in_channels=config.IN_CHANNELS,
    out_classes=config.OUT_CHANNELS,
    batch_size=batch_size,
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

# Early stopping for SkinnyNet
es_skinny = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

# Model checkpoint for SkinnyNet
cc_skinny_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    dirpath="./checkpoints/",
    filename="skinny_best",
    auto_insert_metric_name=False
)


In [None]:
# Trainer for SkinnyNet
trainer_skinny_model = L.Trainer(
    max_epochs=config.MAX_EPOCHS,
    accelerator='auto',
    devices=1,
    callbacks=[es_skinny, cc_skinny_model],
    deterministic=True,
    precision=16
)

# Train SkinnyNet
trainer_skinny_model.fit(
    skinny_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)


In [None]:
# Load best SkinnyNet model and test
skinny_model = BreastSegmentationModel.load_from_checkpoint(
    cc_skinny_model.best_model_path,
    loss_function="dice_ce", 
    boundaryloss=False
)
test_metrics_skinny = trainer_skinny_model.test(
    skinny_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics_skinny[0])


# Training ResNet-UNet


In [None]:
g = reseed()
ENCODER_NAME = "resnet50"
resnet_model = BreastSegmentationModel(
    arch="UNet", 
    encoder_name=ENCODER_NAME, 
    loss_function="dice",  
    loss_kwargs={"sigmoid": True},
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_resnet = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_resnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='resnet-unet-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_resnet_model = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_resnet, cc_resnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)

In [None]:
trainer_resnet_model.fit(
    resnet_model,
    train_dataloaders=no_thorax_sub_train_loader,
    val_dataloaders=no_thorax_sub_val_loader
)

In [None]:
resnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_resnet_model.best_model_path,
    loss_function="dice",
    boundaryloss=False
)
test_metrics_resnet = trainer_resnet_model.test(
    resnet_model, 
    dataloaders=no_thorax_sub_test_loader, 
    verbose=False
)
pp.pprint(test_metrics_resnet[0])