In Kaggle, add the following to the dependencies:
```
pip install torch
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [1]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'monai[einops,itk,nibabel]>=1.5.0' git+https://github.com/timojl/clipseg.git

In [2]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive

        IN_COLAB = True
        import os

        drive.mount("/content/drive")
        os.makedirs("/content/drive/MyDrive/xai", exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass

In [3]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [4]:
DATASET_NAMES = ["CHAOS", "MMWHS"]
DOMAINS = ["CT", "MR"]
DATA_PATH = "data/"
CHECKPOINT_PATH = "checkpoints/"
OUTPUTS_PATH = "outputs/"
USE_3D = False
TRAINING_EPOCHS = {
    ("MMWHS", "CT"): 15,
    ("MMWHS", "MR"): 15,
    ("CHAOS", "MR"): 15,
    ("CHAOS", "CT"): 15,

}
BATCH_SIZE = 8
SPATIAL_SIZE = 96
LEARNING_RATE = 5e-4
WEIGHT_DECAY = 0

In [5]:
CACHE_MAX_ITEMS = 96  # set the in-memory file cache size per dataset (images and segs)
ENABLE_CACHE = True    # set to False to disable caching entirely

In [6]:
CHECKPOINT_PATH = Path(CHECKPOINT_PATH)
OUTPUTS_PATH = Path(OUTPUTS_PATH)
DATA_PATH = Path(DATA_PATH)
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
OUTPUTS_PATH.mkdir(parents=True, exist_ok=True)

if USE_3D:
    encoder_type = "swin_unetr"
else:
    encoder_type = "clipseg"

In [7]:
import torch
from monai import transforms

def update_metrics(name, new_metrics):
    metrics_file = OUTPUTS_PATH / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def debug_metadata(data):
    """Debug transform to print metadata information"""
    print(f"🔍 DEBUG - Data type: {type(data)}")
    if hasattr(data, "meta"):
        print(
            f"🔍 DEBUG - Metadata keys: {list(data.meta.keys()) if data.meta else 'No meta'}"
        )
        print(f"🔍 DEBUG - Full metadata: {data.meta}")
    if hasattr(data, "shape"):
        print(f"🔍 DEBUG - Shape: {data.shape}")
    if hasattr(data, "dtype"):
        print(f"🔍 DEBUG - Dtype: {data.dtype}")
    print("🔍 DEBUG - " + "=" * 50)
    return data


# Normalization stats (mean, std) per dataset/domain
NORM_STATS = {
    ("MMWHS", "MR"):  (186.5875, 258.5917),
    ("MMWHS", "CT"):  (-745.0086, 1042.7251),
    ("CHAOS",  "MR"): (90.8292, 168.8922),
    ("CHAOS",  "CT"): (-478.1732, 476.7163),
}

# Override get_preprocessing to include normalization
def get_preprocessing(dataset_name: str, domain: str, is_training=True):
    decode_func = get_decode_func(dataset_name, domain)
    mean_std = NORM_STATS.get((dataset_name, domain))
    mean, std = (mean_std if mean_std is not None else (None, None))

    # Image-specific transforms
    if USE_3D:
        image_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]
    else:
        image_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]

    # Augmentations (training only)
    if is_training:
        image_transforms.extend(
            [
                transforms.RandGaussianNoise(prob=0.2, std=0.05),
                transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
            ]
        )

    if not USE_3D:
        image_transforms.append(transforms.RepeatChannel(repeats=3))

    # Resize -> Normalize (mean/std) -> ToTensor
    image_transforms.extend(
        [
            transforms.Resize(
                spatial_size=SPATIAL_SIZE,
                size_mode="longest",
                mode="area",
                anti_aliasing=True,
            ),
        ]
    )
    if mean is not None and std is not None:
        image_transforms.append(
            transforms.NormalizeIntensity(
                subtrahend=float(mean),
                divisor=float(std),
                channel_wise=False,
            )
        )
    image_transforms.extend(
        [
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    # Segmentation transforms (no normalization)
    if not USE_3D:
        seg_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]
    else:
        seg_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            transforms.Orientation(axcodes="RAS"),
        ]
    seg_transforms.extend(
        [
            transforms.Lambda(lambda x: decode_func(x)),
            transforms.Resize(
                spatial_size=SPATIAL_SIZE, size_mode="longest", mode="nearest"
            ),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    image_transform = transforms.Compose(image_transforms)
    seg_transform = transforms.Compose(seg_transforms)
    return image_transform, seg_transform


def get_decode_func(dataset_name, domain):
    from src.datasets.mmwhs import mmwhs_labels

    decode = None
    if dataset_name == "CHAOS":
        if domain in ["MR", "MRI"]:
            def decode(labels):
                # Convert intensity values to class indices (keep as float32)
                return labels // 63
        elif domain == "CT":
            def decode(labels):
                return torch.where(labels > 0, 1.0, 0.0)
    elif dataset_name == "MMWHS":
        def decode(labels):
            decoded_labels = torch.zeros_like(labels, dtype=torch.float32)
            for i, label_val in enumerate(mmwhs_labels.keys()):
                decoded_labels[labels == label_val] = i
            return decoded_labels

    if decode is None:
        print(
            f"Warning: No decode function defined for {dataset_name} in {domain}. Returning labels unchanged."
        )
        def decode(labels):
            return labels

    return decode

In [8]:
# Finetuning loop

for (dataset_name, domain), epochs in TRAINING_EPOCHS.items():
    download_and_extract_dataset(dataset_name, DATA_PATH)

    image_transform, seg_transform = get_preprocessing(
        dataset_name, domain, is_training=True
    )

    filename = f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
    filename = CHECKPOINT_PATH / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if USE_3D else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        transform=image_transform,  # Use transform instead of preprocess
        seg_transform=seg_transform,  # Pass seg_transform too
        base_path=DATA_PATH,
        batch_size=BATCH_SIZE,
        num_workers=0,
        slice_2d=not USE_3D,
        # new cache knobs
        cache_max_items=CACHE_MAX_ITEMS,
        enable_cache=ENABLE_CACHE,
    )


    #  Ensure the dataset is loaded correctly
    if not isinstance(dataset, BaseDataset):
        raise TypeError(
            f"Expected dataset to be an instance of BaseDataset, got {type(dataset)}"
        )
    # Print dataset information
    print()
    print(f"Dataset: {dataset_name}, Domain: {domain}")
    print(f"Number of training samples: {len(dataset.train_dataset)}")
    print(f"Number of validation samples: {len(dataset.val_dataset)}")
    print(f"Number of test samples: {len(dataset.test_dataset)}")
    print(f"Image shape: {dataset.train_dataset[0]['image'].shape}")
    print(f"Segmentation shape: {dataset.train_dataset[0]['label'].shape}")
    print(f"Number of classes: {dataset.num_classes}")
    print()

    model = dataset.get_model(
        encoder_type=encoder_type,
    )

    # 🔧 DEBUG: Check initial model parameters
    print("🔧 DEBUG: Initial model parameter check")
    initial_params = {}
    param_count = 0
    trainable_count = 0
    for name, param in model.named_parameters():
        initial_params[name] = param.clone().detach()
        param_count += 1
        if param.requires_grad:
            trainable_count += 1
    print(f"   Total parameters: {param_count}")
    print(f"   Trainable parameters: {trainable_count}")
    print(f"   Model device: {next(model.parameters()).device}")
    print()

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        CHECKPOINT_PATH
        / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if USE_3D else '2d'} images"
    )

    if USE_3D:
        print(
            f"Warning: Using 3D model requires SWIN UNETR, which is not compatible with zero-shot training."
        )

        # 🔧 DEBUG: Check freeze_body functionality
        print("🔧 DEBUG: Before freeze_body()")
        frozen_before = sum(1 for p in model.parameters() if not p.requires_grad)
        model.freeze_body()
        frozen_after = sum(1 for p in model.parameters() if not p.requires_grad)
        print(f"   Frozen parameters before: {frozen_before}")
        print(f"   Frozen parameters after: {frozen_after}")
        print(f"   Parameters frozen: {frozen_after - frozen_before}")

        # Check which parameters are trainable
        print("   Trainable parameters after freeze_body:")
        for name, param in model.named_parameters():
            if param.requires_grad:
                print(f"     {name}: {param.shape}")
        print()

        model.finetune(
            epochs=epochs, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
        )

        metrics = model.evaluate()
        update_metrics(
            f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_head",
            metrics,
        )

    # 🔧 DEBUG: Check unfreeze functionality
    print("🔧 DEBUG: Before unfreeze()")
    frozen_before = sum(1 for p in model.parameters() if not p.requires_grad)
    model.unfreeze()
    frozen_after = sum(1 for p in model.parameters() if not p.requires_grad)
    print(f"   Frozen parameters before unfreeze: {frozen_before}")
    print(f"   Frozen parameters after unfreeze: {frozen_after}")
    print(f"   Total trainable parameters: {sum(1 for p in model.parameters() if p.requires_grad)}")
    print()

    # 🔧 DEBUG: Monitor parameter changes during training
    print("🔧 DEBUG: Starting full model finetuning")
    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )

    # 🔧 DEBUG: Check if parameters actually changed
    print("🔧 DEBUG: Parameter change analysis after finetuning")
    changed_params = 0
    unchanged_params = 0
    max_change = 0.0
    # 🔧 DEBUG: Check if parameters actually changed
    print("🔧 DEBUG: Parameter change analysis after finetuning")
    changed_params = 0
    unchanged_params = 0
    max_change = 0.0
    dev = "cuda" if torch.cuda.is_available() else "cpu"
    for name, param in model.named_parameters():
        if name in initial_params:
            param_dev = param.detach().to(dev)
            initial_param_dev = initial_params[name].detach().to(dev)
            param_change = (param_dev - initial_param_dev).norm().item()
            if param_change > 1e-8:  # Consider very small changes as unchanged
                changed_params += 1
                max_change = max(max_change, param_change)
            else:
                unchanged_params += 1
                print(f"   Parameter {name} didn't change during training!")

    print(f"   Parameters that changed: {changed_params}")
    print(f"   Parameters that didn't change: {unchanged_params}")
    print(f"   Maximum parameter change: {max_change:.6f}")


    if changed_params == 0:
        print("   ⚠️ WARNING: No parameters changed during training!")
    elif max_change < 1e-6:
        print(f"   ⚠️ WARNING: Very small parameter changes (max: {max_change:.8f})")
    else:
        print("   ✅ Parameters updated successfully")
    print()

    # 🔧 DEBUG: Check training history
    if history:
        print("🔧 DEBUG: Training history analysis")
        if 'train_loss' in history:
            train_losses = history['train_loss']
            print(f"   Training losses: {train_losses[:5]}...{train_losses[-5:] if len(train_losses) > 5 else train_losses}")
            print(f"   Loss range: {min(train_losses):.6f} - {max(train_losses):.6f}")
            if len(train_losses) > 1:
                loss_change = abs(train_losses[-1] - train_losses[0])
                print(f"   Total loss change: {loss_change:.6f}")
                if loss_change < 1e-6:
                    print("   ⚠️ WARNING: Training loss barely changed!")
        else:
            print("   ⚠️ No 'train_loss' found in history")
        print(f"   History keys: {list(history.keys()) if history else 'None'}")
    else:
        print("🔧 DEBUG: No training history returned")
    print()

    # Save the finetuned model's state_dict
    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuned model for MMWHS in CT domain with 2d images already exists at checkpoints\MMWHS_CT_2d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 2d images already exists at checkpoints\MMWHS_MR_2d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in MR domain with 2d images already exists at checkpoints\CHAOS_MR_2d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in CT domain with 2d images already exists at checkpoints\CHAOS_CT_2d_finetuned.pth. Skipping finetuning.


# Domain adaptation

In [9]:
# SWIN UNETR Task Vectors
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from torch.nn.modules.conv import ConvTranspose3d

safe_globals = [
    SwinUNETR,
    SwinTransformer,
    PatchEmbed,
    Conv3d,
    Dropout,
    ModuleList,
    BasicLayer,
    SwinTransformerBlock,
    LayerNorm,
    WindowAttention,
    Linear,
    Softmax,
    Identity,
    MLPBlock,
    GELU,
    PatchMerging,
    UnetrBasicBlock,
    UnetResBlock,
    Convolution,
    LeakyReLU,
    InstanceNorm3d,
    UnetrUpBlock,
    ConvTranspose3d,
    UnetOutBlock,
]
##

## CLIPSeg Task Vectors
from src.CLIPSeg import CLIPSeg
from clipseg.clipseg import CLIPDensePredT
from clip.model import (
    CLIP,
    VisionTransformer,
    LayerNorm,
    Transformer,
    ResidualAttentionBlock,
    QuickGELU,
)
from torch.nn.modules.conv import Conv2d, ConvTranspose2d
from torch.nn.modules.container import Sequential
from torch.nn.modules.activation import MultiheadAttention, ReLU
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.transformer import (
    TransformerEncoderLayer,
    TransformerEncoder,
    TransformerDecoderLayer,
    TransformerDecoder,
)
from torch.nn.functional import relu
from torch.nn.modules.container import ModuleDict

safe_globals.extend(
    [
        CLIPSeg,
        CLIPDensePredT,
        CLIP,
        VisionTransformer,
        Conv2d,
        LayerNorm,
        Transformer,
        Sequential,
        ResidualAttentionBlock,
        MultiheadAttention,
        NonDynamicallyQuantizableLinear,
        QuickGELU,
        Embedding,
        ReLU,
        ConvTranspose2d,
        TransformerEncoderLayer,
        TransformerEncoder,
        TransformerDecoderLayer,
        TransformerDecoder,
        relu,
        ModuleDict,
    ]
)

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images"
        )
        baseline_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
        )
        if not baseline_checkpoint.exists():
            print(
                f"Baseline checkpoint for {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue
        if not finetuned_checkpoint.exists():
            print(
                f"Finetuned checkpoint {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue

        with torch.serialization.safe_globals(
            safe_globals=safe_globals,
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the output layers from the task vector
            # For swin it's all layers starting with '.out'
            # For clipseg it might not be necessary since the model architecture isn't dependent on the number of output features
            if encoder_type == "swin_unetr":
                for k in task_vector.keys():
                    if k.startswith(".out"):
                        del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in CT domain with 2d images


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possibl

Building task vector for CHAOS dataset in MR domain with 2d images
Building task vector for MMWHS dataset in CT domain with 2d images
Building task vector for MMWHS dataset in MR domain with 2d images


In [None]:
# 🔄 Simple Composite Task Vector: 4 clear metrics only

# Define the four simple composite vectors
composite_task_vectors_simple = {
    "MMWHS": (task_vectors["MMWHS_MR"] + task_vectors["MMWHS_CT"]),   # MMWHS CT+MR
    "CHAOS": (task_vectors["CHAOS_MR"] + task_vectors["CHAOS_CT"]),    # CHAOS CT+MR
    "MR":    (task_vectors["CHAOS_MR"] + task_vectors["MMWHS_MR"]),    # MR across datasets
    "CT":    (task_vectors["CHAOS_CT"] + task_vectors["MMWHS_CT"]),    # CT across datasets
}

print("🔄 Simple Composite Task Vector — computing ONLY 4 metrics (averaged over selected dataset/domain pairs)")
print("=" * 80)

from statistics import mean

# Map each composite vector to the evaluation targets
# - MMWHS_CT_MR: evaluate on MMWHS CT and MR
# - CHAOS_CT_MR: evaluate on CHAOS CT and MR
# - CT_both_datasets: evaluate on CT of both datasets
# - MR_both_datasets: evaluate on MR of both datasets
vectors_to_test = {
    "MMWHS_CT_MR": composite_task_vectors_simple["MMWHS"],          # MMWHS (CT+MR)
    "CHAOS_CT_MR": composite_task_vectors_simple["CHAOS"],          # CHAOS (CT+MR)
    "CT_both_datasets": composite_task_vectors_simple["CT"],        # CT across datasets
    "MR_both_datasets": composite_task_vectors_simple["MR"],        # MR across datasets
}

target_groups = {
    "MMWHS_CT_MR": [("MMWHS", "CT"), ("MMWHS", "MR")],
    "CHAOS_CT_MR": [("CHAOS", "CT"), ("CHAOS", "MR")],
    "CT_both_datasets": [("CHAOS", "CT"), ("MMWHS", "CT")],
    "MR_both_datasets": [("CHAOS", "MR"), ("MMWHS", "MR")],
}

# Evaluate only the selected target pairs and average Dice over them (4 metrics total)
for label, composite_vec in vectors_to_test.items():
    pairs = target_groups[label]
    print(f"\n▶ {label} ({len(pairs)} targets)")
    scores = []

    for (dataset_name, target_domain) in pairs:
        print(f"  - {dataset_name}/{target_domain}")

        image_transform, seg_transform = get_preprocessing(
            dataset_name, target_domain, is_training=False
        )

        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": DATA_PATH,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": BATCH_SIZE,
            "num_workers": 0,
            "slice_2d": not USE_3D,
            # pass cache knobs as well
            "cache_max_items": CACHE_MAX_ITEMS,
            "enable_cache": ENABLE_CACHE,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        # Build target dataset and model
        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)
        target_model = target_dataset.get_model(encoder_type=encoder_type)

        # Load and evaluate the composite task vector
        target_model.load_task_vector(composite_vec)
        metrics = target_model.evaluate()
        dice = metrics.get("train", {}).get("dice", 0.0)
        scores.append(float(dice))
        print(f"    Dice={float(dice):.3f}")

    avg_dice = mean(scores) if scores else 0.0
    update_metrics(f"{label}_avg", {"dice": avg_dice, "n": len(scores)})
    print(f"✅ {label}: avg Dice={avg_dice:.3f}")

print("=" * 80)

🔄 Simple Composite Task Vector — computing ONLY 4 metrics (averaged over selected dataset/domain pairs)

▶ MMWHS_CT_MR (2 targets)
  - MMWHS/CT

▶ MMWHS_CT_MR (2 targets)
  - MMWHS/CT
Dataset CT total samples: 5305
Split sizes - Train: 3713, Val: 795, Test: 797
Dataset CT total samples: 5305
Split sizes - Train: 3713, Val: 795, Test: 797
Found explicit background class in input. Treating it separately.
Non-background classes: ['Left ventricle blood cavity', 'Right ventricle blood cavity', 'Left atrium blood cavity', 'Right atrium blood cavity', 'Myocardium of the left ventricle', 'Ascending aorta', 'Pulmonary artery']
Found explicit background class in input. Treating it separately.
Non-background classes: ['Left ventricle blood cavity', 'Right ventricle blood cavity', 'Left atrium blood cavity', 'Right atrium blood cavity', 'Myocardium of the left ventricle', 'Ascending aorta', 'Pulmonary artery']
🔄 Loading CLIPSeg weights...
🔄 Loading CLIPSeg weights...


Evaluating train:   0%|          | 0/465 [00:00<?, ?it/s]

In [None]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS_CT": task_vectors["MMWHS_MR"]
    + task_vectors["CHAOS_CT"]
    - task_vectors["CHAOS_MR"],
    "MMWHS_MR": task_vectors["MMWHS_CT"]
    + task_vectors["CHAOS_MR"]
    - task_vectors["CHAOS_CT"],
    "CHAOS_CT": task_vectors["CHAOS_MR"]
    + task_vectors["MMWHS_CT"]
    - task_vectors["MMWHS_MR"],
    "CHAOS_MR": task_vectors["CHAOS_CT"]
    + task_vectors["MMWHS_MR"]
    - task_vectors["MMWHS_CT"],
}

In [None]:
# 🔄 Task Vector Cross-Domain Adaptation Experiments
print("🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in DATASET_NAMES:
    for target_domain in DOMAINS:
        print(f"\n{dataset_name}: {target_domain} adaptation")

        image_transform, seg_transform = get_preprocessing(
            dataset_name, target_domain, is_training=False
        )

        dataset_kwargs = {
            "dataset_name": dataset_name,
            "base_path": DATA_PATH,
            "domain": target_domain,
            "transform": image_transform,  # Use transform instead of preprocess
            "seg_transform": seg_transform,  # Pass seg_transform too
            "batch_size": BATCH_SIZE,
            "num_workers": 0,
            "slice_2d": not USE_3D,
            # pass cache knobs as well
            "cache_max_items": CACHE_MAX_ITEMS,
            "enable_cache": ENABLE_CACHE,
        }
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        # try:
        target_dataset = get_dataset(**dataset_kwargs, **extra_kwargs)

        composite_key = f"{dataset_name}_{target_domain}"
        if composite_key in composite_task_vectors:
            composite_task_vector = composite_task_vectors[composite_key]

            target_model = target_dataset.get_model(encoder_type=encoder_type)
            target_model.load_task_vector(composite_task_vector)

            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)
            print(
                f"   ✅ {composite_key}: Dice={metrics.get('train', {}).get('dice', 0):.3f}"
            )
        else:
            print(f"   ⚠️ No composite task vector found for {composite_key}")

        # except Exception as e:
        #     print(f"   ❌ {dataset_name} {target_domain}: {str(e)[:100]}...")
        #     import traceback
        #     traceback.print_exc()
        #     # continue
        #     break

print("=" * 80)

In [None]:
if IN_COLAB:
    import shutil

    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy(
        "/content/checkpoints.zip", "/content/drive/MyDrive/xai/checkpoints.zip"
    )

    # Copy metrics.json to Google Drive
    shutil.copy(
        "/content/xai/outputs/metrics.json", "/content/drive/MyDrive/xai/metrics.json"
    )

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints

# Statistiche dei 4 dataset (CHAOS/MMWHS × CT/MR)
Questo blocco calcola e visualizza statistiche per ciascuna combinazione dataset/dominio:
- Dimensioni degli split (train/val/test)
- Forma media di immagini e maschere
- Statistiche di intensità (min/max/media/dev.std) su un sottoinsieme del train
- Distribuzione delle classi (bar chart) sul sottoinsieme del train

Nota: per rapidità, le statistiche vengono calcolate su un sottoinsieme dei primi N campioni del train.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter
import json

SUBSET_N = 1e16-1  # numero massimo di campioni del train da usare per le statistiche
PRINT_EVERY = 8

# helper: estrai numpy dai MetaTensor o torch.Tensor

def to_numpy(x):
    if hasattr(x, "detach"):
        x = x.detach()
    if hasattr(x, "cpu"):
        x = x.cpu()
    return np.asarray(x)


def class_histogram(labels_np):
    # considera solo valori >=0
    flat = labels_np.astype(np.int64).ravel()
    flat = flat[flat >= 0]
    counts = Counter(flat.tolist())
    return counts


def summarize_split(loader, max_items=SUBSET_N):
    n = 0
    shapes_img, shapes_seg = [], []
    stats = {
        "img_min": [],
        "img_max": [],
        "img_mean": [],
        "img_std": [],
        "class_counts": Counter(),
    }
    if loader is None:
        return {
            "n_seen": 0,
            "img_shape_examples": [],
            "seg_shape_examples": [],
            "img_min": None,
            "img_max": None,
            "img_mean": None,
            "img_std": None,
            "class_hist": {},
        }
    for batch in loader:
        img = batch.get("image")
        seg = batch.get("label")
        if img is None:
            continue
        # img/seg possono essere MetaTensor con shape (B, C, H, W) o (B, C, H, W, D)
        img_np = to_numpy(img)
        stats["img_min"].append(float(img_np.min()))
        stats["img_max"].append(float(img_np.max()))
        stats["img_mean"].append(float(img_np.mean()))
        stats["img_std"].append(float(img_np.std()))
        shapes_img.append(tuple(img_np.shape))
        if seg is not None:
            seg_np = to_numpy(seg)
            shapes_seg.append(tuple(seg_np.shape))
            stats["class_counts"].update(class_histogram(seg_np))
        n += img_np.shape[0]
        if n >= max_items:
            break
    # aggrega
    agg = {
        "n_seen": n,
        "img_shape_examples": shapes_img[: min(3, len(shapes_img))],
        "seg_shape_examples": shapes_seg[: min(3, len(shapes_seg))],
        "img_min": float(np.mean(stats["img_min"])) if stats["img_min"] else None,
        "img_max": float(np.mean(stats["img_max"])) if stats["img_max"] else None,
        "img_mean": float(np.mean(stats["img_mean"])) if stats["img_mean"] else None,
        "img_std": float(np.mean(stats["img_std"])) if stats["img_std"] else None,
        "class_hist": dict(stats["class_counts"]),
    }
    return agg


def plot_histogram(hist_dict, title, classnames=None):
    if not hist_dict:
        print(f"   Nessuna maschera/nessuna classe trovata per {title}")
        return
    keys = sorted(hist_dict.keys())
    vals = [hist_dict[k] for k in keys]
    labels = [classnames[k] if classnames and k < len(classnames) else str(k) for k in keys]
    plt.figure(figsize=(6, 3))
    plt.bar(range(len(keys)), vals)
    plt.xticks(range(len(keys)), labels, rotation=45, ha="right")
    plt.title(title)
    plt.tight_layout()
    plt.show()


all_stats = {}

for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(f"\n== {dataset_name} / {domain} ==")
        image_transform, seg_transform = get_preprocessing(dataset_name, domain, is_training=False)
        extra_kwargs = {}
        if dataset_name == "CHAOS" and domain == "MR":
            # opzionale: limita a fegato
            extra_kwargs["liver_only"] = False

        ds = get_dataset(
            dataset_name=dataset_name,
            base_path=DATA_PATH,
            domain=domain,
            transform=image_transform,
            seg_transform=seg_transform,
            batch_size=BATCH_SIZE,
            num_workers=0,
            slice_2d=not USE_3D,
            cache_max_items=CACHE_MAX_ITEMS,
            enable_cache=ENABLE_CACHE,
            **extra_kwargs,
        )

        # dimensioni split
        n_train = len(ds.train_dataset) if ds.train_dataset is not None else 0
        n_val = len(ds.val_dataset) if ds.val_dataset is not None else 0
        n_test = len(ds.test_dataset) if ds.test_dataset is not None else 0
        print(f"Split -> train: {n_train}, val: {n_val}, test: {n_test}")
        print(f"Num classi: {getattr(ds, 'num_classes', 'N/A')} | Classnames: {getattr(ds, 'classnames', None)}")

        # statistiche su subset del train
        train_stats = summarize_split(ds.train_loader, SUBSET_N)
        imin = train_stats['img_min']
        imax = train_stats['img_max']
        imean = train_stats['img_mean']
        istd = train_stats['img_std']
        fmt = lambda v: (f"{v:.4f}" if isinstance(v, (int, float)) else "N/A")
        print(f"   Visti nel subset: {train_stats['n_seen']}")
        print(f"   Esempi img shape: {train_stats['img_shape_examples']}")
        print(f"   Esempi seg shape: {train_stats['seg_shape_examples']}")
        print(
            f"   Intensità ~ min:{fmt(imin)} max:{fmt(imax)} "
            f"mean:{fmt(imean)} std:{fmt(istd)}"
        )

        all_stats[f"{dataset_name}_{domain}"] = {
            "splits": {"train": n_train, "val": n_val, "test": n_test},
            "subset": train_stats,
            "classnames": getattr(ds, "classnames", None),
        }

        # bar chart distribuzione classi
        plot_histogram(
            train_stats["class_hist"],
            title=f"Distribuzione classi (subset) - {dataset_name} {domain}",
            classnames=ds.classnames if hasattr(ds, "classnames") else None,
        )

# salva riepilogo su file
try:
    out_file = OUTPUTS_PATH / "dataset_stats.json"
    with open(out_file, "w") as f:
        json.dump(all_stats, f, indent=2)
    print(f"\nSalvato riepilogo in: {out_file}")
except Exception as e:
    print(f"Errore salvataggio stats: {e}")

print("\nRiepilogo sintetico:")
for k, v in all_stats.items():
    s = v["splits"]
    print(f" - {k}: train={s['train']}, val={s['val']}, test={s['test']} | visti(subset)={v['subset']['n_seen']}")