In Kaggle, 
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [4]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'monai[einops,itk,nibabel]>=1.5.0' git+https://github.com/timojl/clipseg.git

In [5]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive

        IN_COLAB = True
        import os

        drive.mount("/content/drive")
        os.makedirs("/content/drive/MyDrive/xai", exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass

In [6]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [7]:
DATASET_NAMES = ["CHAOS", "MMWHS"]
DOMAINS = ["CT", "MR"]
DATA_PATH = "data/"
CHECKPOINT_PATH = "checkpoints/"
OUTPUTS_PATH = "outputs/"
USE_3D = False
TRAINING_EPOCHS = {
    ("CHAOS", "CT"): 100,
    ("CHAOS", "MR"): 100,
    ("MMWHS", "CT"): 100,
    ("MMWHS", "MR"): 100,
}
BATCH_SIZE = 8
SPATIAL_SIZE = 128
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 5e-5
# Number of DataLoader workers (set >0 to enable parallel data loading)
NUM_WORKERS = 0 if not IN_KAGGLE and not IN_COLAB else 2
# Set True to enable debug prints/timers/visualizations)
DEBUG = False

# Profiling controls: False | 'cprofile' | 'torch'
PROFILE = False

In [8]:
CHECKPOINT_PATH = Path(CHECKPOINT_PATH)
OUTPUTS_PATH = Path(OUTPUTS_PATH)
DATA_PATH = Path(DATA_PATH)
PROFILE_DIR = OUTPUTS_PATH / "profiling"
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUTS_PATH.mkdir(parents=True, exist_ok=True)
PROFILE_DIR.mkdir(parents=True, exist_ok=True)

if USE_3D:
    encoder_type = "swin_unetr"
else:
    encoder_type = "clipseg"

In [9]:
import torch
from monai import transforms


def update_metrics(name, new_metrics):
    metrics_file = OUTPUTS_PATH / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


# Normalization stats (mean, std) per dataset/domain
NORM_STATS = {
    ("MMWHS", "MR"): (186.5875, 258.5917),
    ("MMWHS", "CT"): (-745.0086, 1042.7251),
    ("CHAOS", "MR"): (90.8292, 168.8922),
    ("CHAOS", "CT"): (-478.1732, 476.7163),
}

# Optimized preprocessing: resize early


def get_preprocessing(dataset_name: str, domain: str, is_training=True):
    decode_func = get_decode_func(dataset_name, domain)
    mean_std = NORM_STATS.get((dataset_name, domain))
    mean, std = mean_std if mean_std is not None else (None, None)

    # Image-specific transforms
    if USE_3D:
        image_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]
    else:
        image_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]

    # Resize early to reduce compute
    image_transforms.append(
        transforms.Resize(
            spatial_size=SPATIAL_SIZE,
            size_mode="longest",
            mode="area",
            anti_aliasing=True,
        )
    )

    # Convert to tensor and ensure float32 for stable CPU ops
    image_transforms.extend(
        [
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    # Normalize (still in float32)
    if mean is not None and std is not None:
        image_transforms.append(
            transforms.NormalizeIntensity(
                subtrahend=float(mean),
                divisor=float(std),
                channel_wise=False,
            )
        )

    # Augmentations (training only) — run in float32 on CPU
    if is_training:
        image_transforms.extend(
            [
                transforms.RandGaussianNoise(prob=0.15, std=0.05),
                transforms.RandAdjustContrast(prob=0.15, gamma=(0.95, 1.05)),
            ]
        )

    # Repeat to 3 channels only at the end (2D only)
    if not USE_3D:
        image_transforms.append(transforms.RepeatChannel(repeats=3))

    image_transform = transforms.Compose(image_transforms)

    # Segmentation transforms
    if not USE_3D:
        seg_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]
    else:
        seg_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]

    seg_transforms.extend(
        [
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.long),
            transforms.Lambda(
                lambda x: decode_func(x)
            ),  # decode after tensor conversion
            transforms.Resize(
                spatial_size=SPATIAL_SIZE, size_mode="longest", mode="nearest"
            ),
            # transforms.EnsureType(dtype=torch.float32),
        ]
    )

    seg_transform = transforms.Compose(seg_transforms)
    return image_transform, seg_transform


