# Framework Benchmark: RadiObject vs MONAI vs TorchIO

Publication-quality benchmark comparing medical imaging data loading frameworks using the LIDC-IDRI dataset.

## Benchmark Scenarios

### Scenario A: Local Storage (Fair Comparison)
All frameworks read from local filesystem:
- **RadiObject**: Local TileDB array
- **MONAI**: Local NIfTI files
- **TorchIO**: Local NIfTI files

### Scenario B: S3 Remote Storage (RadiObject Advantage)
Demonstrates RadiObject's unique S3 capability via TileDB:
- **RadiObject (S3)**: S3-backed TileDB array
- **RadiObject (Local)**: Baseline comparison

## Metrics
| Metric | Description | Unit |
|--------|-------------|------|
| Throughput | Samples loaded per second | samples/sec |
| Cold Start | First batch load time | seconds |
| Warm Load | Subsequent batch load time | seconds |
| Peak Memory | Memory during loading | MB |
| Epoch Time | Full iteration over dataset | seconds |

## 1. Environment Setup

In [None]:
import subprocess
import platform
import psutil
import time
import gc
import tracemalloc
from datetime import datetime
from pathlib import Path
from dataclasses import dataclass, field
from typing import Callable
import json
import os
import sys

sys.path.insert(0, '..')

print("=" * 60)
print("MACHINE SPECIFICATIONS")
print("=" * 60)
print(f"Timestamp: {datetime.now().isoformat()}")
print(f"Platform: {platform.platform()}")
print(f"Python: {platform.python_version()}")

try:
    chip = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"], text=True).strip()
    print(f"CPU: {chip}")
except Exception:
    print(f"CPU: {platform.processor()}")

