# Dataset-Aware Inference Notebook


## Setup and Imports


In [None]:
import os
import sys
import json
import numpy as np
import torch
import monai
import matplotlib.pyplot as plt
import pprint
from tqdm import tqdm
from typing import Dict, List, Tuple, Optional

# MONAI imports
from monai.data import CacheDataset, DataLoader

# Import custom modules
from breast_segmentation.transforms import Preprocess
from breast_segmentation.models import BreastSegmentationModel, BreastFusionModel
from breast_segmentation.metrics.losses import (
    CABFL, SurfaceLossBinary, AsymmetricUnifiedFocalLoss, 
    AsymmetricFocalLoss, AsymmetricFocalTverskyLoss
)
from breast_segmentation.inference import (
    test_dataset_aware_no_patches,
    test_dataset_aware_fusion,
    test_dataset_aware_ensemble
)
from breast_segmentation.utils import (
    get_patient_ids, get_image_label_files_patient_aware, reverse_transformations,
    visualize_predictions, visualize_volume_predictions
)
from breast_segmentation.data import custom_collate_no_patches, custom_collate
from breast_segmentation.metrics import filter_masses

pp = pprint.PrettyPrinter(indent=2, width=100)

# Register loss classes for checkpoint loading
sys.modules['__main__'].CABFL = CABFL
sys.modules['__main__'].SurfaceLossBinary = SurfaceLossBinary
sys.modules['__main__'].AsymmetricUnifiedFocalLoss = AsymmetricUnifiedFocalLoss
sys.modules['__main__'].AsymmetricFocalLoss = AsymmetricFocalLoss
sys.modules['__main__'].AsymmetricFocalTverskyLoss = AsymmetricFocalTverskyLoss

## Configuration


In [None]:
# Settings
BATCH_SIZE = 128
NUM_WORKERS = os.cpu_count()
SEED = 200
USE_SUBTRACTED = True

# Data paths
BREADM_DIR = "./BreaDM"
TEST_DIR = os.path.join(BREADM_DIR, "seg/")
dataset_base_path = TEST_DIR 
CHECKPOINTS_DIR = "./checkpoints/breadm-dataset"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seeds
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)


## Load Test Patient IDs


In [None]:
# Get test patient IDs
x_test = get_patient_ids(TEST_DIR, split='test')
print(f"Number of test patients: {len(x_test)}")
print(f"Test patient IDs: {x_test[:5]}...")  # Show first 5


## Prepare Datasets


In [None]:
# Pre-computed statistics (from training)
GLOBAL_MEAN = 10.217766761779785
GLOBAL_STD = 26.677101135253906
PATCHES_MEAN =  20.630817413330078
PATCHES_STD = 35.328887939453125

# Create patient-aware datasets
datasets = {}
for patient_id in tqdm(x_test, desc="Creating patient datasets"):
    # Get files for this patient
    image_files, label_files = get_image_label_files_patient_aware(
        dataset_base_path="BreaDM/seg", 
        split="test", 
        image_type="VIBRANT+C2", 
        patient_id=patient_id
    )
    
    data_dicts = [
        {"image": img, "label": lbl, "subtracted": img}
        for img, lbl in zip(image_files, label_files)
    ]
    
    # Define transforms following rebuttal notebook pattern
    test_transforms_no_thorax_sub = monai.transforms.Compose([
        monai.transforms.LoadImaged(keys=["image", "label"], 
                                   image_only=False, 
                                   reader=monai.data.PILReader(converter=lambda image: image.convert("L"))),
        monai.transforms.EnsureChannelFirstd(keys=["image", "label"]),
        monai.transforms.Rotate90d(keys=["image", "label"]),
        Preprocess(keys=None, 
                  mode='test',  
                  dataset="BRADM", 
                  subtracted_images_path_prefixes=("VIBRANT+C2", "SUB2"), 
                  subtrahend=GLOBAL_MEAN, 
                  divisor=GLOBAL_STD, 
                  get_patches=False,
                  get_boundaryloss=False)
    ])
    
    test_transforms_patches_sub = monai.transforms.Compose([
        monai.transforms.LoadImaged(keys=["image", "label"], 
                                   image_only=False, 
                                   reader=monai.data.PILReader(converter=lambda image: image.convert("L"))),
        monai.transforms.EnsureChannelFirstd(keys=["image", "label"]),
        monai.transforms.Rotate90d(keys=["image", "label"]),
        Preprocess(keys=None, 
                  mode='test',  
                  dataset="BRADM", 
                  subtracted_images_path_prefixes=("VIBRANT+C2", "SUB2"), 
                  subtrahend=PATCHES_MEAN, 
                  divisor=PATCHES_STD, 
                  get_patches=True,
                  get_boundaryloss=False)
    ])
    
    # No thorax dataset (whole images)
    no_thorax_test_ds = CacheDataset(
        data=data_dicts,
        transform=test_transforms_no_thorax_sub,
        cache_rate=1.0,
        num_workers=NUM_WORKERS
    )
    
    # Patches dataset  
    patches_test_ds = CacheDataset(
        data=data_dicts,
        transform=test_transforms_patches_sub,
        cache_rate=1.0,
        num_workers=NUM_WORKERS
    )
    
    datasets[patient_id] = {
        'no_thorax_sub_test_ds': no_thorax_test_ds,
        'patches_sub_test_ds': patches_test_ds
    }

print("Patient datasets created successfully!")

# Define data_transforms dictionary for visualization function
data_transforms = {
    'no_thorax_sub_test_ds': test_transforms_no_thorax_sub,
    'no_thorax_sub_thorax_test_ds': test_transforms_no_thorax_sub,  # Same transform
    'patches_sub_test_ds': test_transforms_patches_sub
}