def get_decode_func(dataset_name, domain):
    from src.datasets.mmwhs import mmwhs_labels

    decode = None
    if dataset_name == "CHAOS":
        if domain in ["MR", "MRI"]:

            def decode(labels):
                # Convert intensity values to class indices (keep as float32)
                return labels // 63

        elif domain == "CT":

            def decode(labels):
                return torch.where(labels > 0, 1.0, 0.0)

    elif dataset_name == "MMWHS":

        def decode(labels):
            decoded_labels = torch.zeros_like(labels, dtype=torch.float32)
            for i, label_val in enumerate(mmwhs_labels.keys()):
                decoded_labels[labels == label_val] = i
            return decoded_labels

    if decode is None:

        def decode(labels):
            return labels

    return decode

In [None]:
# Finetuning loop

for (dataset_name, domain), epochs in TRAINING_EPOCHS.items():
    download_and_extract_dataset(dataset_name, DATA_PATH)

    image_transform, seg_transform = get_preprocessing(
        dataset_name, domain, is_training=True
    )

    filename = f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
    filename = CHECKPOINT_PATH / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if USE_3D else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        transform=image_transform,
        seg_transform=seg_transform,
        base_path=DATA_PATH,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        slice_2d=not USE_3D,
    )

    #  Ensure the dataset is loaded correctly
    if not isinstance(dataset, BaseDataset):
        raise TypeError(
            f"Expected dataset to be an instance of BaseDataset, got {type(dataset)}"
        )

    model = dataset.get_model(
        encoder_type=encoder_type,
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        CHECKPOINT_PATH
        / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)

    if USE_3D:
        model.freeze_body()
        model.finetune(
            epochs=epochs,
            learning_rate=LEARNING_RATE,
            weight_decay=WEIGHT_DECAY,
            profile=PROFILE,
            profile_dir=str(PROFILE_DIR),
        )
        metrics = model.evaluate(profile=PROFILE, profile_dir=str(PROFILE_DIR))
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_head",
            metrics,
        )

    # Train Only Segmentation Head
    # pass

    # Train Visual Encoder + Segmentation head
    # model.unfreeze()
    # model.freeze_text_encoder()

    # Train Last 2 ResBlocks of Visual Encoder + Segmentation head
    # for p in model.encoder.clipseg.reduce.parameters():  # Not in forward pass anyway
    #     p.requires_grad = False
    # for i in range(8, 10):
    #     for p in model.encoder.clipseg.clip_model.visual.transformer.resblocks[
    #         i
    #     ].parameters():
    #         p.requires_grad = True

    # Train Only Visual Encoder
    # for p in model.encoder.clipseg.model.parameters():
    #     p.requires_grad_(not p.requires_grad)
    # model.freeze_text_encoder()

    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        debug=DEBUG,
        profile=PROFILE,
        profile_dir=str(PROFILE_DIR),
    )

    torch.save(model.encoder, filename)
    model_metrics = model.evaluate(profile=PROFILE, profile_dir=str(PROFILE_DIR))
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuning on CHAOS dataset in CT domain with 2d images 
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
🔄 Loading CLIPSeg weights...
🚀 Starting training for 100 epochs
   Device: cuda

📖 Epoch 1/100


Training:   0%|          | 0/251 [00:08<?, ?it/s]


KeyboardInterrupt: 

# Domain adaptation

In [None]:
# SWIN UNETR Task Vectors
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from torch.nn.modules.conv import ConvTranspose3d