print(f"CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
print(f"RAM: {psutil.virtual_memory().total / (1024**3):.1f} GB")
print("=" * 60)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader

print(f"NumPy: {np.__version__}")
print(f"PyTorch: {torch.__version__}")

HAVE_MONAI = False
HAVE_TORCHIO = False
HAVE_SIMPLEITK = False

try:
    import monai
    from monai.data import Dataset as MonaiDataset, DataLoader as MonaiDataLoader
    from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, RandSpatialCropd
    HAVE_MONAI = True
    print(f"MONAI: {monai.__version__}")
except ImportError:
    print("MONAI: Not installed")

try:
    import torchio as tio
    HAVE_TORCHIO = True
    print(f"TorchIO: {tio.__version__}")
except ImportError:
    print("TorchIO: Not installed")

try:
    import SimpleITK as sitk
    HAVE_SIMPLEITK = True
    print(f"SimpleITK: {sitk.__version__}")
except ImportError:
    print("SimpleITK: Not installed")

from radiobject.radi_object import RadiObject
from radiobject.ctx import configure, S3Config, get_config
from ml import create_training_dataloader

print("\nRadiObject: Loaded")

In [None]:
# Benchmark Configuration
BATCH_SIZE = 4
PATCH_SIZE = (64, 64, 64)
NUM_WORKERS = 0  # Single-threaded for fair comparison
N_WARMUP = 3
N_BATCHES = 20
N_RUNS = 5

# Data paths
DATA_DIR = Path("../data")
DICOM_DIR = DATA_DIR / "lidc_subset"
NIFTI_DIR = DATA_DIR / "nifti"
RADIOBJECT_LOCAL_URI = str(DATA_DIR / "radiobject_local")
ASSETS_DIR = Path("../assets/benchmark")

# S3 config - use existing RadiObject in S3
S3_BUCKET = os.environ.get("RADIOBJECT_S3_BUCKET", "souzy-scratch")
RADIOBJECT_S3_URI = f"s3://{S3_BUCKET}/lidc-idri/radiobject" if S3_BUCKET else ""

DATA_DIR.mkdir(parents=True, exist_ok=True)
DICOM_DIR.mkdir(parents=True, exist_ok=True)
NIFTI_DIR.mkdir(parents=True, exist_ok=True)
ASSETS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Batch size: {BATCH_SIZE}")
print(f"Patch size: {PATCH_SIZE}")
print(f"Warmup iterations: {N_WARMUP}")
print(f"Benchmark batches: {N_BATCHES}")
print(f"Runs per framework: {N_RUNS}")
print(f"S3 RadiObject URI: {RADIOBJECT_S3_URI}")

## 2. Benchmark Utilities

In [None]:
@dataclass
class BenchmarkResult:
    """Single benchmark run result."""
    framework: str
    scenario: str
    cold_start_s: float
    batch_times_s: list[float]
    peak_memory_mb: float
    
    @property
    def throughput(self) -> float:
        return BATCH_SIZE / np.mean(self.batch_times_s)
    
    @property
    def mean_batch_s(self) -> float:
        return np.mean(self.batch_times_s)
    
    @property
    def std_batch_s(self) -> float:
        return np.std(self.batch_times_s)


@dataclass
class AggregatedResult:
    """Aggregated results from multiple runs."""
    framework: str
    scenario: str
    throughput_mean: float
    throughput_std: float
    cold_start_mean: float
    cold_start_std: float
    batch_time_mean: float
    batch_time_std: float
    peak_memory_mean: float
    peak_memory_std: float


def aggregate_results(results: list[BenchmarkResult]) -> AggregatedResult:
    """Aggregate multiple benchmark runs."""
    throughputs = [r.throughput for r in results]
    cold_starts = [r.cold_start_s for r in results]
    batch_times = [r.mean_batch_s for r in results]
    memories = [r.peak_memory_mb for r in results]
    
    return AggregatedResult(
        framework=results[0].framework,
        scenario=results[0].scenario,
        throughput_mean=np.mean(throughputs),
        throughput_std=np.std(throughputs),
        cold_start_mean=np.mean(cold_starts),
        cold_start_std=np.std(cold_starts),
        batch_time_mean=np.mean(batch_times),
        batch_time_std=np.std(batch_times),
        peak_memory_mean=np.mean(memories),
        peak_memory_std=np.std(memories),
    )

In [None]:
def run_dataloader_benchmark(
    loader: DataLoader,
    framework: str,
    scenario: str,
    image_key: str = "image",
) -> BenchmarkResult:
    """Run a single benchmark iteration on a DataLoader."""
    gc.collect()
    tracemalloc.start()
    
    # Cold start
    loader_iter = iter(loader)
    cold_start = time.perf_counter()
    first_batch = next(loader_iter)
    if isinstance(first_batch, dict):
        _ = first_batch[image_key].shape
    else:
        _ = first_batch.shape
    cold_start_time = time.perf_counter() - cold_start
    
    # Warmup
    for _ in range(N_WARMUP - 1):
        try:
            batch = next(loader_iter)
        except StopIteration:
            loader_iter = iter(loader)
            batch = next(loader_iter)
    
    # Benchmark
    batch_times = []
    for i in range(N_BATCHES):
        try:
            batch_start = time.perf_counter()
            batch = next(loader_iter)
            if isinstance(batch, dict):
                _ = batch[image_key].shape
            else:
                _ = batch.shape
            batch_times.append(time.perf_counter() - batch_start)
        except StopIteration:
            break
    
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()
    
    return BenchmarkResult(
        framework=framework,
        scenario=scenario,
        cold_start_s=cold_start_time,
        batch_times_s=batch_times,
        peak_memory_mb=peak / (1024 * 1024),
    )

## 3. Dataset Preparation

Download a subset of LIDC-IDRI from S3 and prepare data for all frameworks.

In [None]:
# Check for existing DICOM data
existing_dcm = list(DICOM_DIR.rglob("*.dcm"))
print(f"Existing DICOM files: {len(existing_dcm)}")

# Check if we need to download from S3
if len(existing_dcm) < 100:
    print("\nDownloading LIDC-IDRI subset from S3...")
    print("Source: s3://souzy-scratch/lidc-idri/dicom/")
    
    try:
        import boto3
        from tqdm import tqdm
        
        s3 = boto3.client('s3')
        bucket = 'souzy-scratch'
        prefix = 'lidc-idri/dicom/'
        
        # List patient directories (limit to 10 patients)
        paginator = s3.get_paginator('list_objects_v2')
        patient_dirs = set()
        
        for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter='/'):
            for p in page.get('CommonPrefixes', []):
                patient_dirs.add(p['Prefix'])
                if len(patient_dirs) >= 10:
                    break
            if len(patient_dirs) >= 10:
                break
        
        print(f"Found {len(patient_dirs)} patient directories")
        
        # Download files
        for patient_prefix in tqdm(sorted(patient_dirs), desc="Downloading patients"):
            patient_id = patient_prefix.split('/')[-2]
            local_dir = DICOM_DIR / patient_id
            local_dir.mkdir(parents=True, exist_ok=True)
            
            for page in paginator.paginate(Bucket=bucket, Prefix=patient_prefix):
                for obj in page.get('Contents', []):
                    key = obj['Key']
                    if key.endswith('.dcm'):
                        filename = key.split('/')[-1]
                        local_path = local_dir / filename
                        if not local_path.exists():
                            s3.download_file(bucket, key, str(local_path))
        
        print(f"\nDownload complete")
        existing_dcm = list(DICOM_DIR.rglob("*.dcm"))
        print(f"Total DICOM files: {len(existing_dcm)}")
        
    except ImportError:
        print("boto3 not installed. Install with: uv sync --group dev")
    except Exception as e:
        print(f"Download failed: {e}")
        print("\nPlease ensure AWS credentials are configured.")
