# SETUP



## 1. Installs & Imports


In [None]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings
import pprint
pp = pprint.PrettyPrinter(indent=4)

# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import seed_everything

# 

from natsort import natsorted, ns
from PIL import Image
from numpy import einsum
from torch.utils.data import default_collate
import psutil
from typing import List
import monai

# Import modules
from breast_segmentation.config.settings import config
from breast_segmentation.utils.seed import set_deterministic_mode, seed_worker, reseed
from breast_segmentation.data.dataset import (
    get_image_label_files, create_data_dicts, split_data, create_dataloaders
)

from breast_segmentation.transforms.compose import Preprocess
from breast_segmentation.models.lightning_module import BreastSegmentationModel
from breast_segmentation.models.architectures import get_model
from breast_segmentation.metrics.losses import get_loss_function, CrossEntropy2d, compute_class_weight, AsymmetricUnifiedFocalLoss
from breast_segmentation.utils.visualization import (
    plot_batch_predictions, plot_training_history
)
from breast_segmentation.utils.postprocessing import (
    remove_far_masses_based_on_largest_mass
)

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()

## 2. Environment Setup


In [None]:
# Configuration - using config parameters
batch_size = config.BATCH_SIZE
num_workers = config.NUM_WORKERS
checkpoints_dir = config.checkpoints_dir_breadm
get_boundaryloss = False

# Ensure checkpoints directory exists
os.makedirs(checkpoints_dir, exist_ok=True)

# Set random seed for reproducibility
g = set_deterministic_mode(config.SEED)

print(f"Batch size: {batch_size}")
print(f"Number of workers: {num_workers}")
print(f"Checkpoints directory: {checkpoints_dir}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name()}")


## 3. Data Preparation


In [None]:
# Get image and label files
dataset_base_path = config.DATASET_BASE_PATH_BREADM
image_type = "VIBRANT+C2"

train_images, train_labels = get_image_label_files(
    dataset_base_path, "train", image_type
)
val_images, val_labels = get_image_label_files(
    dataset_base_path, "val", image_type
)
test_images, test_labels = get_image_label_files(
    dataset_base_path, "test", image_type
)

# Create data dictionaries
train_dicts = create_data_dicts(train_images, train_labels)
val_dicts = create_data_dicts(val_images, val_labels)
test_dicts = create_data_dicts(test_images, test_labels)

print(f"Dataset statistics:")
print(f"  Training samples: {len(train_dicts)}")
print(f"  Validation samples: {len(val_dicts)}")
print(f"  Test samples: {len(test_dicts)}")


## 4. Subtracted Images Preprocessing Pipeline


In [None]:
# Define subtracted images path prefixes
sub_third_images_path_prefixes = ("VIBRANT+C2", "SUB2")

# Create transforms for statistics calculation
statistics_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=lambda image: image.convert("L"))
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics',  
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

print("Statistics transforms created")


In [None]:
# Create dataset for statistics calculation
statistics_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=statistics_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

statistics_loader_no_thorax_third_sub = DataLoader(
    statistics_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

print("Statistics dataset and loader created")


In [None]:
# Calculate mean and std for normalization (exactly as in original)
def get_mean_std_dataloader(dataloader, masked=False):
    """Calculate mean and std from dataloader."""
    # Variables to store sum and sum of squares
    sum_of_images = 0.0
    sum_of_squares = 0.0
    num_pixels = 0

    # Iterate over the DataLoader
    for batch in tqdm(dataloader):
        if batch is not None:
            image = batch["image"]
      
            if masked:
                mask = image > 0.0
                image = image[mask]
      
            sum_of_images += image.sum()
            sum_of_squares += (image ** 2).sum()
            num_pixels += image.numel()

        else:
            print("none batch")
    
    # Calculate the mean and standard deviation
    mean = sum_of_images / num_pixels
    std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

    print(f'Mean: {mean}, Standard Deviation: {std_dev}')
    return mean.item(), std_dev.item()

mean_no_thorax_third_sub, std_no_thorax_third_sub = get_mean_std_dataloader(statistics_loader_no_thorax_third_sub)
print(f"Calculated - Mean: {mean_no_thorax_third_sub}, Standard Deviation: {std_no_thorax_third_sub}")


In [None]:
mean_no_thorax_third_sub, std_no_thorax_third_sub = 10.217764854431152, 26.677101135253906


In [None]:
# Create final transforms with calculated statistics
test_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=lambda image: image.convert("L"))
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

train_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=lambda image: image.convert("L"))
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])


In [None]:
# Create datasets
train_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

val_ds_no_thorax_third_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

test_ds_no_thorax_third_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)


In [None]:
# Create dataloaders
train_loader_no_thorax_third_sub = DataLoader(
    train_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False
)