safe_globals = [
    SwinUNETR,
    SwinTransformer,
    PatchEmbed,
    Conv3d,
    Dropout,
    ModuleList,
    BasicLayer,
    SwinTransformerBlock,
    LayerNorm,
    WindowAttention,
    Linear,
    Softmax,
    Identity,
    MLPBlock,
    GELU,
    PatchMerging,
    UnetrBasicBlock,
    UnetResBlock,
    Convolution,
    LeakyReLU,
    InstanceNorm3d,
    UnetrUpBlock,
    ConvTranspose3d,
    UnetOutBlock,
]
##

## CLIPSeg Task Vectors
from src.CLIPSeg import CLIPSeg
from clipseg.clipseg import CLIPDensePredT
from clip.model import (
    CLIP,
    VisionTransformer,
    LayerNorm,
    Transformer,
    ResidualAttentionBlock,
    QuickGELU,
)
from torch.nn.modules.conv import Conv2d, ConvTranspose2d
from torch.nn.modules.container import Sequential
from torch.nn.modules.activation import MultiheadAttention, ReLU
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.transformer import (
    TransformerEncoderLayer,
    TransformerEncoder,
    TransformerDecoderLayer,
    TransformerDecoder,
)
from torch.nn.functional import relu
from torch.nn.modules.container import ModuleDict