else:
    print(f"Using existing {len(existing_dcm)} DICOM files")

In [None]:
import pydicom

def find_dicom_series(root_dir: Path) -> list[tuple[Path, str]]:
    """Find DICOM series directories and extract patient IDs."""
    series_list = []
    
    for series_dir in sorted(root_dir.iterdir()):
        if not series_dir.is_dir():
            continue
        
        dcm_files = list(series_dir.glob("*.dcm"))
        if len(dcm_files) < 10:
            continue
        
        try:
            ds = pydicom.dcmread(dcm_files[0], stop_before_pixels=True)
            patient_id = str(getattr(ds, 'PatientID', series_dir.name[:20]))
            patient_id = patient_id.replace(' ', '_').replace('/', '_')
        except Exception:
            patient_id = series_dir.name[:20].replace('/', '_')
        
        series_list.append((series_dir, patient_id))
    
    return series_list

dicom_series = find_dicom_series(DICOM_DIR)
print(f"Found {len(dicom_series)} DICOM series")

if dicom_series:
    print("\nSample series:")
    for path, pid in dicom_series[:3]:
        dcm_count = len(list(path.glob('*.dcm')))
        print(f"  {pid}: {dcm_count} files")

In [None]:
# Convert DICOM to NIfTI for MONAI/TorchIO
nifti_paths = []

if HAVE_SIMPLEITK and dicom_series:
    print("Converting DICOM to NIfTI...")
    
    for series_dir, patient_id in dicom_series:
        nifti_path = NIFTI_DIR / f"{patient_id}.nii.gz"
        nifti_paths.append(nifti_path)
        
        if nifti_path.exists():
            continue
        
        try:
            reader = sitk.ImageSeriesReader()
            dicom_names = reader.GetGDCMSeriesFileNames(str(series_dir))
            reader.SetFileNames(dicom_names)
            image = reader.Execute()
            sitk.WriteImage(image, str(nifti_path))
            print(f"  Converted: {patient_id}")
        except Exception as e:
            print(f"  Failed {patient_id}: {e}")
            nifti_paths.pop()
    
    print(f"\nNIfTI files ready: {len(nifti_paths)}")
