# VENUS FUSION TRAINING NOTEBOOK

This notebook implements fusion-based breast segmentation using the VENUS model with multi-scale patch processing.


## 1. Installs & Imports


In [None]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings
import pprint
import copy
pp = pprint.PrettyPrinter(indent=4)

# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import seed_everything

from natsort import natsorted, ns
from PIL import Image
from numpy import einsum
from torch.utils.data import default_collate
import psutil
from typing import List, Tuple
import monai

# Import modules
from breast_segmentation.config.settings import config
from breast_segmentation.utils.seed import set_deterministic_mode, seed_worker, reseed
from breast_segmentation.data.dataset import (
    get_image_label_files, create_data_dicts,
    PairedDataset, PairedDataLoader
)
from breast_segmentation.data import custom_collate_no_patches, custom_collate
from breast_segmentation.transforms.compose import Preprocess
from breast_segmentation.models.fusion_module import BreastFusionModel
from breast_segmentation.models.lightning_module import BreastSegmentationModel
from breast_segmentation.models.architectures import get_model, VENUS
from breast_segmentation.metrics.losses import (
    get_loss_function, CrossEntropy2d, compute_class_weight, 
    AsymmetricUnifiedFocalLoss, CABFL
)

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()

# Define converter function to avoid pickling issues with lambda
def convert_to_grayscale(image):
    """Convert PIL image to grayscale - replaces lambda for pickling compatibility."""
    return image.convert("L")


## 2. Environment Setup


In [None]:
# Configuration - using config parameters
batch_size = config.BATCH_SIZE
num_workers = config.NUM_WORKERS
checkpoints_dir = config.checkpoints_dir_breadm
get_boundaryloss = True

# Ensure checkpoints directory exists
os.makedirs(checkpoints_dir, exist_ok=True)

# Set random seed for reproducibility
g = set_deterministic_mode(config.SEED)

print(f"Batch size: {batch_size}")
print(f"Number of workers: {num_workers}")
print(f"Checkpoints directory: {checkpoints_dir}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name()}")


## 3. Data Preparation


In [None]:
# Get image and label files
dataset_base_path = config.DATASET_BASE_PATH_BREADM
image_type = "VIBRANT+C2"

train_images, train_labels = get_image_label_files(
    dataset_base_path, "train", image_type
)
val_images, val_labels = get_image_label_files(
    dataset_base_path, "val", image_type
)
test_images, test_labels = get_image_label_files(
    dataset_base_path, "test", image_type
)

# Create data dictionaries
train_dicts = create_data_dicts(train_images, train_labels)
val_dicts = create_data_dicts(val_images, val_labels)
test_dicts = create_data_dicts(test_images, test_labels)

print(f"Dataset statistics:")
print(f"  Training samples: {len(train_dicts)}")
print(f"  Validation samples: {len(val_dicts)}")
print(f"  Test samples: {len(test_dicts)}")


## 4. Data Preprocessing and Dataset Creation


In [None]:
# Define subtracted images path prefixes
sub_third_images_path_prefixes = ("VIBRANT+C2", "SUB2")

print("Will calculate normalization statistics from data...")

In [None]:
# Calculate mean and std for global (no patches) data
statistics_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics',  
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=False,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for global data...")

In [None]:
# Create statistics dataset and loader for global data
statistics_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=statistics_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

statistics_loader_no_thorax_third_sub = DataLoader(
    statistics_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

print("Statistics dataset and loader created for global data")


In [None]:
# Calculate mean and std for global data
def get_mean_std_dataloader(dataloader, masked=False):
    """Calculate mean and std from dataloader."""
    sum_of_images = 0.0
    sum_of_squares = 0.0
    num_pixels = 0

    for batch in tqdm(dataloader):
        if batch is not None:
            image = batch["image"]
      
            if masked:
                mask = image > 0.0
                image = image[mask]
      
            sum_of_images += image.sum()
            sum_of_squares += (image ** 2).sum()
            num_pixels += image.numel()
        else:
            print("none batch")
    
    mean = sum_of_images / num_pixels
    std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

    print(f'Mean: {mean}, Standard Deviation: {std_dev}')
    return mean.item(), std_dev.item()

print("Calculating mean and std for global data...")
mean_no_thorax_third_sub_calc, std_no_thorax_third_sub_calc = get_mean_std_dataloader(statistics_loader_no_thorax_third_sub)
print(f"Calculated - Global Mean: {mean_no_thorax_third_sub_calc}, Global Std: {std_no_thorax_third_sub_calc}")

In [None]:
# Use pre-computed values
mean_no_thorax_third_sub, std_no_thorax_third_sub = 10.217766761779785, 26.677101135253906

In [None]:
# Calculate mean and std for patches data
statistics_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics', 
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=True,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for patches data...")


In [None]:
statistics_ds_patches_sub = CacheDataset(
     data=train_dicts, 
     transform=statistics_transforms_patches_sub,
     num_workers=num_workers
 )

statistics_loader_patches_sub = DataLoader(
     statistics_ds_patches_sub, 
     batch_size=batch_size, 
     worker_init_fn=seed_worker,
     generator=g, 
     shuffle=False, 
     drop_last=False
 )

print("Calculating mean and std for patches data...")
mean_patches_sub_calc, std_patches_sub_calc = get_mean_std_dataloader(statistics_loader_patches_sub)
print(f"Calculated - Patches Mean: {mean_patches_sub_calc}, Patches Std: {std_patches_sub_calc}")

In [None]:
mean_patches_sub, std_patches_sub = 20.63081550598144, 35.328887939453125

In [None]:
# Ensure checkpoints_dir is properly set for training
checkpoints_dir = config.checkpoints_dir_breadm
os.makedirs(checkpoints_dir, exist_ok=True)
print(f"VENUS Training - Using checkpoint directory: {checkpoints_dir}")


In [None]:
# Create final transforms using calculated statistics
test_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

train_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

# Create transforms for patches data
train_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

test_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="BREADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

print("Final transforms created using calculated statistics")


In [None]:
# Create datasets for both global and patches data
train_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

val_ds_no_thorax_third_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

test_ds_no_thorax_third_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

train_ds_patches_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_patches_sub,
    num_workers=num_workers
)