val_loader_no_thorax_third_sub = DataLoader(
    val_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

test_loader_no_thorax_third_sub = DataLoader(
    test_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)


# Training FCN-FFNET


In [None]:
g = reseed()
ENCODER_NAME = None
fcn_ffnet_model = BreastSegmentationModel(
    arch="fcn_ffnet", 
    encoder_name=ENCODER_NAME, 
    loss_function="crossentropy2d",
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_fcn_ffnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='unet-clahe--mit2-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)

trainer_fcn_ffnet_model = L.Trainer(
    devices = 1,
    accelerator="auto",
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es, cc_fcn_ffnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_fcn_ffnet_model.fit(
    fcn_ffnet_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)

In [None]:
fcn_ffnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_fcn_ffnet_model.best_model_path,
    loss_function="crossentropy2d", 
    use_boundary_loss=False
)
test_metrics = trainer_fcn_ffnet_model.test(
    fcn_ffnet_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics[0])

# Training Swin-UNETR


In [None]:
g = reseed()
ENCODER_NAME = None
swin_unetr_model = BreastSegmentationModel(
    arch="swin_unetr", 
    encoder_name=ENCODER_NAME, 
    loss_function='soft_dice',
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_swin = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_swin_unetr_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='swin-unetr-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_swin_unetr_model = L.Trainer(
    devices=1,
    accelerator="auto",
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_swin, cc_swin_unetr_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_swin_unetr_model.fit(
    swin_unetr_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)

In [None]:
swin_unetr_model = BreastSegmentationModel.load_from_checkpoint(
    cc_swin_unetr_model.best_model_path,
    loss_function='soft_dice',
    use_boundary_loss=False
)
test_metrics_swin = trainer_swin_unetr_model.test(
    swin_unetr_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics_swin[0])


# Training Unet++


In [None]:
g = reseed()
ENCODER_NAME = None
unetplusplus_model = BreastSegmentationModel(
    arch="unetplusplus", 
    encoder_name=ENCODER_NAME, 
    loss_function="dice_ce",
    loss_kwargs={"sigmoid": True, "lambda_dice": 0.5, "lambda_ce": 0.5},
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_unetpp = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_unetplusplus_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='unetplusplus-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_unetplusplus_model = L.Trainer(
    devices=1,
    accelerator="auto",
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_unetpp, cc_unetplusplus_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_unetplusplus_model.fit(
    unetplusplus_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)


In [None]:
unetplusplus_model = BreastSegmentationModel.load_from_checkpoint(
    cc_unetplusplus_model.best_model_path,
    loss_function="dice_ce",
    use_boundary_loss=False
)
test_metrics_unetpp = trainer_unetplusplus_model.test(
    unetplusplus_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics_unetpp[0])


# Training SegNet


In [None]:
g = reseed()
ENCODER_NAME = None
segnet_model = BreastSegmentationModel(
    arch="segnet", 
    encoder_name=ENCODER_NAME, 
    loss_function="crossentropy2d",
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_segnet = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_segnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='segnet-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_segnet_model = L.Trainer(
    devices=1,
    accelerator="auto",
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_segnet, cc_segnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_segnet_model.fit(
    segnet_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)


In [None]:
segnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_segnet_model.best_model_path,
    loss_function="crossentropy2d", 
    use_boundary_loss=False
)
test_metrics_segnet = trainer_segnet_model.test(
    segnet_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics_segnet[0])


# Training Skinny


In [None]:
# Initialize SkinnyNet model
skinny_model = BreastSegmentationModel(
    arch="skinny",
    encoder_name=None,
    loss_function="dice_ce",
    loss_kwargs={"sigmoid": True, "lambda_dice": 0.5, "lambda_ce": 0.5},
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0],
    in_channels=config.IN_CHANNELS,
    out_classes=config.OUT_CHANNELS,
    batch_size=batch_size,
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

# Early stopping for SkinnyNet
es_skinny = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

# Model checkpoint for SkinnyNet
cc_skinny_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    dirpath="./checkpoints/",
    filename="skinny_best",
    auto_insert_metric_name=False
)

In [None]:
# Trainer for SkinnyNet
trainer_skinny_model = L.Trainer(
    max_epochs=config.MAX_EPOCHS,
    accelerator="auto",
    devices=1,
    callbacks=[es_skinny, cc_skinny_model],
    deterministic=True,
    precision=16
)

# Train SkinnyNet
trainer_skinny_model.fit(
    skinny_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)


In [None]:
# Load best SkinnyNet model and test
skinny_model = BreastSegmentationModel.load_from_checkpoint(
    cc_skinny_model.best_model_path,
    loss_function="dice_ce",
    use_boundary_loss=False
)
test_metrics_skinny = trainer_skinny_model.test(
    skinny_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics_skinny[0])

# Training ResNet-UNet


In [None]:
g = reseed()
ENCODER_NAME = "resnet50"
resnet_model = BreastSegmentationModel(
    arch="UNet", 
    encoder_name=ENCODER_NAME, 
    loss_function="dice",
    use_boundary_loss=False,
    img_size = config.IMAGE_SIZE[0], 
    in_channels=config.IN_CHANNELS, 
    out_classes=config.OUT_CHANNELS, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es_resnet = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_resnet_model = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='resnet-unet-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_resnet_model = L.Trainer(
    devices=1,
    accelerator="auto",
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_resnet, cc_resnet_model],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=True
)


In [None]:
trainer_resnet_model.fit(
    resnet_model,
    train_dataloaders=train_loader_no_thorax_third_sub,
    val_dataloaders=val_loader_no_thorax_third_sub
)


In [None]:
resnet_model = BreastSegmentationModel.load_from_checkpoint(
    cc_resnet_model.best_model_path,
    loss_function="dice", 
    use_boundary_loss=False
)
test_metrics_resnet = trainer_resnet_model.test(
    resnet_model, 
    dataloaders=test_loader_no_thorax_third_sub, 
    verbose=False
)
pp.pprint(test_metrics_resnet[0])