else:
    # Check for existing NIfTI
    nifti_paths = list(NIFTI_DIR.glob("*.nii.gz"))
    if nifti_paths:
        print(f"Using existing {len(nifti_paths)} NIfTI files")
    else:
        print("No NIfTI files available (install SimpleITK for conversion)")

In [None]:
import tiledb

def radiobject_exists(uri: str) -> bool:
    try:
        if uri.startswith("s3://"):
            return tiledb.object_type(uri) == "group"
        return Path(uri).exists()
    except Exception:
        return False

# Local RadiObject
if radiobject_exists(RADIOBJECT_LOCAL_URI):
    print(f"Local RadiObject exists: {RADIOBJECT_LOCAL_URI}")
    radi_local = RadiObject(RADIOBJECT_LOCAL_URI)
elif dicom_series:
    print("Creating local RadiObject from DICOM...")
    start = time.perf_counter()
    radi_local = RadiObject.from_dicoms(
        uri=RADIOBJECT_LOCAL_URI,
        dicom_dirs=dicom_series,
        reorient=True,
    )
    print(f"Created in {time.perf_counter() - start:.1f}s")
else:
    radi_local = None
    print("No DICOM data available")

if radi_local:
    print(f"Subjects: {len(radi_local)}")
    print(f"Collections: {radi_local.collection_names}")

In [None]:
# S3 RadiObject - load existing from s3://souzy-scratch/lidc-idri/radiobject
radi_s3 = None

if RADIOBJECT_S3_URI:
    configure(
        s3=S3Config(
            region=os.environ.get("AWS_REGION", "us-east-1"),
            max_parallel_ops=8,
        )
    )
    
    if radiobject_exists(RADIOBJECT_S3_URI):
        print(f"Loading existing S3 RadiObject: {RADIOBJECT_S3_URI}")
        radi_s3 = RadiObject(RADIOBJECT_S3_URI)
        print(f"S3 subjects: {len(radi_s3)}")
        print(f"S3 collections: {radi_s3.collection_names}")
    else:
        print(f"S3 RadiObject not found at: {RADIOBJECT_S3_URI}")
        print("S3 benchmark will be skipped.")
else:
    print("S3 benchmark disabled (RADIOBJECT_S3_URI not configured)")

## 4. Framework Adapters

In [None]:
def create_radiobject_loader(radi: RadiObject) -> DataLoader:
    """Create RadiObject DataLoader."""
    return create_training_dataloader(
        radi,
        modalities=[radi.collection_names[0]],
        batch_size=BATCH_SIZE,
        patch_size=PATCH_SIZE,
        num_workers=NUM_WORKERS,
    )

if radi_local:
    loader = create_radiobject_loader(radi_local)
    batch = next(iter(loader))
    print(f"RadiObject batch shape: {batch['image'].shape}")

In [None]:
def create_monai_loader(nifti_paths: list[Path]) -> DataLoader | None:
    """Create MONAI DataLoader."""
    if not HAVE_MONAI or not nifti_paths:
        return None
    
    data_dicts = [{"image": str(p)} for p in nifti_paths]
    
    transforms = Compose([
        LoadImaged(keys=["image"]),
        EnsureChannelFirstd(keys=["image"]),
        RandSpatialCropd(keys=["image"], roi_size=PATCH_SIZE, random_size=False),
    ])
    
    dataset = MonaiDataset(data=data_dicts, transform=transforms)
    return MonaiDataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        shuffle=True,
    )

if HAVE_MONAI and nifti_paths:
    monai_loader = create_monai_loader(nifti_paths)
    if monai_loader:
        batch = next(iter(monai_loader))
        print(f"MONAI batch shape: {batch['image'].shape}")
else:
    print("MONAI loader not available")

