# VENUS FUSION TRAINING NOTEBOOK - PRIVATE DATASET

This notebook implements breast segmentation using the VENUS model with multi-scale patch processing on the private dataset.


## 1. Installs & Imports


In [1]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings
import pprint
import copy
pp = pprint.PrettyPrinter(indent=4)

# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch import seed_everything

# 

from natsort import natsorted, ns
from PIL import Image
from numpy import einsum
from torch.utils.data import default_collate
import psutil
from typing import List, Tuple
import monai

# Import modules
from breast_segmentation.config.settings import config
from breast_segmentation.utils.seed import set_deterministic_mode, seed_worker, reseed

from breast_segmentation.data.private_dataset import (
    PATIENT_INFO, get_filenames, get_train_val_test_dicts, PATIENTS_TO_EXCLUDE,
    filter_samples_sample_aware, filter_samples_to_exclude, get_samples_size
)

from breast_segmentation.data.dataset import PairedDataset, PairedDataLoader


from breast_segmentation.data import custom_collate_no_patches, custom_collate
from breast_segmentation.transforms.compose import Preprocess
from breast_segmentation.models.fusion_module import BreastFusionModel
from breast_segmentation.models.lightning_module import BreastSegmentationModel
from breast_segmentation.models.architectures import get_model, VENUS
from breast_segmentation.metrics.losses import (
    get_loss_function, CrossEntropy2d, compute_class_weight, 
    AsymmetricUnifiedFocalLoss, CABFL
)


# Import additional loss functions used in original
from monai.losses import DiceLoss, DiceCELoss

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()

# Define converter function to avoid pickling issues with lambda
def convert_to_grayscale(image):
    """Convert PIL image to grayscale - replaces lambda for pickling compatibility."""
    return image.convert("L")


  from torch.distributed.optim import ZeroRedundancyOptimizer


MONAI version: 1.6.dev2535
Numpy version: 2.0.2
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8e677816bfd1fb2ec541d7f951db4caaf210b150
MONAI __file__: c:\Users\<username>\AppData\Local\pypoetry\Cache\virtualenvs\venus-nCPuPPcI-py3.9\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.13.1
Pillow version: 11.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED o

## 2. Environment Setup


In [2]:
# Settings
NUM_WORKERS = os.cpu_count()
SEED = 200
USE_SUBTRACTED = True
batch_size = 32

# Data paths
dataset_base_path = "Dataset-arrays-4-FINAL"
CHECKPOINTS_DIR = "./checkpoints/private-dataset"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

Using device: cuda


## 3. Data Preparation


In [3]:
print("Using private dataset backend functions")


patient_ids = list(PATIENT_INFO.keys())

# Apply exclusions
print(f"Initial patients from PATIENT_INFO: {len(patient_ids)}")
print(f"Patients after exclusion: {len(patient_ids)}")


x_train_val, x_test = train_test_split(patient_ids, test_size=0.2, random_state=SEED)
x_train, x_val = train_test_split(x_train_val, test_size=0.25, random_state=SEED)
train_dicts, val_dicts,test_dicts = get_train_val_test_dicts(dataset_base_path, x_train,x_val,x_test)
print(f"Dataset base path: {dataset_base_path}")
print(f"Total patients: {len(patient_ids)}")
print(f"Train patients: {len(x_train)}")
print(f"Validation patients: {len(x_val)}")
print(f"Test patients: {len(x_test)}")
print(f"Test patient IDs: {x_test}")

print(len(train_dicts))
print(len(val_dicts))
print(len(test_dicts))


Using private dataset backend functions
Initial patients from PATIENT_INFO: 103
Patients after exclusion: 103
Dataset base path: Dataset-arrays-4-FINAL
Total patients: 103
Train patients: 61
Validation patients: 21
Test patients: 21
Test patient IDs: ['TE0966', 'ASMK0783', 'BP130964', 'SD0462', 'PS0446(1,5)', 'GMG0961(3)', 'PV0781', 'MS0478', 'D2MP3(VR)', 'GA07(DF)', 'CS300759', 'BNB1172(DF)', 'RP271052', 'RP0178', 'BD0510', 'DBR270865', 'PRP0185', 'MP140270', 'GP0454', 'DMA0247', 'VS0976(1,5)']
3456
3968
4108


## 4. Data Preprocessing and Dataset Creation


In [4]:
# Define subtracted images path prefixes (CRITICAL: Private dataset DOES use subtracted images)
sub_third_images_path_prefixes = ("Dataset-arrays-4-FINAL", "Dataset-arrays-FINAL")

print("Will calculate normalization statistics from data...")
print("This allows verification against original pre-computed values:")

Will calculate normalization statistics from data...
This allows verification against original pre-computed values:


In [5]:
# Calculate mean and std for global (no patches) data
statistics_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics',  
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=False,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for global data...")


Creating statistics dataset for global data...


In [None]:
# Create statistics dataset and loader for global data
statistics_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=statistics_transforms_no_thorax_third_sub,
    num_workers=NUM_WORKERS
)

statistics_loader_no_thorax_third_sub = DataLoader(
    statistics_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False,
)

print("Statistics dataset and loader created for global data")


Loading dataset:  37%|███▋      | 1273/3456 [00:40<01:09, 31.50it/s]

In [None]:
g = reseed()
statistics_loader_no_thorax_third_sub = DataLoader(
    statistics_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

print("Statistics dataset and loader created for global data")


Seed set to 200


Using random seed 200...
Statistics dataset and loader created for global data


In [None]:
# Calculate mean and std for global data
def get_mean_std_dataloader(dataloader, masked=False):
    """Calculate mean and std from dataloader."""
    sum_of_images = 0.0
    sum_of_squares = 0.0
    num_pixels = 0

    for batch in tqdm(dataloader):
        if batch is not None:
            image = batch["image"]
      
            if masked:
                mask = image > 0.0
                image = image[mask]
      
            sum_of_images += image.sum()
            sum_of_squares += (image ** 2).sum()
            num_pixels += image.numel()
        else:
            print("none batch")
    
    mean = sum_of_images / num_pixels
    std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

    print(f'Mean: {mean}, Standard Deviation: {std_dev}')
    return mean.item(), std_dev.item()

print("Calculating mean and std for global data...")
mean_no_thorax_third_sub_calc, std_no_thorax_third_sub_calc = get_mean_std_dataloader(statistics_loader_no_thorax_third_sub)
print(f"Calculated - Global Mean: {mean_no_thorax_third_sub_calc}, Global Std: {std_no_thorax_third_sub_calc}")

Calculating mean and std for global data...


  0%|          | 0/108 [00:00<?, ?it/s]

Mean: 40.89126968383789, Standard Deviation: 167.93341064453125
Calculated - Global Mean: 40.89126968383789, Global Std: 167.93341064453125


In [34]:
mean_no_thorax_third_sub, std_no_thorax_third_sub = 43.14976119995117, 172.67039489746094

In [35]:
# Calculate mean and std for patches data
statistics_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics', 
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=True,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for patches data...")


Creating statistics dataset for patches data...


In [28]:
statistics_ds_patches_sub = CacheDataset(
     data=train_dicts, 
     transform=statistics_transforms_patches_sub,
     num_workers=NUM_WORKERS
 )

statistics_loader_patches_sub = DataLoader(
     statistics_ds_patches_sub, 
     batch_size=batch_size, 
     collate_fn = custom_collate,
     worker_init_fn=seed_worker,
     generator=g, 
     shuffle=False, 
     drop_last=False
 )

Loading dataset:   0%|          | 0/3456 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 3456/3456 [06:16<00:00,  9.17it/s]


In [29]:
print("Calculating mean and std for patches data...")
mean_patches_sub_calc, std_patches_sub_calc = get_mean_std_dataloader(statistics_loader_patches_sub)
print(f"Calculated - Patches Mean: {mean_patches_sub_calc}, Patches Std: {std_patches_sub_calc}")

# Set final values
mean_patches_sub, std_patches_sub = mean_patches_sub_calc, std_patches_sub_calc

Calculating mean and std for patches data...


  0%|          | 0/108 [00:00<?, ?it/s]

Mean: 86.1353759765625, Standard Deviation: 238.13462829589844
Calculated - Patches Mean: 86.1353759765625, Patches Std: 238.13462829589844


In [None]:
mean_patches_sub, std_patches_sub = 86.1353759765625, 238.13462829589844

In [None]:
# Create final transforms using calculated statistics
test_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

train_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

# Create transforms for patches data
train_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

test_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.NumpyReader()
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="private", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

print("Final transforms created using calculated statistics")


In [None]:
# Create datasets for both global and patches data
train_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

val_ds_no_thorax_third_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

test_ds_no_thorax_third_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

train_ds_patches_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_patches_sub,
    num_workers=num_workers
)