safe_globals.extend(
    [
        CLIPSeg,
        CLIPDensePredT,
        CLIP,
        VisionTransformer,
        Conv2d,
        LayerNorm,
        Transformer,
        Sequential,
        ResidualAttentionBlock,
        MultiheadAttention,
        NonDynamicallyQuantizableLinear,
        QuickGELU,
        Embedding,
        ReLU,
        ConvTranspose2d,
        TransformerEncoderLayer,
        TransformerEncoder,
        TransformerDecoderLayer,
        TransformerDecoder,
        relu,
        ModuleDict,
    ]
)

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images"
        )
        baseline_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
        )
        if not baseline_checkpoint.exists():
            print(
                f"Baseline checkpoint for {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue
        if not finetuned_checkpoint.exists():
            print(
                f"Finetuned checkpoint {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue

        with torch.serialization.safe_globals(
            safe_globals=safe_globals,
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the output layers from the task vector
            # For swin it's all layers starting with '.out'
            # For clipseg it might not be necessary since the model architecture isn't dependent on the number of output features
            if encoder_type == "swin_unetr":
                for k in task_vector.keys():
                    if k.startswith(".out"):
                        del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

## Part 1: Improve robustness post-hoc with data

In [None]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS": (task_vectors["MMWHS_MR"]*0.55)
    + (task_vectors["MMWHS_CT"]*0.45),
    "CHAOS": (task_vectors["CHAOS_MR"]*0.7)
    + (task_vectors["CHAOS_CT"]*0.3),
    "MR": (task_vectors["CHAOS_MR"]*0.64)
     + (task_vectors["MMWHS_MR"]*0.36),
    "CT": (task_vectors["CHAOS_CT"]*0.475)
     + (task_vectors["MMWHS_CT"]*0.525),
}
alpha = 1

In [None]:
# Task vector simple composition experiments
print("🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in DATASET_NAMES:
    for target_domain in DOMAINS:
        print(f"\n{dataset_name}: {target_domain} adaptation")

        image_transform, seg_transform = get_preprocessing(
            dataset_name, target_domain, is_training=False
        )

        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": DATA_PATH,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": BATCH_SIZE,
            "num_workers": NUM_WORKERS,
            "slice_2d": not USE_3D,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)

        composite_task_vector = composite_task_vectors[dataset_name]

        target_model = target_dataset.get_model(encoder_type=encoder_type)
        target_model.load_task_vector(composite_task_vector, scaling_coef=alpha)

        metrics = target_model.evaluate()
        update_metrics(f"{dataset_name}_composite_at_{target_domain}", metrics)
        train_d = metrics.get("train", {}).get("dice", 0)
        val_d = metrics.get("val", {}).get("dice")
        if val_d is not None:
            print(
                f"   ✅ {dataset_name} at {target_domain}: Train Dice={train_d:.3f} | Val Dice={val_d:.3f}"
            )
        else:
            print(f"   ✅ {dataset_name} at {target_domain}: Train Dice={train_d:.3f}")

        composite_task_vector = composite_task_vectors[target_domain]

        target_model = target_dataset.get_model(encoder_type=encoder_type)
        target_model.load_task_vector(composite_task_vector, scaling_coef=alpha)

        metrics = target_model.evaluate()
        update_metrics(f"{target_domain}_composite_at_{dataset_name}", metrics)
        train_d = metrics.get("train", {}).get("dice", 0)
        val_d = metrics.get("val", {}).get("dice")
        if val_d is not None:
            print(
                f"   ✅ {target_domain} at {dataset_name}: Train Dice={train_d:.3f} | Val Dice={val_d:.3f}"
            )
        else:
            print(f"   ✅ {target_domain} at {dataset_name}: Train Dice={train_d:.3f}")
print("=" * 80)

## Part 2: Improve robustness post-hoc without data

In [None]:
# Build composite task vectors using arithmetic

composite_task_vectors = {
    "MMWHS_CT":  1.04 * task_vectors["MMWHS_MR"] \
               + 0.76 * task_vectors["CHAOS_CT"] \
               - 1.79 * task_vectors["CHAOS_MR"],

    "MMWHS_MR":  0.85 * task_vectors["MMWHS_CT"] \
               + 1.79 * task_vectors["CHAOS_MR"] \
               - 0.76 * task_vectors["CHAOS_CT"],

    "CHAOS_CT":  1.79 * task_vectors["CHAOS_MR"] \
               + 0.85 * task_vectors["MMWHS_CT"] \
               - 1.04 * task_vectors["MMWHS_MR"],

    "CHAOS_MR":  0.76 * task_vectors["CHAOS_CT"] \
               + 1.04 * task_vectors["MMWHS_MR"] \
               - 0.85 * task_vectors["MMWHS_CT"],
}
alpha = 1

In [None]:
# 🔄 Task Vector Cross-Domain Adaptation Experiments
print("🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in DATASET_NAMES:
    for target_domain in DOMAINS:
        print(f"\n{dataset_name}: {target_domain} adaptation")

        image_transform, seg_transform = get_preprocessing(
            dataset_name, target_domain, is_training=False
        )

        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": DATA_PATH,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": BATCH_SIZE,
            "num_workers": NUM_WORKERS,
            "slice_2d": not USE_3D,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        # try:
        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)

        composite_key = f"{dataset_name}_{target_domain}"
        if composite_key in composite_task_vectors:
            composite_task_vector = composite_task_vectors[composite_key]

            target_model = target_dataset.get_model(encoder_type=encoder_type)
            target_model.load_task_vector(composite_task_vector, scaling_coef=alpha)

            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)
            train_d = metrics.get("train", {}).get("dice", 0)
            val_d = metrics.get("val", {}).get("dice")
            if val_d is not None:
                print(
                    f"   ✅ {composite_key}: Train Dice={train_d:.3f} | Val Dice={val_d:.3f}"
                )
            else:
                print(f"   ✅ {composite_key}: Train Dice={train_d:.3f}")
        else:
            print(f"   ⚠️ No composite task vector found for {composite_key}")

        # except Exception as e:
        #     print(f"   ❌ {dataset_name} {target_domain}: {str(e)[:100]}...")
        #     import traceback
        #     traceback.print_exc()
        #     # continue
        #     break

print("=" * 80)

In [None]:
# Load and display all metrics
metrics_file = OUTPUTS_PATH / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    def fmt_pair(m):
        if not isinstance(m, dict):
            return "Dice=N/A"
        t = m.get("train", {})
        v = m.get("val", {})
        t_d = t.get("dice")
        v_d = v.get("dice")
        if t_d is not None and v_d is not None:
            return f"Train Dice={t_d:.3f}, Val Dice={v_d:.3f}"
        if t_d is not None:
            return f"Train Dice={t_d:.3f}"
        if v_d is not None:
            return f"Val Dice={v_d:.3f}"
        return "Dice=N/A"

    # Baseline performance
    print("\nBaseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            print(f"   {key}: {fmt_pair(metrics)}")

    # After Head-training performance
    print("\n🏋️‍♂️ After Head-Training Performance:")
    for key, metrics in all_metrics.items():
        if "head" in key:
            print(f"   {key}: {fmt_pair(metrics)}")

    # Finetuned performance
    print("\n🏆 Finetuned Performance:")
    for key, metrics in all_metrics.items():
        if "finetuned" in key:
            print(f"   {key}: {fmt_pair(metrics)}")

    # Composite task vector results
    print("\n🧩 Composite Task Vector Results:")
    for key, metrics in all_metrics.items():
        if "composite_at" in key:
            print(f"   {key}: {fmt_pair(metrics)}")

    # Dataless adaptation results
    print("\n🔄 Dataless Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "adaptation" in key:
            print(f"   {key}: {fmt_pair(metrics)}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
if IN_COLAB:
    import shutil

    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy(
        "/content/checkpoints.zip", "/content/drive/MyDrive/xai/checkpoints.zip"
    )

    # Copy metrics.json to Google Drive
    shutil.copy(
        "/content/xai/outputs/metrics.json", "/content/drive/MyDrive/xai/metrics.json"
    )

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints

# Statistiche dei 4 dataset (CHAOS/MMWHS × CT/MR)
Questo blocco calcola e visualizza statistiche per ciascuna combinazione dataset/dominio:
- Dimensioni degli split (train/val/test)
- Forma media di immagini e maschere
- Statistiche di intensità (min/max/media/dev.std) su un sottoinsieme del train
- Distribuzione delle classi (bar chart) sul sottoinsieme del train

Nota: per rapidità, le statistiche vengono calcolate su un sottoinsieme dei primi N campioni del train.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import json

SUBSET_N = 1e16 - 1  # numero massimo di campioni del train da usare per le statistiche
PRINT_EVERY = 8

# helper: estrai numpy dai MetaTensor o torch.Tensor


def to_numpy(x):
    if hasattr(x, "detach"):
        x = x.detach()
    if hasattr(x, "cpu"):
        x = x.cpu()
    return np.asarray(x)


def class_histogram(labels_np):
    # considera solo valori >=0
    flat = labels_np.astype(np.int64).ravel()
    flat = flat[flat >= 0]
    counts = Counter(flat.tolist())
    return counts


def summarize_split(loader, max_items=SUBSET_N):
    n = 0
    shapes_img, shapes_seg = [], []
    stats = {
        "img_min": [],
        "img_max": [],
        "img_mean": [],
        "img_std": [],
        "class_counts": Counter(),
    }
    if loader is None:
        return {
            "n_seen": 0,
            "img_shape_examples": [],
            "seg_shape_examples": [],
            "img_min": None,
            "img_max": None,
            "img_mean": None,
            "img_std": None,
            "class_hist": {},
        }
    for batch in loader:
        img = batch.get("image")
        seg = batch.get("label")
        if img is None:
            continue
        # img/seg possono essere MetaTensor con shape (B, C, H, W) o (B, C, H, W, D)
        img_np = to_numpy(img)
        stats["img_min"].append(float(img_np.min()))
        stats["img_max"].append(float(img_np.max()))
        stats["img_mean"].append(float(img_np.mean()))
        stats["img_std"].append(float(img_np.std()))
        shapes_img.append(tuple(img_np.shape))
        if seg is not None:
            seg_np = to_numpy(seg)
            shapes_seg.append(tuple(seg_np.shape))
            stats["class_counts"].update(class_histogram(seg_np))
        n += img_np.shape[0]
        if n >= max_items:
            break
    # aggrega
    agg = {
        "n_seen": n,
        "img_shape_examples": shapes_img[: min(3, len(shapes_img))],
        "seg_shape_examples": shapes_seg[: min(3, len(shapes_seg))],
        "img_min": float(np.mean(stats["img_min"])) if stats["img_min"] else None,
        "img_max": float(np.mean(stats["img_max"])) if stats["img_max"] else None,
        "img_mean": float(np.mean(stats["img_mean"])) if stats["img_mean"] else None,
        "img_std": float(np.mean(stats["img_std"])) if stats["img_std"] else None,
        "class_hist": dict(stats["class_counts"]),
    }
    return agg


def plot_histogram(hist_dict, title, classnames=None):
    if not hist_dict:
        print(f"   Nessuna maschera/nessuna classe trovata per {title}")
        return
    keys = sorted(hist_dict.keys())
    vals = [hist_dict[k] for k in keys]
    labels = [
        classnames[k] if classnames and k < len(classnames) else str(k) for k in keys
    ]
    plt.figure(figsize=(6, 3))
    plt.bar(range(len(keys)), vals)
    plt.xticks(range(len(keys)), labels, rotation=45, ha="right")
    plt.title(title)
    plt.tight_layout()
    plt.show()


all_stats = {}

for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(f"\n== {dataset_name} / {domain} ==")
        image_transform, seg_transform = get_preprocessing(
            dataset_name, domain, is_training=False
        )
        extra_kwargs = {}
        if dataset_name == "CHAOS" and domain == "MR":
            # opzionale: limita a fegato
            extra_kwargs["liver_only"] = False

        ds = get_dataset(
            dataset_name=dataset_name,
            base_path=DATA_PATH,
            domain=domain,
            transform=image_transform,
            seg_transform=seg_transform,
            batch_size=BATCH_SIZE,
            num_workers=NUM_WORKERS,
            slice_2d=not USE_3D,
            **extra_kwargs,
        )

        # dimensioni split
        n_train = len(ds.train_dataset) if ds.train_dataset is not None else 0
        n_val = len(ds.val_dataset) if ds.val_dataset is not None else 0
        n_test = len(ds.test_dataset) if ds.test_dataset is not None else 0
        print(f"Split -> train: {n_train}, val: {n_val}, test: {n_test}")
        print(
            f"Num classi: {getattr(ds, 'num_classes', 'N/A')} | Classnames: {getattr(ds, 'classnames', None)}"
        )

        # statistiche su subset del train
        train_stats = summarize_split(ds.train_loader, SUBSET_N)
        imin = train_stats["img_min"]
        imax = train_stats["img_max"]
        imean = train_stats["img_mean"]
        istd = train_stats["img_std"]
        fmt = lambda v: (f"{v:.4f}" if isinstance(v, (int, float)) else "N/A")
        print(f"   Visti nel subset: {train_stats['n_seen']}")
        print(f"   Esempi img shape: {train_stats['img_shape_examples']}")
        print(f"   Esempi seg shape: {train_stats['seg_shape_examples']}")
        print(
            f"   Intensità ~ min:{fmt(imin)} max:{fmt(imax)} "
            f"mean:{fmt(imean)} std:{fmt(istd)}"
        )

        all_stats[f"{dataset_name}_{domain}"] = {
            "splits": {"train": n_train, "val": n_val, "test": n_test},
            "subset": train_stats,
            "classnames": getattr(ds, "classnames", None),
        }

        # bar chart distribuzione classi
        plot_histogram(
            train_stats["class_hist"],
            title=f"Distribuzione classi (subset) - {dataset_name} {domain}",
            classnames=ds.classnames if hasattr(ds, "classnames") else None,
        )

# salva riepilogo su file
try:
    out_file = OUTPUTS_PATH / "dataset_stats.json"
    with open(out_file, "w") as f:
        json.dump(all_stats, f, indent=2)
    print(f"\nSalvato riepilogo in: {out_file}")
except Exception as e:
    print(f"Errore salvataggio stats: {e}")

print("\nRiepilogo sintetico:")
for k, v in all_stats.items():
    s = v["splits"]
    print(
        f" - {k}: train={s['train']}, val={s['val']}, test={s['test']} | visti(subset)={v['subset']['n_seen']}"
    )