In [None]:
def create_torchio_loader(nifti_paths: list[Path]) -> DataLoader | None:
    """Create TorchIO DataLoader."""
    if not HAVE_TORCHIO or not nifti_paths:
        return None
    
    subjects = [tio.Subject(image=tio.ScalarImage(str(p))) for p in nifti_paths]
    
    transform = tio.Compose([
        tio.CropOrPad(PATCH_SIZE),
    ])
    
    dataset = tio.SubjectsDataset(subjects, transform=transform)
    return DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        shuffle=True,
    )

if HAVE_TORCHIO and nifti_paths:
    tio_loader = create_torchio_loader(nifti_paths)
    if tio_loader:
        batch = next(iter(tio_loader))
        print(f"TorchIO batch shape: {batch['image'][tio.DATA].shape}")
else:
    print("TorchIO loader not available")

## 5. Benchmark Suite A: Local Storage Comparison

Fair comparison - all frameworks reading from local storage.

In [None]:
local_results: dict[str, list[BenchmarkResult]] = {
    "RadiObject": [],
    "MONAI": [],
    "TorchIO": [],
}

print("=" * 60)
print("BENCHMARK SUITE A: LOCAL STORAGE")
print("=" * 60)

for run in range(N_RUNS):
    print(f"\nRun {run + 1}/{N_RUNS}")
    
    # RadiObject
    if radi_local:
        loader = create_radiobject_loader(radi_local)
        result = run_dataloader_benchmark(loader, "RadiObject", "local")
        local_results["RadiObject"].append(result)
        print(f"  RadiObject: {result.throughput:.1f} samples/sec")
        del loader
        gc.collect()
    
    # MONAI
    if HAVE_MONAI and nifti_paths:
        loader = create_monai_loader(nifti_paths)
        if loader:
            result = run_dataloader_benchmark(loader, "MONAI", "local")
            local_results["MONAI"].append(result)
            print(f"  MONAI: {result.throughput:.1f} samples/sec")
            del loader
            gc.collect()
    
    # TorchIO
    if HAVE_TORCHIO and nifti_paths:
        loader = create_torchio_loader(nifti_paths)
        if loader:
            result = run_dataloader_benchmark(
                loader, "TorchIO", "local", 
                image_key="image"
            )
            # TorchIO returns nested structure
            local_results["TorchIO"].append(result)
            print(f"  TorchIO: {result.throughput:.1f} samples/sec")
            del loader
            gc.collect()

print("\nBenchmark A complete")

## 6. Benchmark Suite B: S3 vs Local (RadiObject Only)

Demonstrates RadiObject's unique S3 capability via TileDB.

In [None]:
s3_results: dict[str, list[BenchmarkResult]] = {
    "Local": [],
    "S3": [],
}

if radi_s3:
    print("=" * 60)
    print("BENCHMARK SUITE B: S3 vs LOCAL")
    print("=" * 60)
    
    for run in range(N_RUNS):
        print(f"\nRun {run + 1}/{N_RUNS}")
        
        # Local
        if radi_local:
            loader = create_radiobject_loader(radi_local)
            result = run_dataloader_benchmark(loader, "RadiObject", "local")
            s3_results["Local"].append(result)
            print(f"  Local: {result.throughput:.1f} samples/sec")
            del loader
            gc.collect()
        
        # S3
        loader = create_radiobject_loader(radi_s3)
        result = run_dataloader_benchmark(loader, "RadiObject", "s3")
        s3_results["S3"].append(result)
        print(f"  S3: {result.throughput:.1f} samples/sec")
        del loader
        gc.collect()
    
    print("\nBenchmark B complete")
else:
    print("Skipping S3 benchmark (not configured)")

## 7. Results Aggregation

In [None]:
# Aggregate local results
local_aggregated = {}
for framework, results in local_results.items():
    if results:
        local_aggregated[framework] = aggregate_results(results)

# Aggregate S3 results
s3_aggregated = {}
for scenario, results in s3_results.items():
    if results:
        s3_aggregated[scenario] = aggregate_results(results)

# Display local results
print("=" * 60)
print("LOCAL STORAGE RESULTS")
print("=" * 60)

local_df_data = []
for name, agg in local_aggregated.items():
    local_df_data.append({
        "Framework": name,
        "Throughput (samples/sec)": f"{agg.throughput_mean:.1f} +/- {agg.throughput_std:.1f}",
        "Cold Start (s)": f"{agg.cold_start_mean:.3f} +/- {agg.cold_start_std:.3f}",
        "Batch Time (ms)": f"{agg.batch_time_mean*1000:.1f} +/- {agg.batch_time_std*1000:.1f}",
        "Peak Memory (MB)": f"{agg.peak_memory_mean:.1f} +/- {agg.peak_memory_std:.1f}",
    })

if local_df_data:
    local_df = pd.DataFrame(local_df_data)
    display(local_df)

# Display S3 results
if s3_aggregated:
    print("\n" + "=" * 60)
    print("S3 vs LOCAL RESULTS (RadiObject)")
    print("=" * 60)
    
    s3_df_data = []
    for scenario, agg in s3_aggregated.items():
        s3_df_data.append({
            "Storage": scenario,
            "Throughput (samples/sec)": f"{agg.throughput_mean:.1f} +/- {agg.throughput_std:.1f}",
            "Cold Start (s)": f"{agg.cold_start_mean:.3f} +/- {agg.cold_start_std:.3f}",
            "Batch Time (ms)": f"{agg.batch_time_mean*1000:.1f} +/- {agg.batch_time_std*1000:.1f}",
        })
    
    s3_df = pd.DataFrame(s3_df_data)
    display(s3_df)
    
    # Calculate overhead
    if "Local" in s3_aggregated and "S3" in s3_aggregated:
        overhead = (s3_aggregated["Local"].throughput_mean / s3_aggregated["S3"].throughput_mean - 1) * 100
        print(f"\nS3 throughput overhead: {overhead:.1f}%")

## 8. Visualization Generation

In [None]:
# Colorblind-friendly palette
COLORS = {
    "RadiObject": "#0077BB",  # Blue
    "MONAI": "#EE7733",       # Orange
    "TorchIO": "#009988",     # Teal
    "Local": "#0077BB",
    "S3": "#CC3311",          # Red
}

plt.rcParams.update({
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 11,
    'ytick.labelsize': 11,
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

In [None]:
# 1. Local Throughput Comparison
if local_aggregated:
    fig, ax = plt.subplots(figsize=(8, 5))
    
    frameworks = list(local_aggregated.keys())
    throughputs = [local_aggregated[f].throughput_mean for f in frameworks]
    errors = [local_aggregated[f].throughput_std for f in frameworks]
    colors = [COLORS.get(f, "#999999") for f in frameworks]
    
    bars = ax.bar(frameworks, throughputs, yerr=errors, capsize=5, color=colors, edgecolor='black', linewidth=1)
    
    ax.set_ylabel("Throughput (samples/sec)")
    ax.set_title("Data Loading Throughput Comparison")
    ax.grid(axis='y', alpha=0.3)
    
    # Add value labels
    for bar, val in zip(bars, throughputs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                f"{val:.1f}", ha='center', va='bottom', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / "local_throughput.png")
    print(f"Saved: {ASSETS_DIR / 'local_throughput.png'}")
    plt.show()

In [None]:
# 2. S3 vs Local Comparison
if s3_aggregated and len(s3_aggregated) == 2:
    fig, ax = plt.subplots(figsize=(6, 5))
    
    scenarios = ["Local", "S3"]
    throughputs = [s3_aggregated[s].throughput_mean for s in scenarios]
    errors = [s3_aggregated[s].throughput_std for s in scenarios]
    colors = [COLORS[s] for s in scenarios]
    
    bars = ax.bar(scenarios, throughputs, yerr=errors, capsize=5, color=colors, edgecolor='black', linewidth=1)
    
    ax.set_ylabel("Throughput (samples/sec)")
    ax.set_title("RadiObject: Local vs S3 Storage")
    ax.grid(axis='y', alpha=0.3)
    
    for bar, val in zip(bars, throughputs):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.5, 
                f"{val:.1f}", ha='center', va='bottom', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / "s3_comparison.png")
    print(f"Saved: {ASSETS_DIR / 's3_comparison.png'}")
    plt.show()
else:
    print("S3 comparison chart skipped (not enough data)")

In [None]:
# 3. Memory Usage Comparison
if local_aggregated:
    fig, ax = plt.subplots(figsize=(8, 5))
    
    frameworks = list(local_aggregated.keys())
    memories = [local_aggregated[f].peak_memory_mean for f in frameworks]
    errors = [local_aggregated[f].peak_memory_std for f in frameworks]
    colors = [COLORS.get(f, "#999999") for f in frameworks]
    
    bars = ax.bar(frameworks, memories, yerr=errors, capsize=5, color=colors, edgecolor='black', linewidth=1)
    
    ax.set_ylabel("Peak Memory (MB)")
    ax.set_title("Peak Memory Usage During Data Loading")
    ax.grid(axis='y', alpha=0.3)
    
    for bar, val in zip(bars, memories):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                f"{val:.0f}", ha='center', va='bottom', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / "memory.png")
    print(f"Saved: {ASSETS_DIR / 'memory.png'}")
    plt.show()

In [None]:
# 4. Cold Start Comparison
if local_aggregated:
    fig, ax = plt.subplots(figsize=(8, 5))
    
    frameworks = list(local_aggregated.keys())
    cold_starts = [local_aggregated[f].cold_start_mean for f in frameworks]
    errors = [local_aggregated[f].cold_start_std for f in frameworks]
    colors = [COLORS.get(f, "#999999") for f in frameworks]
    
    bars = ax.bar(frameworks, cold_starts, yerr=errors, capsize=5, color=colors, edgecolor='black', linewidth=1)
    
    ax.set_ylabel("Cold Start Time (seconds)")
    ax.set_title("First Batch Load Time (Cold Start)")
    ax.grid(axis='y', alpha=0.3)
    
    for bar, val in zip(bars, cold_starts):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f"{val:.2f}s", ha='center', va='bottom', fontsize=11)
    
    plt.tight_layout()
    plt.savefig(ASSETS_DIR / "cold_start.png")
    print(f"Saved: {ASSETS_DIR / 'cold_start.png'}")
    plt.show()

## 9. Export for README

In [None]:
# Generate markdown tables for README
def format_winner(values: dict[str, float], metric: str, higher_is_better: bool = True) -> str:
    """Determine the winner for a metric."""
    if not values:
        return "-"
    if higher_is_better:
        return max(values, key=values.get)
    return min(values, key=values.get)

print("## Local Storage Benchmark")
print("")
print("| Metric | " + " | ".join(local_aggregated.keys()) + " | Winner |")
print("|--------|" + "|".join(["--------"] * (len(local_aggregated) + 1)) + "|")

if local_aggregated:
    # Throughput
    throughputs = {f: agg.throughput_mean for f, agg in local_aggregated.items()}
    row = "| Throughput (samples/sec) | "
    row += " | ".join([f"{agg.throughput_mean:.1f} +/- {agg.throughput_std:.1f}" for agg in local_aggregated.values()])
    row += f" | **{format_winner(throughputs, 'throughput', True)}** |"
    print(row)
    
    # Cold Start
    cold_starts = {f: agg.cold_start_mean for f, agg in local_aggregated.items()}
    row = "| Cold Start (s) | "
    row += " | ".join([f"{agg.cold_start_mean:.3f} +/- {agg.cold_start_std:.3f}" for agg in local_aggregated.values()])
    row += f" | **{format_winner(cold_starts, 'cold_start', False)}** |"
    print(row)
    
    # Memory
    memories = {f: agg.peak_memory_mean for f, agg in local_aggregated.items()}
    row = "| Peak Memory (MB) | "
    row += " | ".join([f"{agg.peak_memory_mean:.1f} +/- {agg.peak_memory_std:.1f}" for agg in local_aggregated.values()])
    row += f" | **{format_winner(memories, 'memory', False)}** |"
    print(row)

