In Kaggle, add the following to the dependencies:
```
pip install torch
pip install torchvision
pip install numpy
pip install pydicom
pip install PILlow
pip install matplotlib
```
Enable file persistence and internet access.
Remember that you can run the whole notebook and close the runtime without wasting resources by going to File > Save Version > Save & Run All (Double check that GPU is selected in the advanced settings).
Later, by going to 'File' > 'Version history' you can view the full logs and download the output files.

In [None]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'napari[pyqt6,optional]==0.6.2a1' 'monai[einops,nibabel]>=1.1.0'

In [2]:
# Check if running in Google Colab
IN_COLAB = False
if not IN_KAGGLE:
    try:
        import google.colab

        IN_COLAB = True
    except Exception:
        pass

    if IN_COLAB:
        !git clone https://github.com/parmigggiana/xai /content/xai
        %cd /content/xai
        !git fetch
        !git reset --hard origin/main
        %pip install -r requirements.txt

In [3]:
from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from pathlib import Path
import json
from src.task_vector import TaskVector
from src.utils import download_and_extract_dataset

In [None]:
dataset_names = ["CHAOS", "MMWHS"]
domains = ["MR", "CT"]
data_path = "data/"
checkpoint_path = "checkpoints/"
outputs_path = "outputs/"
use_3d = True
training_epochs = {
    ("CHAOS", "MR"): 30,
    ("CHAOS", "CT"): 15,
    ("MMWHS", "MR"): 35,
    ("MMWHS", "CT"): 20,
}
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-5

In [5]:
checkpoint_path = Path(checkpoint_path)
outputs_path = Path(outputs_path)
data_path = Path(data_path)
checkpoint_path.mkdir(parents=True, exist_ok=True)
outputs_path.mkdir(parents=True, exist_ok=True)

In [None]:
import torch


def update_metrics(name, new_metrics):
    metrics_file = outputs_path / "metrics.json"

    if not metrics_file.exists():
        metrics = {}
    else:
        with open(metrics_file, "r") as f:
            metrics = json.load(f)

    metrics[name] = new_metrics
    with open(metrics_file, "w") as f:
        json.dump(metrics, f, indent=4)


def get_preprocessing(domain):
    if use_3d:

        def preprocess(x):
            # Only scale down to (96, 128, 128) if larger
            target_shape = (96, 128, 128)
            img = x["image"].float()
            lbl = x["label"].float() if x["label"] is not None else None

            if img.shape[-3:] != target_shape:
                img = torch.nn.functional.interpolate(
                    img.unsqueeze(0),
                    size=target_shape,
                    mode="trilinear",
                    align_corners=False,
                ).squeeze(0)
                if lbl is not None:
                    lbl = (
                        torch.nn.functional.interpolate(
                            lbl.unsqueeze(0), size=target_shape, mode="nearest"
                        )
                        .squeeze(0)
                        .long()
                    )
            else:
                if lbl is not None:
                    lbl = lbl.long()

            return {"image": img, "label": lbl}

    else:
        preprocess = torch.nn.Identity

    return preprocess

In [7]:
# Finetuning loop

for (dataset_name, domain), epochs in training_epochs.items():
    download_and_extract_dataset(dataset_name, data_path)
    preprocess = get_preprocessing(domain)

    filename = f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
    filename = checkpoint_path / filename
    # Check if the finetuned checkpoint already exists
    if filename.exists():
        print(
            f"Finetuned model for {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images already exists at {filename}. Skipping finetuning."
        )
        continue

    print(
        f"Finetuning on {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images "
    )
    dataset: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        preprocess=preprocess,
        base_path=data_path,
        batch_size=1,
        num_workers=1,
        slice_2d=not use_3d,
    )

    model = dataset.get_model(
        encoder_type="swin_unetr",
    )

    # Save the baseline model's state_dict before finetuning
    baseline_filename = (
        checkpoint_path
        / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
    )
    torch.save(model.encoder, baseline_filename)
    print(
        f"Processing {dataset_name} in {domain} domain with {'3d' if use_3d else '2d'} images"
    )
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline",
        model_metrics,
    )

    history = model.finetune(
        epochs=epochs,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
    )
    # Save the finetuned model's state_dict

    torch.save(model.encoder, filename)
    model_metrics = model.evaluate()
    update_metrics(
        f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned",
        model_metrics,
    )
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

Finetuned model for CHAOS in MR domain with 3d images already exists at checkpoints/CHAOS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for CHAOS in CT domain with 3d images already exists at checkpoints/CHAOS_CT_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in MR domain with 3d images already exists at checkpoints/MMWHS_MR_3d_finetuned.pth. Skipping finetuning.
Finetuned model for MMWHS in CT domain with 3d images already exists at checkpoints/MMWHS_CT_3d_finetuned.pth. Skipping finetuning.


# Domain adaptation

In [8]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from torch.nn.modules.conv import ConvTranspose3d
from monai.networks.blocks.dynunet_block import UnetOutBlock

# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in dataset_names:
    for domain in domains:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if use_3d else '2d'} images"
        )
        baseline_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            checkpoint_path
            / f"{dataset_name}_{domain}_{'3d' if use_3d else '2d'}_finetuned.pth"
        )
        with torch.serialization.safe_globals(
            [
                SwinUNETR,
                SwinTransformer,
                PatchEmbed,
                Conv3d,
                Dropout,
                ModuleList,
                BasicLayer,
                SwinTransformerBlock,
                LayerNorm,
                WindowAttention,
                Linear,
                Softmax,
                Identity,
                MLPBlock,
                GELU,
                PatchMerging,
                UnetrBasicBlock,
                UnetResBlock,
                Convolution,
                LeakyReLU,
                InstanceNorm3d,
                UnetrUpBlock,
                ConvTranspose3d,
                UnetOutBlock,
            ]
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
        task_vectors[f"{dataset_name}_{domain}"] = task_vector

Building task vector for CHAOS dataset in MR domain with 3d images
Building task vector for CHAOS dataset in CT domain with 3d images
Building task vector for MMWHS dataset in MR domain with 3d images
Building task vector for MMWHS dataset in CT domain with 3d images


In [None]:
# Task Vector Cross-Domain Evaluation (merged version, with nested loops for all configs)
print("\n🔄 Task Vector Cross-Domain Adaptation Experiments")
print("=" * 80)

for dataset_name in dataset_names:
    for source_domain in domains:
        for target_domain in domains:
            if source_domain == target_domain:
                continue
            key = (dataset_name, source_domain, target_domain)
            metrics_key = f"{dataset_name}_{target_domain}_from_{source_domain}"
            # Only add extra_kwarg 'liver_only' for CHAOS dataset
            extra_kwargs = {}
            if dataset_name == "CHAOS":
                extra_kwargs["liver_only"] = True
            try:
                task_vector_key = f"{dataset_name}_{source_domain}"
                if task_vector_key in task_vectors:
                    print(
                        f"\n{dataset_name}: {source_domain} → {target_domain} adaptation"
                    )
                    # Load target domain dataset
                    dataset_kwargs = dict(
                        dataset_name=dataset_name,
                        domain=target_domain,
                        base_path=data_path,
                        batch_size=1,
                        num_workers=1,
                        slice_2d=False,
                    )
                    dataset_kwargs.update(extra_kwargs)
                    target_dataset = get_dataset(**dataset_kwargs)
                    target_model = target_dataset.get_model(encoder_type="swin_unetr")

                    # Apply task vector from source domain
                    source_task_vector = task_vectors[task_vector_key]
                    target_model.load_task_vector(source_task_vector)

                    # Evaluate cross-domain performance
                    cross_domain_metrics = target_model.evaluate()
                    update_metrics(metrics_key, cross_domain_metrics)

                    print(
                        f"   ✅ {source_domain}→{target_domain}: Dice={cross_domain_metrics.get('dice', 0):.3f}"
                    )
                    print(
                        f"   📊 Hausdorff Distance: {cross_domain_metrics.get('hausdorff', 0):.3f}"
                    )
            except Exception as e:
                print(
                    f"   ❌ {dataset_name} {source_domain}→{target_domain} error: {e}"
                )


🔄 Task Vector Cross-Domain Adaptation Experiments

CHAOS: MR → CT adaptation
2025-07-20 22:12:35,112 - INFO - Expected md5 is None, skip md5 check for file data/ssl_pretrained_weights.pth.
2025-07-20 22:12:35,112 - INFO - File exists: data/ssl_pretrained_weights.pth, skipped downloading.
Total updated layers 159 / 159
Pretrained Weights Succesfully Loaded !
🔍 Evaluating train split...


Evaluating train:   0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
# Load and display all metrics
metrics_file = outputs_path / "metrics.json"
if metrics_file.exists():
    with open(metrics_file, "r") as f:
        all_metrics = json.load(f)

    print("\n📊 COMPREHENSIVE RESULTS ANALYSIS")
    print("=" * 80)

    # Baseline performance
    print("\n🏁 Baseline Performance:")
    for key, metrics in all_metrics.items():
        if "baseline" in key:
            dice = metrics.get("dice", 0)
            hausdorff = metrics.get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")

    # Cross-domain adaptation results
    print("\n🔄 Cross-Domain Adaptation Results:")
    for key, metrics in all_metrics.items():
        if "from" in key:
            dice = metrics.get("dice", 0)
            hausdorff = metrics.get("hausdorff", 0)
            print(f"   {key}: Dice={dice:.3f}, HD={hausdorff:.3f}")
else:
    print("No metrics file found. Run the experiments first.")

In [None]:
if IN_COLAB:
    from google.colab import files, runtime

    !zip -r /content/checkpoints.zip /content/xai/checkpoints
    files.download("/content/checkpoints.zip")
    files.download("/content/xai/outputs/metrics.json")
    runtime.unassign()

In [None]:
if IN_KAGGLE:
    from IPython.display import FileLink

    !zip -r /kaggle/working/checkpoints.zip /kaggle/working/xai/checkpoints