In Kaggle, add the following to the dependencies:
```
pip install torch
pip install torchvision
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [1]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'napari[pyqt6,optional]==0.6.2a1' 'monai[einops,nibabel]>=1.1.0'

In [2]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab
        from google.colab import drive
        IN_COLAB = True
        import os
        drive.mount('/content/drive')
        os.makedirs('/content/drive/MyDrive/xai', exist_ok=True)
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt
    except Exception:
        pass




In [3]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [4]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["MR", "CT"]
data_path = "data/"
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True
training_epochs = {
    ("CHAOS", "MR"): 30,
    ("CHAOS", "CT"): 10,
    ("MMWHS", "MR"): 30,
    ("MMWHS", "CT"): 20,
}
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

In [5]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
import torch
from monai import transforms

def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def get_preprocessing(domain, is_training=True):
    """
    Get comprehensive preprocessing pipeline for volumetric medical data.

    Args:
        domain: 'MR' or 'CT'
        is_training: Whether this is for training (includes augmentations)
    """
    if use_3d:
        # Base preprocessing steps (applied to all data)
        base_transforms = [
            transforms.Orientation(axcodes="RAS"),  # Standardize spatial orientation (D, H, W)
            transforms.Spacing(pixdim=(1.5, 1.5, 2.0), mode="trilinear"),  # Consistent voxel spacing
        ]

        # Domain-specific intensity normalization
        if domain.upper() in ['CT']:
            # CT: Clip HU values and normalize
            base_transforms.extend([
                transforms.ScaleIntensityRange(
                    a_min=-200, a_max=300, b_min=0.0, b_max=1.0, clip=True
                ),
            ])
        else:  # MR/MRI
            # MR: Z-score normalization (handles varying intensity ranges)
            base_transforms.extend([
                transforms.NormalizeIntensity(nonzero=True, channel_wise=True),
            ])

        # Spatial resizing
        base_transforms.append(
            transforms.Resize(spatial_size=(96, 96, 96), mode="trilinear")
        )

        # Training augmentations
        if is_training:
            augmentation_transforms = [
                transforms.RandRotate90(prob=0.3, spatial_axes=(0, 1)),
                transforms.RandFlip(prob=0.3, spatial_axis=0),
                transforms.RandAffine(
                    prob=0.3,
                    rotate_range=0.1,
                    translate_range=5,
                    scale_range=0.1,
                    mode="trilinear"
                ),
                transforms.RandGaussianNoise(prob=0.2, std=0.05),
                transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
            ]
            base_transforms.extend(augmentation_transforms)

        # Final conversion to tensor
        base_transforms.extend([
            # transforms.EnsureChannelFirst(),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ])

        return transforms.Compose(base_transforms)
    else:
        # 2D preprocessing (if needed)
        return None

In [7]:
# Finetuning loop

for (dataset_name, domain), epochs in training_epochs.items():
    download_and_extract_dataset(dataset_name, data_path)
    preprocess = get_preprocessing(domain, is_training=True)

    filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
    filename = checkpoint_path / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        preprocess=preprocess,
        base_path=data_path,
        batch_size=1,
        num_workers=1,
        slice_2d=not use_3d,
    )

    model = dataset.get_model(
        encoder_type="swin_unetr",
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        checkpoint_path
        / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images"
    )
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline",
        model_metrics,
    )

    # Train the segmentation head
    # For the proposal of the paper this part should not be done like this!
    # This requires data - the original task arithmetic paper builds classification heads using no data, only templates
    # if we have time this point should be addressed, otherwise at least mention that the technical problem
    # requires more developement time and the proof of concept should still be somewhat solid.
    model.freeze_body()
    model.finetune(epochs=epochs, learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    # Save the head
    torch.save(model.head, checkpoint_path / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head.pth")

    metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_head",
        metrics,
    )

    # Finetune the encoder-decoder
    model.unfreeze()
    model.freeze_head()
    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )
    # Save the finetuned model's state_dict

    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuned model for CHAOS in MR domain with 3d images already exists at checkpoints/CHAOS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in CT domain with 3d images already exists at checkpoints/CHAOS_CT_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 3d images already exists at checkpoints/MMWHS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in CT domain with 3d images already exists at checkpoints/MMWHS_CT_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 3d images already exists at checkpoints/MMWHS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in CT domain with 3d images already exists at checkpoints/MMWHS_CT_3d_finetuned.pth. Skipping finetuning.


# Domain adaptation

In [8]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images"
        )
        baseline_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            [
                SwinUNETR,
                SwinTransformer,
                PatchEmbed,
                Conv3d,
                Dropout,
                ModuleList,
                BasicLayer,
                SwinTransformerBlock,
                LayerNorm,
                WindowAttention,
                Linear,
                Softmax,
                Identity,
                MLPBlock,
                GELU,
                PatchMerging,
                UnetrBasicBlock,
                UnetResBlock,
                Convolution,
                LeakyReLU,
                InstanceNorm3d,
                UnetrUpBlock,
                ConvTranspose3d,
                UnetOutBlock,
            ]
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the .out layer from the task vector
            out_layer_keys = [k for k in task_vector.keys() if "out." in k]
            for k in out_layer_keys:
                del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in MR domain with 3d images
Building task vector for CHAOS dataset in CT domain with 3d images
Building task vector for CHAOS dataset in CT domain with 3d images
Building task vector for MMWHS dataset in MR domain with 3d images
Building task vector for MMWHS dataset in MR domain with 3d images
Building task vector for MMWHS dataset in CT domain with 3d images
Building task vector for MMWHS dataset in CT domain with 3d images


In [9]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS_CT": task_vectors["MMWHS_MR"] + task_vectors["CHAOS_CT"] - task_vectors["CHAOS_MR"],
    "MMWHS_MR": task_vectors["MMWHS_CT"] + task_vectors["CHAOS_MR"] - task_vectors["CHAOS_CT"],
    "CHAOS_CT": task_vectors["CHAOS_MR"] + task_vectors["MMWHS_CT"] - task_vectors["MMWHS_MR"],
    "CHAOS_MR": task_vectors["CHAOS_CT"] + task_vectors["MMWHS_MR"] - task_vectors["MMWHS_CT"],
}

In [10]:
# Task Vector Cross-Domain Evaluation (merged version, with nested loops for all configs)
print("\n🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)
preprocess = get_preprocessing(domain, is_training=False)
for dataset_name in dataset_names:
    for target_domain in domains:

        # Only add extra_kwarg 'liver_only' for CHAOS dataset
        extra_kwargs = {}
        if dataset_name == "CHAOS":
            extra_kwargs["liver_only"] = True

        try:
            composite_key = f"{dataset_name}_{target_domain}"
            if composite_key not in composite_task_vectors:
                # If the composite task vector does not exist, skip this iteration
                print(
                    f"   ❗ Warning: Composite task vector for {composite_key} does not exist. Skipping evaluation."
                )
                continue

            print(
                f"\n{dataset_name}: {target_domain} adaptation"
            )
            # Load target domain dataset
            dataset_kwargs = dict(
                dataset_name=dataset_name,
                domain=target_domain,
                base_path=data_path,
                preprocess=preprocess,
                batch_size=1,
                num_workers=1,
                slice_2d=False,
            )
            dataset_kwargs.update(extra_kwargs)
            target_dataset = get_dataset(**dataset_kwargs)
            target_model = target_dataset.get_model(encoder_type="swin_unetr")

            # Apply composite task vector for target domain
            composite_task_vector = composite_task_vectors[composite_key]
            target_model.load_task_vector(composite_task_vector)

            # Overwrite the model's head with the saved one from the same task
            head_filename = checkpoint_path / f"{dataset_name}_{target_domain}_{'3d' if use_3d else '2d'}_head.pth"
            if head_filename.exists():
                with torch.serialization.safe_globals([UnetOutBlock, Convolution, Conv3d]):
                    target_model.head.load_state_dict(torch.load(head_filename, map_location="cuda" if torch.cuda.is_available() else "cpu").state_dict())
            else:
                print(
                    f"   ❗ Warning: Head file {head_filename} does not exist. Skipping evaluation."
                )
                continue

            # Evaluate cross-domain performance
            metrics = target_model.evaluate()
            update_metrics(f"{composite_key}_adaptation", metrics)

            print(
                f"   ✅ {dataset_name} {target_domain}: Dice={metrics.get('dice', 0):.3f}, Hausdorff={metrics.get('hausdorff', 0):.3f}"
            )
        except Exception as e:
            raise RuntimeError(f"   ❌ {dataset_name} {target_domain}") from  e


🔄 Task Vector Cross-Domain Adaptation Experiments

CHAOS: MR adaptation
Split 'test' - Found 0 segmentation files for patient 11

=== DEBUG: Metadata Analysis for Patient 11 ===
Image metadata keys: ['spacing', original_affine, space, affine, spatial_shape, original_channel_dim]
Image shape: (256, 256, 26)
No segmentation data (test split or no files found)
=== END DEBUG ===

Split 'test' - Found 0 segmentation files for patient 11

=== DEBUG: Metadata Analysis for Patient 11 ===
Image metadata keys: ['spacing', original_affine, space, affine, spatial_shape, original_channel_dim]
Image shape: (256, 256, 26)
No segmentation data (test split or no files found)
=== END DEBUG ===

2025-07-28 12:55:58,428 - INFO - Expected md5 is None, skip md5 check for file data/ssl_pretrained_weights.pth.
2025-07-28 12:55:58,429 - INFO - File exists: data/ssl_pretrained_weights.pth, skipped downloading.
2025-07-28 12:55:58,428 - INFO - Expected md5 is None, skip md5 check for file data/ssl_pretrained_we

Evaluating train:   0%|          | 0/20 [00:00<?, ?it/s]

Split 'train' - Found 36 segmentation files for patient 1


=== DEBUG: Metadata Analysis for Patient 1 ===
Image metadata keys: ['spacing', original_affine, space, affine, spatial_shape, original_channel_dim]

=== DEBUG: Metadata Analysis for Patient 1 ===
Image metadata keys: ['spacing', original_affine, space, affine, spatial_shape, original_channel_dim]
Image shape: (256, 256, 36)
Segmentation metadata keys: ['format', 'mode', 'width', 'height', spatial_shape, original_channel_dim]
Segmentation shape: (256, 256, 36)Image shape: (256, 256, 36)
Segmentation metadata keys: ['format', 'mode', 'width', 'height', spatial_shape, original_channel_dim]
Segmentation shape: (256, 256, 36)

Total seg slices loaded: 36
Checking if segmentation metadata varies across slices:Total seg slices loaded: 36
Checking if segmentation metadata varies across slices:

  Slice 1: original_channel_dim differs from first slice
    First: nan  Slice 1: original_channel_dim differs from first slice
    First: na

Evaluating train:   0%|          | 0/20 [00:03<?, ?it/s]



RuntimeError:    ❌ CHAOS MR

In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # After Head-training performance
    print("\n🏋️‍♂️ After Head-Training Performance:")
    for key, metrics in all_metrics.items():
        if "head" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Finetuned performance
    print("\n🏆 Finetuned Performance:")
    for key, metrics in all_metrics.items():
        if "finetuned" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "adaptation" in key:
            dice = metrics.get("train").get("dice", 0)
            hausdorff = metrics.get("train").get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")


📊 COMPREHENSIVE RESULTS ANALYSIS

🏁 Baseline Performance:
   CHAOS_MR_3d_baseline: Dice=0.241, HD=74.361
   CHAOS_CT_3d_baseline: Dice=0.127, HD=100.085
   MMWHS_MR_3d_baseline: Dice=0.001, HD=72.752
   MMWHS_CT_3d_baseline: Dice=0.161, HD=62.818

🏋️‍♂️ After Head-Training Performance:
   CHAOS_MR_3d_head: Dice=0.229, HD=74.416
   CHAOS_CT_3d_head: Dice=0.384, HD=46.022
   MMWHS_MR_3d_head: Dice=0.223, HD=29.732
   MMWHS_CT_3d_head: Dice=0.205, HD=37.370

🏆 Finetuned Performance:
   CHAOS_MR_3d_finetuned: Dice=0.753, HD=61.213
   CHAOS_CT_3d_finetuned: Dice=0.518, HD=69.145
   MMWHS_MR_3d_finetuned: Dice=0.777, HD=32.957
   MMWHS_CT_3d_finetuned: Dice=0.790, HD=28.309

🔄 Cross-Domain Adaptation Results:
   CHAOS_MR_adaptation: Dice=0.229, HD=74.416
   CHAOS_CT_adaptation: Dice=0.384, HD=46.022
   MMWHS_MR_adaptation: Dice=0.223, HD=29.732
   MMWHS_CT_adaptation: Dice=0.205, HD=37.370


In [None]:
if IN_COLAB:
    import shutil


    # Copy checkpoints.zip to Google Drive
    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    shutil.copy('/content/checkpoints.zip', '/content/drive/MyDrive/xai/checkpoints.zip')

    # Copy metrics.json to Google Drive
    shutil.copy('/content/xai/outputs/metrics.json', '/content/drive/MyDrive/xai/metrics.json')

In [None]:
if IN_KAGGLE:
    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints