# VENUS FUSION TRAINING NOTEBOOK

This notebook implements fusion-based breast segmentation using the VENUS model with multi-scale patch processing.


## 1. Installs & Imports


In [None]:
# Standard libraries
import os
import json
import glob
import shutil
import tempfile
import random
import warnings
import pprint
import copy
pp = pprint.PrettyPrinter(indent=4)

# Third-party libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
import nibabel as nib
import albumentations as A
import segmentation_models_pytorch as smp
import torch.nn as nn
import torch.nn.functional as F
from skimage import filters
from skimage.measure import label as label_fn, regionprops
from skimage import morphology
from sklearn.model_selection import train_test_split
from copy import deepcopy
from tqdm.notebook import tqdm

# MONAI related imports
from monai.config import print_config
from monai.networks.nets import UNet, SwinUNETR, BasicUNetPlusPlus
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss, DiceCELoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
from monai.transforms import (
    AsDiscrete, AsDiscreted, EnsureChannelFirstd, Compose, CropForegroundd,
    LoadImaged, Orientationd, RandCropByPosNegLabeld, SaveImaged, ScaleIntensityRanged,
    Spacingd, Invertd, ResizeWithPadOrCropd, Resized, MapTransform, ScaleIntensityd,
    LabelToContourd, ForegroundMaskd, HistogramNormalized, RandFlipd, RandGridDistortiond,
    RandHistogramShiftd, RandRotated
)
from monai.handlers.utils import from_engine
from monai.utils.type_conversion import convert_to_numpy

# PyTorch Lightning related imports
import lightning.pytorch as L
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch import seed_everything

# Weights & Biases
import wandb

from natsort import natsorted, ns
from PIL import Image
from numpy import einsum
from torch.utils.data import default_collate
import psutil
from typing import List, Tuple
import monai

# Import modules
from breast_segmentation.config.settings import config
from breast_segmentation.utils.seed import set_deterministic_mode, seed_worker, reseed
from breast_segmentation.data.dataset import (
    get_image_label_files, create_data_dicts, split_data, create_dataloaders,
    PairedDataset, PairedDataLoader
)
from breast_segmentation.transforms.compose import Preprocess
from breast_segmentation.models.fusion_module import BreastFusionModel
from breast_segmentation.models.lightning_module import BreastSegmentationModel
from breast_segmentation.models.architectures import get_model, VENUS
from breast_segmentation.metrics.losses import (
    get_loss_function, CrossEntropy2d, compute_class_weight, 
    AsymmetricUnifiedFocalLoss, CABFL
)
from breast_segmentation.utils.visualization import (
    plot_batch_predictions, plot_training_history
)
from breast_segmentation.utils.postprocessing import (
    remove_far_masses_based_on_largest_mass
)

from utiils import *
from boundaryloss.dataloader import dist_map_transform

# Set precision for matmul operations and print MONAI config
torch.set_float32_matmul_precision('medium')
print_config()

# Define converter function to avoid pickling issues with lambda
def convert_to_grayscale(image):
    """Convert PIL image to grayscale - replaces lambda for pickling compatibility."""
    return image.convert("L")


  from torch.distributed.optim import ZeroRedundancyOptimizer


MONAI version: 1.6.dev2535
Numpy version: 2.0.2
Pytorch version: 2.5.1+cu121
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: 8e677816bfd1fb2ec541d7f951db4caaf210b150
MONAI __file__: c:\Users\<username>\AppData\Local\pypoetry\Cache\virtualenvs\venus-nCPuPPcI-py3.9\lib\site-packages\monai\__init__.py

Optional dependencies:
Pytorch Ignite version: 0.4.11
ITK version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: 5.3.2
scikit-image version: 0.24.0
scipy version: 1.13.1
Pillow version: 11.3.0
Tensorboard version: NOT INSTALLED or UNKNOWN VERSION.
gdown version: 5.2.0
TorchVision version: 0.20.1+cu121
tqdm version: 4.67.1
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 7.0.0
pandas version: NOT INSTALLED or UNKNOWN VERSION.
einops version: 0.8.1
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.
pynrrd version: NOT INSTALLED or UNKNOWN VERSION.
clearml version: NOT INSTALLED o

## 2. Environment Setup


In [2]:
# Configuration
batch_size = 16  # Reduced for fusion training
num_workers = min(4, os.cpu_count())  # Use multiprocessing for faster data loading
checkpoints_dir = config.CHECKPOINTS_DIR
get_boundaryloss = True

# Ensure checkpoints directory exists
os.makedirs(checkpoints_dir, exist_ok=True)

# Set random seed for reproducibility
g = set_deterministic_mode(config.SEED)

print(f"Batch size: {batch_size}")
print(f"Number of workers: {num_workers}")
print(f"Checkpoints directory: {checkpoints_dir}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Device: {torch.cuda.get_device_name()}")


Seed set to 200


Using random seed 200...
Batch size: 16
Number of workers: 4
Checkpoints directory: checkpoints
CUDA available: True
Device: NVIDIA GeForce RTX 3060


## 3. Data Preparation


In [3]:
# Get image and label files
dataset_base_path = "BreaDM/seg"
image_type = "VIBRANT+C2"

train_images, train_labels = get_image_label_files(
    dataset_base_path, "train", image_type
)
val_images, val_labels = get_image_label_files(
    dataset_base_path, "val", image_type
)
test_images, test_labels = get_image_label_files(
    dataset_base_path, "test", image_type
)

# Create data dictionaries
train_dicts = create_data_dicts(train_images, train_labels)
val_dicts = create_data_dicts(val_images, val_labels)
test_dicts = create_data_dicts(test_images, test_labels)

print(f"Dataset statistics:")
print(f"  Training samples: {len(train_dicts)}")
print(f"  Validation samples: {len(val_dicts)}")
print(f"  Test samples: {len(test_dicts)}")


Dataset statistics:
  Training samples: 1202
  Validation samples: 117
  Test samples: 417


## 4. Data Preprocessing and Dataset Creation


In [None]:
# Define subtracted images path prefixes
sub_third_images_path_prefixes = ("VIBRANT+C2", "SUB2")

print("Will calculate normalization statistics from data...")


Will calculate normalization statistics from data...
This allows verification against original pre-computed values:


In [5]:
# Calculate mean and std for global (no patches) data
statistics_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics',  
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=False,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for global data...")


Creating statistics dataset for global data...


In [6]:
# Create statistics dataset and loader for global data
statistics_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=statistics_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

statistics_loader_no_thorax_third_sub = DataLoader(
    statistics_ds_no_thorax_third_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

print("Statistics dataset and loader created for global data")


Loading dataset: 100%|██████████| 1202/1202 [00:39<00:00, 30.31it/s]

Statistics dataset and loader created for global data





In [7]:
# Calculate mean and std for global data
def get_mean_std_dataloader(dataloader, masked=False):
    """Calculate mean and std from dataloader."""
    sum_of_images = 0.0
    sum_of_squares = 0.0
    num_pixels = 0

    for batch in tqdm(dataloader):
        if batch is not None:
            image = batch["image"]
      
            if masked:
                mask = image > 0.0
                image = image[mask]
      
            sum_of_images += image.sum()
            sum_of_squares += (image ** 2).sum()
            num_pixels += image.numel()
        else:
            print("none batch")
    
    mean = sum_of_images / num_pixels
    std_dev = (sum_of_squares / num_pixels - mean ** 2) ** 0.5

    print(f'Mean: {mean}, Standard Deviation: {std_dev}')
    return mean.item(), std_dev.item()

print("Calculating mean and std for global data...")
mean_no_thorax_third_sub_calc, std_no_thorax_third_sub_calc = get_mean_std_dataloader(statistics_loader_no_thorax_third_sub)
print(f"Calculated - Global Mean: {mean_no_thorax_third_sub_calc}, Global Std: {std_no_thorax_third_sub_calc}")

# Use calculated or pre-computed values
mean_no_thorax_third_sub, std_no_thorax_third_sub = mean_no_thorax_third_sub_calc, std_no_thorax_third_sub_calc


Calculating mean and std for global data...


  0%|          | 0/76 [00:00<?, ?it/s]

Mean: 10.217766761779785, Standard Deviation: 26.677101135253906
Calculated - Global Mean: 10.217766761779785, Global Std: 26.677101135253906


In [8]:
# Calculate mean and std for patches data
statistics_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='statistics', 
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        get_patches=True,
        get_boundaryloss=False
    )
])

print("Creating statistics dataset for patches data...")


Creating statistics dataset for patches data...


In [9]:
statistics_ds_patches_sub = CacheDataset(
     data=train_dicts, 
     transform=statistics_transforms_patches_sub,
     num_workers=num_workers
 )

statistics_loader_patches_sub = DataLoader(
     statistics_ds_patches_sub, 
     batch_size=batch_size, 
     worker_init_fn=seed_worker,
     generator=g, 
     shuffle=False, 
     drop_last=False
 )

print("Calculating mean and std for patches data...")
mean_patches_sub_calc, std_patches_sub_calc = get_mean_std_dataloader(statistics_loader_patches_sub)
print(f"Calculated - Patches Mean: {mean_patches_sub_calc}, Patches Std: {std_patches_sub_calc}")

# Set final values
mean_patches_sub, std_patches_sub = mean_patches_sub_calc, std_patches_sub_calc

Loading dataset: 100%|██████████| 1202/1202 [01:57<00:00, 10.25it/s]

Calculating mean and std for patches data...





  0%|          | 0/76 [00:00<?, ?it/s]

Mean: 20.630815505981445, Standard Deviation: 35.328887939453125
Calculated - Patches Mean: 20.630815505981445, Patches Std: 35.328887939453125


In [10]:
# Create final transforms using calculated statistics
test_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

train_transforms_no_thorax_third_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_no_thorax_third_sub, 
        divisor=std_no_thorax_third_sub, 
        get_patches=False,
        get_boundaryloss=get_boundaryloss
    )
])

# Create transforms for patches data
train_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='train', 
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

test_transforms_patches_sub = Compose([
    LoadImaged(
        keys=["image", "label"], 
        image_only=False, 
        reader=monai.data.PILReader(converter=convert_to_grayscale)
    ),
    EnsureChannelFirstd(keys=["image", "label"]),
    monai.transforms.Rotate90d(keys=["image", "label"]),
    Preprocess(
        keys=None, 
        mode='test',  
        dataset="BRADM", 
        subtracted_images_path_prefixes=sub_third_images_path_prefixes, 
        subtrahend=mean_patches_sub, 
        divisor=std_patches_sub, 
        get_patches=True,
        get_boundaryloss=get_boundaryloss
    )
])

print("Final transforms created using calculated statistics")


Final transforms created using calculated statistics


In [11]:
# Create datasets for both global and patches data
train_ds_no_thorax_third_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

val_ds_no_thorax_third_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

test_ds_no_thorax_third_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_no_thorax_third_sub,
    num_workers=num_workers
)

train_ds_patches_sub = CacheDataset(
    data=train_dicts, 
    transform=train_transforms_patches_sub,
    num_workers=num_workers
)

val_ds_patches_sub = CacheDataset(
    data=val_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

test_ds_patches_sub = CacheDataset(
    data=test_dicts, 
    transform=test_transforms_patches_sub,
    num_workers=num_workers
)

print("Datasets created for both global and patches data")


Loading dataset: 100%|██████████| 1202/1202 [00:45<00:00, 26.61it/s]
Loading dataset: 100%|██████████| 117/117 [00:04<00:00, 28.62it/s]
Loading dataset: 100%|██████████| 417/417 [00:14<00:00, 28.94it/s]
Loading dataset: 100%|██████████| 1202/1202 [01:57<00:00, 10.19it/s]
Loading dataset: 100%|██████████| 117/117 [00:11<00:00,  9.76it/s]
Loading dataset: 100%|██████████| 417/417 [00:36<00:00, 11.28it/s]

Datasets created for both global and patches data





## 5. Create Fusion DataLoaders


In [12]:
# Reseed before creating dataloaders
g = reseed()

# Create fusion dataloaders that combine global and patches data
train_loader_fusion_sub = PairedDataLoader(
    train_ds_no_thorax_third_sub, 
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False,
)

val_loader_fusion_sub = PairedDataLoader(
    val_ds_no_thorax_third_sub, 
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

test_loader_fusion_sub = PairedDataLoader(
    test_ds_no_thorax_third_sub, 
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False, 
    num_workers=num_workers,
    augment=False
)

print("Fusion dataloaders created")


Seed set to 200


Using random seed 200...
Fusion dataloaders created


## 6. Training VENUS Fusion Model with CABFL


In [13]:
import gc
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()


In [None]:
g = reseed()

# Create VENUS fusion model with CABFL loss
model_fusion_sub_cabl = BreastFusionModel(
    arch="venus",
    encoder_name=None,
    in_channels=1,
    out_classes=1,
    batch_size=batch_size,
    len_train_loader=len(train_ds_no_thorax_third_sub) // batch_size,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1},
    base_channels=64,
    use_simple_fusion=False,
    use_decoder_attention=True
)

es = EarlyStopping(monitor="valid_loss", mode="min", patience=10)

cc_fusion_sub_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='venus-fusion-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)

wandb.login(key="2bc18e4744fb0771a16fd009b7aa2c98c79efc49")
wandb_logger = WandbLogger(project='Tesi-final', log_model=False)

trainer_fusion_sub_cabl = L.Trainer(
    logger=wandb_logger,
    devices=1,
    accelerator='gpu',
    max_epochs=1000,
    callbacks=[es, cc_fusion_sub_cabl],
    log_every_n_steps=10,
    gradient_clip_val=1,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_fusion_sub_cabl.fit(
    model_fusion_sub_cabl,
    train_dataloaders=train_loader_fusion_sub,
    val_dataloaders=val_loader_fusion_sub
)


Seed set to 200


Using random seed 200...
Initialized SurfaceLossBinary with [1]


wandb: Appending key for api.wandb.ai to your netrc file: C:\Users\pabli\_netrc
wandb: Currently logged in as: pablo-giaccaglia to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


c:\Users\pabli\AppData\Local\pypoetry\Cache\virtualenvs\venus-nCPuPPcI-py3.9\lib\site-packages\lightning\pytorch\callbacks\model_checkpoint.py:701: Checkpoint directory C:\Users\pabli\Desktop\venus\checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type  | Params | Mode 
------------------------------------------
0 | model   | VENUS | 21.7 M | train
1 | loss_fn | CABFL | 0      | train
------------------------------------------
21.7 M    Trainable params
0         Non-trainable params
21.7 M    Total params
86.860    Total estimated model params size (MB)
331       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\pabli\AppData\Local\pypoetry\Cache\virtualenvs\venus-nCPuPPcI-py3.9\lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:428: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


In [None]:
# Load best model and test
model_fusion_sub_cabl = BreastFusionModel.load_from_checkpoint(
    cc_fusion_sub_cabl.best_model_path,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics = trainer_fusion_sub_cabl.test(
    model_fusion_sub_cabl, 
    dataloaders=test_loader_fusion_sub, 
    verbose=False
)
pp.pprint(test_metrics[0])


In [None]:
wandb.finish()


## 7. Training VENUS Fusion Model with Simple Fusion


In [None]:
import gc
gc.collect()

with torch.no_grad():
    torch.cuda.empty_cache()


In [None]:
g = reseed()

# Create VENUS fusion model with simple fusion and smaller base channels
model_fusion_sub_cabl_simple = BreastFusionModel(
    arch="venus",
    encoder_name=None,
    in_channels=1,
    out_classes=1,
    batch_size=batch_size,
    len_train_loader=len(train_ds_no_thorax_third_sub) // batch_size,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1},
    base_channels=16,
    use_simple_fusion=True,
    use_decoder_attention=True
)

es_simple = EarlyStopping(monitor="valid_loss", mode="min", patience=10)

cc_fusion_sub_cabl_simple = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='venus-fusion-simple-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)

wandb.login(key="2bc18e4744fb0771a16fd009b7aa2c98c79efc49")
wandb_logger_simple = WandbLogger(project='Tesi-final', log_model=False)

trainer_fusion_sub_cabl_simple = L.Trainer(
    logger=wandb_logger_simple,
    devices=1,
    accelerator='gpu',
    max_epochs=1000,
    callbacks=[es_simple, cc_fusion_sub_cabl_simple],
    log_every_n_steps=10,
    gradient_clip_val=1,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_fusion_sub_cabl_simple.fit(
    model_fusion_sub_cabl_simple,
    train_dataloaders=train_loader_fusion_sub,
    val_dataloaders=val_loader_fusion_sub
)


In [None]:
# Load best simple fusion model and test
model_fusion_sub_cabl_simple = BreastFusionModel.load_from_checkpoint(
    cc_fusion_sub_cabl_simple.best_model_path,
    base_channels=16,
    use_boundary_loss=True,
    use_simple_fusion=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics_simple = trainer_fusion_sub_cabl_simple.test(
    model_fusion_sub_cabl_simple, 
    dataloaders=test_loader_fusion_sub, 
    verbose=False
)
pp.pprint(test_metrics_simple[0])


In [None]:
trainer_fusion_sub_cabl_simple.save_checkpoint("VENUS-FUSION-SIMPLE-CABL.ckpt")


In [None]:
wandb.finish()


## 8. Training ResNet with Patches

In [None]:
from typing import Tuple

# Create individual loaders for patches data (used by ResNet baseline)
train_loader_patches_sub = DataLoader(
    train_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=True, 
    drop_last=False
)

val_loader_patches_sub = DataLoader(
    val_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)

test_loader_patches_sub = DataLoader(
    test_ds_patches_sub, 
    batch_size=batch_size, 
    worker_init_fn=seed_worker,
    generator=g, 
    shuffle=False, 
    drop_last=False
)


In [None]:
g = reseed()
ENCODER_NAME = "resnet18"
model_resnet_patches_cabl = BreastSegmentationModel(
    arch="UNet",
    encoder_name=ENCODER_NAME,
    in_channels=1,
    out_classes=1,
    batch_size=batch_size,
    len_train_loader=len(train_ds_patches_sub) // batch_size,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.7, "gamma": 0.4}
)

es_resnet = EarlyStopping(monitor="valid_loss", mode="min", patience=10)

cc_resnet_cabl = ModelCheckpoint(
    monitor="valid_loss",
    save_top_k=1,
    mode="min",
    filename='resnet18-patches-cabl-{epoch:02d}-{valid_loss:.2f}',
    dirpath=checkpoints_dir,
    auto_insert_metric_name=False
)

wandb.login(key="2bc18e4744fb0771a16fd009b7aa2c98c79efc49")
wandb_logger_resnet = WandbLogger(project='Tesi-final', log_model=False)

trainer_resnet_cabl = L.Trainer(
    logger=wandb_logger_resnet,
    devices=1,
    accelerator='gpu',
    max_epochs=1000,
    callbacks=[es_resnet, cc_resnet_cabl],
    log_every_n_steps=10,
    gradient_clip_val=1,
    num_sanity_val_steps=1,
    deterministic=False
)

trainer_resnet_cabl.fit(
    model_resnet_patches_cabl,
    train_dataloaders=train_loader_patches_sub,
    val_dataloaders=val_loader_patches_sub
)


In [None]:
# Load best ResNet patches model and test
model_resnet_patches_cabl = BreastSegmentationModel.load_from_checkpoint(
    cc_resnet_cabl.best_model_path,
    strict=True,
    use_boundary_loss=True,
    loss_function="cabfl",
    loss_kwargs={"idc": [1], "weight_aufl": 0.5, "delta": 0.4, "gamma": 0.1}
)
test_metrics_resnet = trainer_resnet_cabl.test(
    model_resnet_patches_cabl, 
    dataloaders=test_loader_patches_sub, 
    verbose=False
)
pp.pprint(test_metrics_resnet[0])


In [None]:
trainer_resnet_cabl.save_checkpoint("RESNET18-PATCHES-BREADM-CABL.ckpt")


In [None]:
wandb.finish()


## 9. Results Summary


In [None]:
# Summary of all results
print("="*60)
print("FINAL RESULTS SUMMARY - FUSION MODELS")
print("="*60)

models_results = [
    ("VENUS Fusion (Full)", test_metrics[0] if 'test_metrics' in locals() else None),
    ("VENUS Fusion (Simple)", test_metrics_simple[0] if 'test_metrics_simple' in locals() else None),
    ("ResNet18-Patches", test_metrics_resnet[0] if 'test_metrics_resnet' in locals() else None)
]

for model_name, results in models_results:
    if results:
        print(f"\n{model_name}:")
        key_metrics = [
            'test_per_dataset_dice', 'test_per_dataset_iou', 
            'test_mean_dice_per_dataset', 'test_mean_iou_per_dataset',
            'test_accuracy', 'test_precision', 'test_recall'
        ]
        for metric in key_metrics:
            if metric in results:
                print(f"  {metric}: {results[metric]:.4f}")
    else:
        print(f"\n{model_name}: Not trained")

print("\n" + "="*60)
print("FUSION NOTEBOOK COMPLETED")
print("="*60)
