# Visualization Function Testing Notebook

This notebook tests the visualization functions `visualize_sample_slice` with configurable dataset and domain settings. It focuses on ground truth visualization and skips 3D visualization as requested.

## Features:
- Tests CHAOS and MMWHS datasets
- Supports CT and MR domains  
- Configurable visualization parameters
- Ground truth visualization testing
- Error handling and cleanup

In [1]:
# First, import torch and torchvision
import torch
from pathlib import Path
from src.datasets.registry import get_dataset
from monai import transforms
from monai.data import MetaTensor

# Patch per retro-compatibilità: PyTorch < 2.6 non ha safe_globals (inutile?)
#if not hasattr(torch.serialization, "safe_globals"):
#    torch.serialization.safe_globals = {}

IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


In [2]:
# Check if running in Kaggle
import os

IN_KAGGLE = False
if os.environ.get("KAGGLE_URL_BASE", ""):
    IN_KAGGLE = True
    !git clone https://github.com/parmigggiana/xai /kaggle/working/xai
    %cd xai
    !git fetch
    !git reset --hard origin/main
    %pip install 'monai[einops,itk,nibabel]>=1.5.0' git+https://github.com/timojl/clipseg.git

## Configuration

Change these constants to test different datasets and domains:
- **DATASET_NAME**: "CHAOS" or "MMWHS"
- **DOMAIN**: "CT" or "MR" 
- **ENCODER_TYPE**: "resnet" or "swin_unetr"

In [3]:
# Configuration constants
DATASET_NAME = "CHAOS"  # Change to "MMWHS" if needed
DOMAIN = "MR"  # Change to "CT" if needed
ENCODER_TYPE = "clipseg"  # Changed from "swin_unetr" to "clipseg"
BATCH_SIZE = 1
NUM_WORKERS = 1
USE_3D = False  # Set to True for 3D data processing
SPATIAL_SIZE = 256