val_ds_patches_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

test_ds_patches_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

print("Datasets created for both global and patches data")

## 5. Create Fusion DataLoaders


In [None]:
# Reseed before creating dataloaders
g = reseed()

# Create fusion dataloaders that combine global and patches data
train_loader_fusion_sub = PairedDataLoader(
    train_ds_no_thorax_third_sub, 
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False,
)

val_loader_fusion_sub = PairedDataLoader(
    val_ds_no_thorax_third_sub, 
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

test_loader_fusion_sub = PairedDataLoader(
    test_ds_no_thorax_third_sub, 
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

print("Fusion dataloaders created")

## 6. Training VENUS Model


In [None]:
# Ensure checkpoints_dir is properly set for ResNet training  
checkpoints_dir = config.checkpoints_dir_breadm
os.makedirs(checkpoints_dir, exist_ok=True)
print(f"ResNet Training - Using checkpoint directory: {checkpoints_dir}")


In [None]:
import gc
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()


In [None]:
g = reseed()

# Create VENUS fusion model with CABFL loss
model_fusion_sub_cabl = BreastFusionModel(
    arch="venus",
    encoder_name=None,
    in_channels=config.IN_CHANNELS,
    out_classes=config.OUT_CHANNELS,
    batch_size=batch_size,
    len_train_loader=len(train_ds_no_thorax_third_sub) // batch_size,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1},
    base_channels=64,
    use_simple_fusion=True,
    use_decoder_attention=True
)

es = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_fusion_sub_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='venus-fusion-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_fusion_sub_cabl = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es, cc_fusion_sub_cabl],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_fusion_sub_cabl.fit(
    model_fusion_sub_cabl,
    train_dataloaders=train_loader_fusion_sub,
    val_dataloaders=val_loader_fusion_sub
)


In [None]:
# Load best model and test
model_fusion_sub_cabl = BreastFusionModel.load_from_checkpoint(
    cc_fusion_sub_cabl.best_model_path,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics = trainer_fusion_sub_cabl.test(
    model_fusion_sub_cabl, 
    dataloaders=test_loader_fusion_sub, 
    verbose=False
)
pp.pprint(test_metrics[0])


## 7. Training ResNet with Patches

In [None]:
# Create individual loaders for patches data
train_loader_patches_sub = DataLoader(
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False
)

val_loader_patches_sub = DataLoader(
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

test_loader_patches_sub = DataLoader(
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)


In [None]:
g = reseed()
ENCODER_NAME = "resnet18"
model_resnet_patches_cabl = BreastSegmentationModel(
    arch="UNet",
    encoder_name=ENCODER_NAME,
    in_channels=config.IN_CHANNELS,
    out_classes=config.OUT_CHANNELS,
    batch_size=batch_size,
    len_train_loader=len(train_ds_patches_sub) // batch_size,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.7, "gamma": 0.4}
)

es_resnet = EarlyStopping(monitor="valid_loss", mode="min", patience=config.EARLY_STOPPING_PATIENCE)

cc_resnet_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='resnet18-patches-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_resnet_cabl = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=config.MAX_EPOCHS,
    callbacks=[es_resnet, cc_resnet_cabl],
    log_every_n_steps=config.LOG_EVERY_N_STEPS,
    gradient_clip_val=config.GRADIENT_CLIP_VAL,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_resnet_cabl.fit(
    model_resnet_patches_cabl,
    train_dataloaders=train_loader_patches_sub,
    val_dataloaders=val_loader_patches_sub
)


In [None]:
# Load best ResNet patches model and test
model_resnet_patches_cabl = BreastSegmentationModel.load_from_checkpoint(
    cc_resnet_cabl.best_model_path,
    strict=True,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics_resnet = trainer_resnet_cabl.test(
    model_resnet_patches_cabl, 
    dataloaders=test_loader_patches_sub, 
    verbose=False
)
pp.pprint(test_metrics_resnet[0])