## Model Checkpoint Paths


In [None]:
model_paths = {
    # VENUS fusion models
    'venus_tiny': f'{CHECKPOINTS_DIR}/venus-tiny-best.ckpt',
    
    # Baseline models
    'unetplusplus': f'{CHECKPOINTS_DIR}/unetplusplus_model.ckpt',
    'skinny': f'{CHECKPOINTS_DIR}/skinny_model.ckpt',
    'resnet50': f'{CHECKPOINTS_DIR}/resnet50.ckpt',
    'fcn': f'{CHECKPOINTS_DIR}/unetplusplus_model.ckpt',  
    'segnet': f'{CHECKPOINTS_DIR}/segnet_model_large.ckpt',
    'swin': f'{CHECKPOINTS_DIR}/swin_model.ckpt',
    
    # Patches models
    'resnet50_patches': f'{CHECKPOINTS_DIR}/resnet50-patches.ckpt',
    'resnet18_patches': f'{CHECKPOINTS_DIR}/resnet18-patches.ckpt',
    'unetplusplus_patches': f'{CHECKPOINTS_DIR}/unetplusplus-patches-cabfl.ckpt',
}

# Check which models are available
print("Available models:")
for name, path in model_paths.items():
    if os.path.exists(path):
        print(f"  ✓ {name}: {path}")
    else:
        print(f"  ✗ {name}: {path} (not found)")


## Test VENUS Model


In [None]:
# Test 1: venus-tiny-best.ckpt
if os.path.exists(model_paths['venus_tiny']):
    print("Testing VENUS Tiny ...")
    scores_for_statistics_fusion_tiny = test_dataset_aware_fusion(
        model_path=model_paths['venus_tiny'],
        patient_ids=x_test,
        datasets=datasets,
        whole_dataset_key="no_thorax_sub_test_ds",
        patches_dataset_key="patches_sub_test_ds",
        use_simple_fusion=False,
        use_decoder_attention=True,
        strict=True,
        filter=False,
        subtracted=True,
        get_scores_for_statistics=False,
        get_only_masses=False,
        base_channels=16
    )
    print("\nVENUS Tiny:")
    pp.pprint(scores_for_statistics_fusion_tiny)
else:
    print("venus-tiny.ckpt not found.")
    scores_for_statistics_fusion_tiny = None

## Test Baseline Models


In [None]:
# Test baseline models (6 tests)
baseline_tests = [
    ('unetplusplus', 'UNet++', 'unetplusplus'),
    ('skinny', 'SkinnyNet', 'skinny') ,
    ('resnet50', 'ResNet50', 'resnet50'),
    ('fcn', 'FCN', 'unetplusplus'), 
    ('segnet', 'SegNet', 'segnet'),
    ('swin', 'Swin-UNETR', 'swin_unetr')
]

baseline_results = {}
for model_key, model_name, arch_name in baseline_tests:
    if os.path.exists(model_paths[model_key]):
        print(f"Testing {model_name} model...")
        result = test_dataset_aware_no_patches(
            model_path=model_paths[model_key],
            patient_ids=x_test,
            datasets=datasets,
            dataset_key="no_thorax_sub_test_ds",
            filter=False,
            get_scores_for_statistics=False,
            get_only_masses=False,
            arch_name=arch_name,
            strict=True,
            subtracted=True
        )
        baseline_results[model_key] = result
        print(f"\n{model_name} Results:")
        pp.pprint(result)
    else:
        print(f"{model_paths[model_key]} not found.")
        baseline_results[model_key] = None

print(f"\nCompleted {len([r for r in baseline_results.values() if r is not None])} baseline model tests.")


## Test Ensemble Models


In [None]:

ensemble_tests = [
    ('venus_tiny', 'unetplusplus_patches', False, 16, 'VENUS Tiny + UNet++ patches'),
    ('venus_tiny', 'unetplusplus_patches', True, 16, 'VENUS Tiny + UNet++ patches (filtered)'),
    
    ('venus_tiny', 'resnet18_patches', False, 16, 'VENUS Tiny + ResNet18 patches'),
    ('venus_tiny', 'resnet18_patches', True, 16, 'VENUS Tiny + ResNet18 patches (filtered)'),
]

ensemble_results = {}
for whole_key, patches_key, use_filter, base_channels, description in ensemble_tests:
    if os.path.exists(model_paths[whole_key]) and os.path.exists(model_paths[patches_key]):
        print(f"Testing Ensemble: {description}...")
        
        result = test_dataset_aware_ensemble(
            model_whole_path=model_paths[whole_key],
            model_patches_path=model_paths[patches_key],
            patient_ids=x_test,
            datasets=datasets,
            whole_dataset_key="no_thorax_sub_test_ds",
            patches_dataset_key="patches_sub_test_ds",
            filter=use_filter,
            get_scores_for_statistics=False,
            get_only_masses=False,
            subtracted=True,
            base_channels=base_channels
        )
        
        ensemble_key = f"{whole_key}+{patches_key}{'_filtered' if use_filter else ''}"
        ensemble_results[ensemble_key] = result
        print(f"\n{description} Results:")
        pp.pprint(result)
    else:
        print(f"Required models not found for ensemble: {description}")
        ensemble_key = f"{whole_key}+{patches_key}{'_filtered' if use_filter else ''}"
        ensemble_results[ensemble_key] = None

print(f"\nCompleted {len([r for r in ensemble_results.values() if r is not None])} ensemble model tests.")