In [4]:
def get_preprocessing(dataset_name: str, domain: str, is_training=True):
    """
    Build MONAI-native Compose transform pipelines for images and segmentations.
    ImageDataset will wrap arrays as MetaTensor so metadata flows automatically.
    """
    decode_func = get_decode_func(dataset_name, domain)

    if USE_3D:
        image_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            # transforms.Orientation(axcodes="RAS"),
        ]
    else:
        image_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]

    if domain == "CT":
        image_transforms.append(
            transforms.ScaleIntensityRange(
                a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True
            )
        )
    else:
        image_transforms.append(
            transforms.NormalizeIntensity(nonzero=True, channel_wise=True)
        )

    if is_training:
        image_transforms.extend(
            [
                transforms.RandGaussianNoise(prob=0.2, std=0.05),
                transforms.RandAdjustContrast(prob=0.2, gamma=(0.9, 1.1)),
            ]
        )

    if not USE_3D:
        image_transforms.append(transforms.RepeatChannel(repeats=3))

    image_transforms.extend(
        [
            transforms.Resize(
                spatial_size=SPATIAL_SIZE,
                size_mode="longest",
                mode="area",
                anti_aliasing=True,
            ),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    if not USE_3D:
        seg_transforms = [
            transforms.Lambda(lambda x: x.squeeze(-1)),
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
        ]
    else:
        seg_transforms = [
            transforms.EnsureChannelFirst(channel_dim="no_channel"),
            # transforms.Orientation(axcodes="RAS"),
        ]

    seg_transforms.extend(
        [
            transforms.Lambda(lambda x: decode_func(x)),
            transforms.Resize(
                spatial_size=SPATIAL_SIZE, size_mode="longest", mode="nearest"
            ),
            transforms.ToTensor(),
            transforms.EnsureType(dtype=torch.float32),
        ]
    )

    return transforms.Compose(image_transforms), transforms.Compose(seg_transforms)


def get_decode_func(dataset_name, domain):
    from src.datasets.mmwhs import mmwhs_labels

    decode = None
    if dataset_name == "CHAOS":
        if domain in ["MR", "MRI"]:
            def decode(labels):
                return labels
        elif domain == "CT":
            def decode(labels):
                return torch.where(labels > 0, 255.0, 0.0)
    elif dataset_name == "MMWHS":
        def decode(labels):
            decoded_labels = torch.zeros_like(labels, dtype=torch.float32)
            for i, label_val in enumerate(mmwhs_labels.keys()):
                decoded_labels[labels == label_val] = i
            return decoded_labels

    if decode is None:
        print(
            f"Warning: No decode function defined for {dataset_name} in {domain}. Returning labels unchanged."
        )
        def decode(labels):
            return labels

    return decode

def metatensor_batch_to_dict(batch):
    """
    Converts a batch from your dataset (list of MetaTensor) into a dict with 'image' and 'label'.
    Assumes:
    - batch is a list of length batch_size, each element is a MetaTensor or (MetaTensor, metadata)
    - label is stored in metadata as 'label' key or as a separate tensor
    """
    if len(batch) >= 2:
        # If batch contains separate image and label tensors
        image_data = batch[0]
        label_data = batch[1]

        # Extract image
        if isinstance(image_data, tuple):
            image, _ = image_data
        else:
            image = image_data

        # Extract label
        if isinstance(label_data, tuple):
            label, _ = label_data
        else:
            label = label_data

        # Convert to tensors if needed
        image = image.data.float() if hasattr(image, "data") else torch.tensor(image, dtype=torch.float32)
        label = label.data.float() if hasattr(label, "data") else torch.tensor(label, dtype=torch.float32)

    else:
        # Fallback to original logic if only one element
        sample = batch[0]
        if isinstance(sample, tuple):
            data, metadata = sample
        else:
            data = sample
            metadata = getattr(sample, "meta", {}) or {}

        # Extract image and label
        image = data.data.float() if hasattr(data, "data") else torch.tensor(data, dtype=torch.float32)
        label = metadata.get("label", None)

    return {"image": image, "label": label}

## Ground Truth Visualization Testing

This section loads the dataset and tests the `visualize_sample_slice` function with ground truth data.

In [5]:
from src.datasets.common import BaseDataset


def test_ground_truth_visualization():
    """Test visualization functions with ground truth data."""
    print("🔍 Testing Ground Truth Visualization...")
    print(f"Dataset: {DATASET_NAME}, Domain: {DOMAIN}")

    image_transform, seg_transform = get_preprocessing(
        DATASET_NAME, DOMAIN, is_training=False
    )
    # image_transform = None
    # seg_transform = None

    # Load dataset
    dataset: BaseDataset = get_dataset(
        dataset_name=DATASET_NAME,
        domain=DOMAIN,
        transform=image_transform,  # Use transform instead of preprocess
        seg_transform=seg_transform,  # Pass seg_transform too
        base_path=Path("data"),
        batch_size=1,
        num_workers=0,
        slice_2d=True,
    )


    # Get a sample from the dataset
    loader = dataset.train_loader
    iterat = iter(loader)
    batch = next(iterat)
    batch = next(iterat)

    # Convert batch to dictionary format
    sample_dict = metatensor_batch_to_dict(batch)
    print(f"Image shape: {sample_dict['image'].shape}")
    if sample_dict['label'] is not None:
        print(f"Label shape: {sample_dict['label'].shape}")
    else:
        print("Label: None")

    # Extract a single sample
    #sample = {
    #    "image": batch["image"],  # Keep batch dimension for inference
    #    "label": batch["label"],  # Keep batch dimension for inference
    #}
    return dataset, sample_dict, batch


def test_inference_and_visualization(dataset, batch):
    """Test inference with semantic head training and visualize both GT and prediction."""
    print("🚀 Testing Inference with Semantic Head Training...")
    print(f"Dataset: {DATASET_NAME}, Domain: {DOMAIN}")

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Get model with semantic head
    model = dataset.get_model(ENCODER_TYPE)

    # Check what checkpoint files exist
    checkpoint_dir = Path("checkpoints")
    available_files = list(checkpoint_dir.glob("*.pth"))
    print(f"Available checkpoint files: {[f.name for f in available_files]}")

    # Try different possible checkpoint names
    possible_names = [
        f"{DATASET_NAME}_{DOMAIN}_2d_finetuned.pth",
        f"{DATASET_NAME}_{DOMAIN}_{'3d' if USE_3D else '2d'}_finetuned.pth",
        f"{DATASET_NAME}_{DOMAIN}_2d_baseline.pth",
        f"{DATASET_NAME}_{DOMAIN}_{'3d' if USE_3D else '2d'}_baseline.pth"
    ]

    checkpoint_path = None
    for name in possible_names:
        potential_path = checkpoint_dir / name
        if potential_path.exists():
            checkpoint_path = potential_path
            print(f"Using checkpoint: {checkpoint_path}")
            break

    if checkpoint_path is None:
        print("No suitable checkpoint found. Available files:")
        for f in available_files:
            print(f"  - {f.name}")
        raise FileNotFoundError("No suitable checkpoint found")

    # Load the checkpoint
    try:
        finetuned_state_dict = torch.load(checkpoint_path, map_location=device, weights_only=True)
        if hasattr(finetuned_state_dict, 'state_dict'):
            finetuned_state_dict = finetuned_state_dict.state_dict()
        model.encoder.load_state_dict(finetuned_state_dict)
    except Exception as e:
        print(f"Error loading checkpoint: {e}")
        # Try without weights_only for older checkpoints
        finetuned_state_dict = torch.load(checkpoint_path, map_location=device)
        if hasattr(finetuned_state_dict, 'state_dict'):
            finetuned_state_dict = finetuned_state_dict.state_dict()
        model.encoder.load_state_dict(finetuned_state_dict)

    model.to(device)
    batch["image"] = batch["image"].to(device)

    # Run inference
    print("🔮 Running inference...")
    outputs = model(batch["image"])
    preds = torch.argmax(outputs, dim=1, keepdim=True)

    print(f"Prediction unique values: {torch.unique(preds)}")

    return preds



In [6]:
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d, Conv2d, ConvTranspose2d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList, Sequential, ModuleDict
from monai.networks.nets.swin_unetr import BasicLayer, SwinTransformerBlock, WindowAttention, PatchMerging
from torch.nn.modules.normalization import LayerNorm
from torch.nn.modules.linear import Linear, NonDynamicallyQuantizableLinear, Identity
from torch.nn.modules.activation import Softmax, GELU, LeakyReLU, ReLU, MultiheadAttention
from monai.networks.blocks.mlp import MLPBlock
from monai.networks.blocks.unetr_block import UnetrBasicBlock, UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetResBlock, UnetOutBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.instancenorm import InstanceNorm3d
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.transformer import (
    TransformerEncoderLayer, TransformerEncoder,
    TransformerDecoderLayer, TransformerDecoder
)

# CLIPSeg imports
from src.CLIPSeg import CLIPSeg
from clipseg.clipseg import CLIPDensePredT
from clip.model import CLIP, VisionTransformer, Transformer, ResidualAttentionBlock, QuickGELU

SAFE_GLOBALS = [
    # SwinUNETR components
    SwinUNETR, SwinTransformer, PatchEmbed, Conv3d, Dropout, ModuleList,
    BasicLayer, SwinTransformerBlock, LayerNorm, WindowAttention, Linear,
    Softmax, Identity, MLPBlock, GELU, PatchMerging, UnetrBasicBlock,
    UnetResBlock, Convolution, LeakyReLU, InstanceNorm3d, UnetrUpBlock,
    UnetOutBlock,

    # CLIPSeg components
    CLIPSeg, CLIPDensePredT, CLIP, VisionTransformer, Conv2d, Sequential,
    ResidualAttentionBlock, MultiheadAttention, NonDynamicallyQuantizableLinear,
    QuickGELU, Embedding, ReLU, ConvTranspose2d, TransformerEncoderLayer,
    TransformerEncoder, TransformerDecoderLayer, TransformerDecoder,
    LayerNorm, Transformer, ModuleDict,
]

def load_checkpoint_safely(checkpoint_path, device="cpu"):
    """
    Utility function to safely load PyTorch checkpoints with proper error handling.

    Args:
        checkpoint_path: Path to the checkpoint file
        device: Device to load the checkpoint on

    Returns:
        Loaded checkpoint state dict
    """
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)


    # Handle different checkpoint formats
    if isinstance(checkpoint, dict):
        if 'model_state_dict' in checkpoint:
            return checkpoint['model_state_dict']
        elif 'state_dict' in checkpoint:
            return checkpoint['state_dict']
        else:
            return checkpoint
    else:
        return checkpoint.state_dict() if hasattr(checkpoint, 'state_dict') else checkpoint

In [7]:
def test_inference_and_visualization_clipseg(dataset, batch):
    """Test inference with CLIPSeg model using pre-trained weights."""
    print("🚀 Testing Inference with CLIPSeg...")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    model = dataset.get_model(encoder_type=ENCODER_TYPE)


    # Check for fine-tuned checkpoint (optional)
    checkpoint_dir = Path("checkpoints")
    possible_names = [
        f"{DATASET_NAME}_{DOMAIN}_2d_clipseg_finetuned.pth",
        f"{DATASET_NAME}_{DOMAIN}_clipseg_finetuned.pth",
        f"clipseg_{DATASET_NAME}_{DOMAIN}_finetuned.pth",
    ]

    checkpoint_path = None
    for name in possible_names:
        potential_path = checkpoint_dir / name
        if potential_path.exists():
            checkpoint_path = potential_path
            break

    if checkpoint_path is not None:
        try:
            # Use the safe loading function
            state_dict = load_checkpoint_safely(checkpoint_path, device)

            # Load fine-tuned state dict
            missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
            print(f"✅ Successfully loaded fine-tuned checkpoint")

        except Exception as e:
            print(f"Error loading fine-tuned checkpoint: {e}")
            print("Continuing with pre-trained weights only...")

    model.to(device)

    # Handle batch properly - it's a tuple from the DataLoader
    if isinstance(batch, (list, tuple)):
        sample_dict = metatensor_batch_to_dict(batch)
        batch_image = sample_dict['image'].unsqueeze(0) if sample_dict['image'].dim() == 3 else sample_dict['image']
    else:
        # Fallback if batch is already a dict
        batch_image = batch["image"]

    # Handle tensor shape properly
    print(f"Original batch_image shape: {batch_image.shape}")

    if len(batch_image.shape) == 5:
        batch_image = batch_image.squeeze(1)
        print(f"After squeezing dimension 1: {batch_image.shape}")

    batch_image = batch_image.to(device)

    # Run inference
    print("🔮 Running CLIPSeg inference...")
    model.eval()
    with torch.no_grad():
        outputs = model(batch_image)
        preds = torch.argmax(outputs, dim=1, keepdim=True)

    print(f"Prediction unique values: {torch.unique(preds)}")
    return preds

## Dataset Visualization Method Test

Test the `dataset.visualize_sample_slice()` method with the loaded data. This section will show both ground truth and prediction visualizations for the **same sample** to enable direct comparison. The dataset-specific implementation automatically applies the correct rotation and flip parameters for each dataset type.

In [8]:
import matplotlib.pyplot as plt
import numpy as np


## Dataset Visualization Method Test
dataset, sample, batch = test_ground_truth_visualization()

# Test dataset.visualize_sample_slice method
print("\n📊 Testing dataset.visualize_sample_slice method...")

# Debug: check shapes before processing
print(f"Original image shape: {sample['image'].shape}")
print(f"Original label shape: {sample['label'].shape}")



# Apply squeeze to remove batch and channel dimensions for visualization
sample_viz = {
    "image": sample["image"].squeeze().cpu().numpy(),
    "label": sample["label"].squeeze().cpu().numpy()
}

print(f"After squeeze - Image shape: {sample_viz['image'].shape}")
print(f"After squeeze - Label shape: {sample_viz['label'].shape}")

# 🔍 DEBUG: Check dimensionality before visualization
print(f"\n🔍 DEBUG - seg_slice dimensionality:")
seg_slice = sample_viz['label']
print(f"  seg_slice.shape: {seg_slice.shape}")
print(f"  seg_slice.ndim: {seg_slice.ndim}")
print(f"  seg_slice type: {type(seg_slice)}")

# Check if it's the right shape for legend generation
if seg_slice.ndim >= 2:
    print("  ✅ Shape is compatible with legend generation")
    unique_vals = np.unique(seg_slice)
    print(f"  Unique values in seg_slice: {unique_vals}")
else:
    print("  ❌ Shape is NOT compatible with legend generation")
    print("  Need to fix the shape before visualization")

print(f"  seg_slice.shape: {seg_slice.shape}")
print(f"  seg_slice.ndim: {seg_slice.ndim}")

dataset.visualize_sample_slice(sample_viz)

🔍 Testing Ground Truth Visualization...
Dataset: CHAOS, Domain: MR
Dataset MR total samples: 623
Split sizes - Train: 436, Val: 93, Test: 94


KeyError: 0

In [None]:
# Run inference on the SAME sample with CLIPSeg
print("\n🔮 Testing dataset.visualize_sample_slice method with CLIPSeg PREDICTIONS...")

preds = test_inference_and_visualization_clipseg(dataset, batch)


print(f"Applying {DOMAIN} domain encoding...")
encoded_preds = dataset.encode(preds)
print(f"Encoded prediction unique values: {torch.unique(encoded_preds)}")

# Prepare prediction sample for visualization with proper squeeze
pred_sample_viz = {
    "image": sample["image"].squeeze().cpu().numpy(),
    "label": encoded_preds.squeeze().cpu().numpy()
}


# Use the dataset visualization method
print("\n📊 Visualizing CLIPSeg predictions...")
# Use the SAME orientation parameters as GT to avoid mismatches
dataset.visualize_sample_slice(pred_sample_viz)

In [None]:
# Use the dataset method for prediction visualization

dataset.visualize_sample_slice(sample_viz)
dataset.visualize_sample_slice(pred_sample_viz)
#dataset.visualize_sample_slice(comb_sample)

# Visualization Composite Vectors


In [13]:
# SWIN UNETR Task Vectors
from monai.networks.nets import SwinUNETR
from monai.networks.nets.swin_unetr import SwinTransformer
from monai.networks.blocks.patchembedding import PatchEmbed
from torch.nn.modules.conv import Conv3d
from torch.nn.modules.dropout import Dropout
from torch.nn.modules.container import ModuleList
from monai.networks.nets.swin_unetr import BasicLayer
from monai.networks.nets.swin_unetr import SwinTransformerBlock
from torch.nn.modules.normalization import LayerNorm
from monai.networks.nets.swin_unetr import WindowAttention
from torch.nn.modules.linear import Linear
from torch.nn.modules.activation import Softmax
from torch.nn.modules.linear import Identity
from monai.networks.blocks.mlp import MLPBlock
from torch.nn.modules.activation import GELU
from monai.networks.nets.swin_unetr import PatchMerging
from monai.networks.blocks.unetr_block import UnetrBasicBlock
from monai.networks.blocks.dynunet_block import UnetResBlock
from monai.networks.blocks.convolutions import Convolution
from torch.nn.modules.activation import LeakyReLU
from torch.nn.modules.instancenorm import InstanceNorm3d
from monai.networks.blocks.unetr_block import UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetOutBlock
from torch.nn.modules.conv import ConvTranspose3d

safe_globals = [
    SwinUNETR,
    SwinTransformer,
    PatchEmbed,
    Conv3d,
    Dropout,
    ModuleList,
    BasicLayer,
    SwinTransformerBlock,
    LayerNorm,
    WindowAttention,
    Linear,
    Softmax,
    Identity,
    MLPBlock,
    GELU,
    PatchMerging,
    UnetrBasicBlock,
    UnetResBlock,
    Convolution,
    LeakyReLU,
    InstanceNorm3d,
    UnetrUpBlock,
    ConvTranspose3d,
    UnetOutBlock,
]
##

## CLIPSeg Task Vectors
from src.CLIPSeg import CLIPSeg
from clipseg.clipseg import CLIPDensePredT
from clip.model import (
    CLIP,
    VisionTransformer,
    LayerNorm,
    Transformer,
    ResidualAttentionBlock,
    QuickGELU,
)
from torch.nn.modules.conv import Conv2d, ConvTranspose2d
from torch.nn.modules.container import Sequential
from torch.nn.modules.activation import MultiheadAttention, ReLU
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.modules.sparse import Embedding
from torch.nn.modules.transformer import (
    TransformerEncoderLayer,
    TransformerEncoder,
    TransformerDecoderLayer,
    TransformerDecoder,
)
from torch.nn.functional import relu
from torch.nn.modules.container import ModuleDict

safe_globals.extend(
    [
        CLIPSeg,
        CLIPDensePredT,
        CLIP,
        VisionTransformer,
        Conv2d,
        LayerNorm,
        Transformer,
        Sequential,
        ResidualAttentionBlock,
        MultiheadAttention,
        NonDynamicallyQuantizableLinear,
        QuickGELU,
        Embedding,
        ReLU,
        ConvTranspose2d,
        TransformerEncoderLayer,
        TransformerEncoder,
        TransformerDecoderLayer,
        TransformerDecoder,
        relu,
        ModuleDict,
    ]
)

from pathlib import Path
from src.task_vector import TaskVector

DATASET_NAMES = ["CHAOS", "MMWHS"]
DOMAINS = ["MR","CT"]
CHECKPOINT_PATH = "checkpoints/"
DATA_PATH = "data/"
CHECKPOINT_PATH = Path(CHECKPOINT_PATH)
DATA_PATH = Path(DATA_PATH)


# Build Task Vectors for each dataset and domain
task_vectors = {}
for dataset_name in DATASET_NAMES:
    for domain in DOMAINS:
        print(
            f"Building task vector for {dataset_name} dataset in {domain} domain with {'3d' if USE_3D else '2d'} images"
        )
        baseline_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_baseline.pth"
        )
        finetuned_checkpoint = (
            CHECKPOINT_PATH
            / f"{dataset_name}_{domain}_{'3d' if USE_3D else '2d'}_finetuned.pth"
        )
        if not baseline_checkpoint.exists():
            print(
                f"Baseline checkpoint for {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue
        if not finetuned_checkpoint.exists():
            print(
                f"Finetuned checkpoint {dataset_name} {domain} does not exist. Skipping task vector creation."
            )
            continue

        with torch.serialization.safe_globals(
            safe_globals=safe_globals,
        ):
            task_vector = TaskVector(baseline_checkpoint, finetuned_checkpoint)
            # Remove keys associated with the output layers from the task vector
            # For swin it's all layers starting with '.out'
            # For clipseg it might not be necessary since the model architecture isn't dependent on the number of output features
            if ENCODER_TYPE == "swin_unetr":
                for k in task_vector.keys():
                    if k.startswith(".out"):
                        del task_vector[k]
        task_vectors[f"{dataset_name}_{domain}"] = task_vector



Building task vector for CHAOS dataset in MR domain with 2d images
Building task vector for CHAOS dataset in CT domain with 2d images
Building task vector for MMWHS dataset in MR domain with 2d images
Building task vector for MMWHS dataset in CT domain with 2d images


In [None]:
# Build composite task vectors using arithmetic
composite_task_vectors = {
    "MMWHS": task_vectors["MMWHS_MR"]
    + task_vectors["MMWHS_CT"],
    "CHAOS": task_vectors["CHAOS_MR"]
    + task_vectors["CHAOS_CT"],
    "MR": task_vectors["CHAOS_MR"]
    + task_vectors["MMWHS_MR"],
    "CT": task_vectors["CHAOS_CT"]
    + task_vectors["MMWHS_CT"],
}
alpha = 0.5


# CLIPSeg inference using composite_task_vectors (no checkpoints)
import torch
import matplotlib.pyplot as plt
import numpy as np


def _get_images_from_batch(batch):
    if isinstance(batch, dict):
        images = batch.get("image") or batch.get("images")
    elif isinstance(batch, (list, tuple)):
        images = batch[0]
    else:
        raise ValueError(f"Unsupported batch type: {type(batch)}")
    if hasattr(images, "as_tensor"):
        images = images.as_tensor()
    return images


def _to_numpy_img(x):
    if hasattr(x, "detach"):
        x = x.detach()
    if hasattr(x, "cpu"):
        x = x.cpu()
    x = x.float()
    # expect [B, C, H, W]
    if x.ndim == 4:
        x = x[0, 0]
    elif x.ndim == 3:
        x = x[0]
    return x.numpy()


def _to_numpy_mask(x):
    if hasattr(x, "detach"):
        x = x.detach()
    if hasattr(x, "cpu"):
        x = x.cpu()
    # expect [B, 1, H, W]
    if x.ndim == 4:
        x = x[0, 0]
    elif x.ndim == 3:
        x = x[0]
    return x.long().numpy()


def _plot_triplet(img_np, pred_np, gt_np=None, title_prefix=""):
    cols = 3 if gt_np is not None else 2
    fig, axes = plt.subplots(1, cols, figsize=(12, 4))
    axes = axes if isinstance(axes, (list, tuple, np.ndarray)) else [axes]
    axes[0].imshow(img_np, cmap="gray")
    axes[0].set_title(f"{title_prefix} image")
    axes[0].axis("off")

    axes[1].imshow(pred_np, cmap="nipy_spectral", interpolation="nearest")
    axes[1].set_title(f"{title_prefix} pred")
    axes[1].axis("off")

    if gt_np is not None and cols == 3:
        axes[2].imshow(gt_np, cmap="nipy_spectral", interpolation="nearest")
        axes[2].set_title(f"{title_prefix} label")
        axes[2].axis("off")

    plt.tight_layout()
    plt.show()


def test_inference_and_visualization_clipseg(dataset, batch, composite_key=None, scaling=None, encoder_type_override=None):
    """Run CLIPSeg inference applying a composite task vector instead of a checkpoint.

    Args:
        dataset: dataset instance providing get_model and classnames
        batch: a single batch from a DataLoader
        composite_key: optional key to select the composite task vector. If None,
                       tries dataset name, domain (e.g., "MMWHS", "CHAOS", "CT", "MR"),
                       then combined key "<dataset>_<domain>" (e.g., "MMWHS_CT").
        scaling: optional scaling coef for the task vector (defaults to global alpha if available)
        encoder_type_override: force a specific encoder_type (defaults to global encoder_type)
    Returns:
        preds tensor shaped [B, 1, H, W] with class indices
    """
    print("🚀 Testing Inference with CLIPSeg + composite_task_vectors…")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    enc = encoder_type_override or encoder_type
    model = dataset.get_model(encoder_type=enc)

    # Resolve composite key
    key = composite_key
    if key is None:
        ds_key = getattr(dataset, "name", type(dataset).__name__)
        dom_key = getattr(dataset, "domain", None)
        dom_key = str(dom_key).upper().replace("MRI", "MR") if dom_key else None
        combined_key = f"{ds_key}_{dom_key}" if dom_key else None
        if "composite_task_vectors" in globals() and ds_key in composite_task_vectors:
            key = ds_key
        elif "composite_task_vectors" in globals() and dom_key in composite_task_vectors:
            key = dom_key
        elif "composite_task_vectors" in globals() and combined_key in composite_task_vectors:
            key = combined_key

    # Resolve scaling
    if scaling is None:
        scaling = globals().get("alpha", 1.0)

    # Apply composite task vector if available
    if "composite_task_vectors" in globals():
        ctv = composite_task_vectors.get(key)
        if ctv is not None:
            print(f"🔧 Applying composite task vector: key='{key}', alpha={scaling}")
            model.load_task_vector(ctv, scaling_coef=float(scaling))
        else:
            print(f"⚠️ No composite task vector found for key '{key}'. Using base model.")
    else:
        print("⚠️ composite_task_vectors not defined. Using base model.")

    model.to(device).eval()

    # Prepare images
    images = _get_images_from_batch(batch)
    if images.ndim == 5:  # e.g. [B, 1, H, W, D] -> squeeze depth/channel dims for 2D case
        images = images.squeeze(1)
    images = images.to(device)

    # Inference
    with torch.no_grad():
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1, keepdim=True)

    print(f"Prediction unique values: {torch.unique(preds)}")
    return preds