print("")

if s3_aggregated and len(s3_aggregated) == 2:
    print("## S3 Performance (RadiObject)")
    print("")
    print("| Metric | Local | S3 | Overhead |")
    print("|--------|-------|-----|----------|")
    
    local_tp = s3_aggregated["Local"].throughput_mean
    s3_tp = s3_aggregated["S3"].throughput_mean
    overhead = (local_tp / s3_tp - 1) * 100
    
    print(f"| Throughput (samples/sec) | {local_tp:.1f} | {s3_tp:.1f} | {overhead:.1f}% |")

In [None]:
# Save results as JSON
results_json = {
    "timestamp": datetime.now().isoformat(),
    "config": {
        "batch_size": BATCH_SIZE,
        "patch_size": PATCH_SIZE,
        "num_workers": NUM_WORKERS,
        "n_warmup": N_WARMUP,
        "n_batches": N_BATCHES,
        "n_runs": N_RUNS,
    },
    "local_results": {
        f: {
            "throughput_mean": agg.throughput_mean,
            "throughput_std": agg.throughput_std,
            "cold_start_mean": agg.cold_start_mean,
            "cold_start_std": agg.cold_start_std,
            "batch_time_mean": agg.batch_time_mean,
            "batch_time_std": agg.batch_time_std,
            "peak_memory_mean": agg.peak_memory_mean,
            "peak_memory_std": agg.peak_memory_std,
        }
        for f, agg in local_aggregated.items()
    },
    "s3_results": {
        s: {
            "throughput_mean": agg.throughput_mean,
            "throughput_std": agg.throughput_std,
            "cold_start_mean": agg.cold_start_mean,
            "cold_start_std": agg.cold_start_std,
        }
        for s, agg in s3_aggregated.items()
    } if s3_aggregated else {},
}

results_path = ASSETS_DIR / "benchmark_results.json"
with open(results_path, "w") as f:
    json.dump(results_json, f, indent=2)

print(f"Results saved: {results_path}")

## 10. Summary

In [None]:
print("=" * 60)
print("BENCHMARK SUMMARY")
print("=" * 60)
print(f"\nDataset: {len(dicom_series) if dicom_series else 'N/A'} DICOM series")
print(f"Batch size: {BATCH_SIZE}")
print(f"Patch size: {PATCH_SIZE}")
print(f"Runs per framework: {N_RUNS}")

print("\nFramework Availability:")
print(f"  RadiObject: Available")
print(f"  MONAI: {'Available' if HAVE_MONAI else 'Not installed'}")
print(f"  TorchIO: {'Available' if HAVE_TORCHIO else 'Not installed'}")

print("\nGenerated Artifacts:")
for png in ASSETS_DIR.glob("*.png"):
    print(f"  {png}")
for json_file in ASSETS_DIR.glob("*.json"):
    print(f"  {json_file}")

print("\nRadiObject Advantages:")
print("  1. Native DICOM ingestion (no conversion step)")
print("  2. Full DICOM metadata preservation")
print("  3. S3 native support via TileDB VFS")
print("  4. TileDB tile-level caching for repeated access")
print("  5. Random sub-volume reads without loading full volume")
print("=" * 60)

## 11. Cleanup (Optional)

In [None]:
# Uncomment to clean up generated data
# import shutil
# 
# if NIFTI_DIR.exists():
#     shutil.rmtree(NIFTI_DIR)
#     print(f"Removed: {NIFTI_DIR}")
# 
# if Path(RADIOBJECT_LOCAL_URI).exists():
#     shutil.rmtree(RADIOBJECT_LOCAL_URI)
#     print(f"Removed: {RADIOBJECT_LOCAL_URI}")

print("Cleanup skipped (uncomment to remove generated data)")