val_ds_patches_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

test_ds_patches_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

print("Datasets created for both global and patches data")


## 5. Create Fusion DataLoaders


In [None]:
# Reseed before creating dataloaders
g = reseed()

# Create fusion dataloaders that combine global and patches data
train_loader_fusion_sub = PairedDataLoader(
    train_ds_no_thorax_third_sub, 
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False,
)

val_loader_fusion_sub = PairedDataLoader(
    val_ds_no_thorax_third_sub, 
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

test_loader_fusion_sub = PairedDataLoader(
    test_ds_no_thorax_third_sub, 
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

print("Fusion dataloaders created")


## 6. Training VENUS Model


In [None]:
import gc
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()


In [None]:
g = reseed()

# Create VENUS fusion model with CABFL loss
ENCODER_NAME = None
model_fusion_sub_cabl = BreastFusionModel(
    arch="venus", 
    encoder_name=ENCODER_NAME, 
    use_boundary_loss=True, 
    use_simple_fusion=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.7, "gamma": 0.4}, 
    in_channels=1, 
    out_classes=1, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_no_thorax_third_sub)//batch_size
)

es = EarlyStopping(monitor="valid_loss", mode="min", patience=10)

cc_fusion_sub_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='venus-fusion-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_fusion_sub_cabl = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=1000,
    callbacks=[es, cc_fusion_sub_cabl],
    log_every_n_steps=10,
    gradient_clip_val=1,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_fusion_sub_cabl.fit(
    model_fusion_sub_cabl,
    train_dataloaders=train_loader_fusion_sub,
    val_dataloaders=val_loader_fusion_sub
)


In [None]:
# Load best model and test
model_fusion_sub_cabl = BreastFusionModel.load_from_checkpoint(
    cc_fusion_sub_cabl.best_model_path, 
    use_boundary_loss=True, 
    use_simple_fusion=True
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics = trainer_fusion_sub_cabl.test(
    model_fusion_sub_cabl, 
    dataloaders=test_loader_fusion_sub, 
    verbose=False
)
pp.pprint(test_metrics[0])

## 7. Training ResNet with Patches


In [None]:
# Create individual loaders for patches data (used by ResNet baseline)
train_loader_patches_sub = DataLoader(
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False
)

val_loader_patches_sub = DataLoader(
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

test_loader_patches_sub = DataLoader(
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

In [None]:
g = reseed()
ENCODER_NAME = "resnet18"
model_resnet_cabl = BreastSegmentationModel(
    arch="UNet", 
    encoder_name=ENCODER_NAME, 
    use_boundary_loss=True, 
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.7, "gamma": 0.4}, 
    in_channels=1, 
    out_classes=1, 
    batch_size=batch_size, 
    len_train_loader=len(train_ds_patches_sub)//batch_size
)

es_resnet = EarlyStopping(monitor="valid_loss", mode="min", patience=10)

cc_resnet_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='resnet18-patches-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)


trainer_resnet_cabl = L.Trainer(
    devices=1,
    accelerator='auto',
    max_epochs=1000,
    callbacks=[es_resnet, cc_resnet_cabl],
    log_every_n_steps=10,
    gradient_clip_val=1,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_resnet_cabl.fit(
    model_resnet_cabl,
    train_dataloaders=train_loader_patches_sub,
    val_dataloaders=val_loader_patches_sub
)

In [None]:
# Load best ResNet patches model and test
model_resnet_cabl = BreastSegmentationModel.load_from_checkpoint(
    cc_resnet_cabl.best_model_path, 
    strict=True,
    use_boundary_loss=True, 
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics_resnet = trainer_resnet_cabl.test(
    model_resnet_cabl, 
    dataloaders=test_loader_patches_sub, 
    verbose=False
)
pp.pprint(test_metrics_resnet[0])