# Helper: pick an available loader (test > val > train)

def _get_any_loader(ds):
    return getattr(ds, "test_loader", None) or getattr(ds, "val_loader", None) or getattr(ds, "train_loader", None)


# Quick visualization harness for composite_task_vectors

def visualize_composite_predictions(dataset_name: str, target_domain: str, use_key: str, n_samples: int = 2):
    print(f"\n=== Visualize predictions: dataset='{dataset_name}' domain='{target_domain}' with composite='{use_key}' (alpha={alpha}) ===")
    image_transform, seg_transform = get_preprocessing(dataset_name, target_domain, is_training=False)

    extra_kwargs = {}
    if dataset_name == "CHAOS":
        # Keep consistent with experiments above
        extra_kwargs["liver_only"] = True

    ds = get_dataset(
        dataset_name=dataset_name,
        base_path=DATA_PATH,
        domain=target_domain,
        transform=image_transform,
        seg_transform=seg_transform,
        batch_size=BATCH_SIZE,
        num_workers=0,
        slice_2d=not USE_3D,
        cache_max_items=CACHE_MAX_ITEMS,
        enable_cache=ENABLE_CACHE,
        **extra_kwargs,
    )

    loader = _get_any_loader(ds)
    if loader is None:
        print("⚠️ No loader available.")
        return

    batch = next(iter(loader))
    preds = test_inference_and_visualization_clipseg(ds, batch, composite_key=use_key, scaling=alpha)

    # Prepare small preview
    images = _get_images_from_batch(batch)
    labels = batch.get("label") if isinstance(batch, dict) else (batch[1] if isinstance(batch, (list, tuple)) and len(batch) > 1 else None)

    img_np = _to_numpy_img(images)
    pred_np = _to_numpy_mask(preds)
    gt_np = _to_numpy_mask(labels) if labels is not None else None

    _plot_triplet(img_np, pred_np, gt_np, title_prefix=f"{dataset_name}/{target_domain} [{use_key}]")


# Drive a few previews for each dataset/domain with all composite key variants
for ds_name in DATASET_NAMES:
    for dom in DOMAINS:
        # dataset-level key (Part 1)
        if ds_name in composite_task_vectors:
            visualize_composite_predictions(ds_name, dom, use_key=ds_name)
        # domain-level key (Part 1)
        if dom in composite_task_vectors:
            visualize_composite_predictions(ds_name, dom, use_key=dom)
        # combined key (Part 2)
        combined_key = f"{ds_name}_{dom}"
        if combined_key in composite_task_vectors:
            visualize_composite_predictions(ds_name, dom, use_key=combined_key)