# Conflict minimization when combining task vectors (AWD + TATR)

This notebook scaffolds experiments to reduce conflicts when combining **task vectors** (parameter deltas).

We implement:
- **Adaptive Weight Disentanglement (AWD)**: edit each task vector by *removing* components that are aligned with the (conflicting) subspace induced by other tasks.
- **Task Arithmetic in Trust Region (TATR)**: combine (possibly edited) vectors, then **scale** the result to remain inside a trust region (norm bound + optional cosine constraints vs each task).

We test effectiveness by:
1) Comparing task vectors before/after editing (cosine/conflict matrices, norms)
2) Comparing combined vector alignment to each task
3) Running a small sample evaluation on 4 datasets

> You will need to plug in how your repo loads models and datasets. The vector logic is framework-agnostic beyond PyTorch state_dicts.

In [10]:
# If needed (uncomment):
# %pip install -r ../requirements.txt

import copy
import math
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple, Callable, Optional

import numpy as np
import torch

def _find_repo_root(start: Path) -> Path:
    p = start.resolve()
    for _ in range(6):
        if (p / "src").exists():
            return p
        p = p.parent
    return start.resolve()

REPO_ROOT = _find_repo_root(Path.cwd())
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

print("Repo root:", REPO_ROOT)

Repo root: /home/basilef/Documents/Magistrale/Anno2_Semestre2/Explainable_and_Trustworthy_AI/project


## 0) Configuration

Define where your **base model** and **task-finetuned models** live. You can use either:
- paths to checkpoints (preferred), or
- in-memory model objects.

The only hard requirement is that you can obtain `state_dict()` for base and each task model.

In [11]:
from dataclasses import dataclass
from pathlib import Path
from typing import Tuple

def _find_repo_root(start: Path) -> Path:
    p = start.resolve()
    for _ in range(6):
        if (p / "src").exists():
            return p
        p = p.parent
    return start.resolve()

REPO_ROOT = _find_repo_root(Path.cwd())

@dataclass
class ExperimentConfig:
    # Evaluation device (models will be moved here for forward passes)
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    # Vector arithmetic device (keep on CPU to avoid GPU OOM during cosine/stacking)
    vector_device: str = "cpu"
    dtype: torch.dtype = torch.float32

    # 2D pipeline only
    use_3d: bool = False
    encoder_type: str = "clipseg"

    # The 4 tasks in this repo = (dataset, domain) pairs
    task_names: Tuple[str, str, str, str] = ("CHAOS_CT", "CHAOS_MR", "MMWHS_CT", "MMWHS_MR")

    # Paths (absolute, derived from repo root)
    repo_root: Path = REPO_ROOT
    data_path: Path = REPO_ROOT / "data"
    checkpoint_path: Path = REPO_ROOT / "checkpoints"

    # Data / eval knobs
    batch_size: int = 4
    spatial_size: int = 128
    num_workers: int = 0
    max_eval_batches: int = 10

    # Task addition scaling (mirrors local.ipynb usage of alpha)
    addition_alpha: float = 0.8

    # Trust region parameters (TATR)
    trust_radius: float = 5.0
    min_cos_to_each_task: float = -0.2
    tatr_max_steps: int = 30
    tatr_shrink: float = 0.85

    # AWD parameters
    awd_strength: float = 1.0
    awd_k: float = 6.0
    eps: float = 1e-8

CFG = ExperimentConfig()
CFG

ExperimentConfig(device='cpu', vector_device='cpu', dtype=torch.float32, use_3d=False, encoder_type='clipseg', task_names=('CHAOS_CT', 'CHAOS_MR', 'MMWHS_CT', 'MMWHS_MR'), repo_root=PosixPath('/home/basilef/Documents/Magistrale/Anno2_Semestre2/Explainable_and_Trustworthy_AI/project'), data_path=PosixPath('/home/basilef/Documents/Magistrale/Anno2_Semestre2/Explainable_and_Trustworthy_AI/project/data'), checkpoint_path=PosixPath('/home/basilef/Documents/Magistrale/Anno2_Semestre2/Explainable_and_Trustworthy_AI/project/checkpoints'), batch_size=4, spatial_size=128, num_workers=0, max_eval_batches=10, addition_alpha=0.8, trust_radius=5.0, min_cos_to_each_task=-0.2, tatr_max_steps=30, tatr_shrink=0.85, awd_strength=1.0, awd_k=6.0, eps=1e-08)

## 1) Task-vector utilities

A task vector is the parameter delta: `Î”_task = Î¸_task - Î¸_base`.

We keep deltas as `Dict[str, Tensor]` keyed by parameter names.

In [12]:
StateDict = Dict[str, torch.Tensor]
TaskVector = Dict[str, torch.Tensor]

def to_device(sd: StateDict, device: str, dtype: torch.dtype) -> StateDict:
    out = {}
    for k, v in sd.items():
        if torch.is_tensor(v):
            out[k] = v.detach().to(device=device, dtype=dtype)
    return out

def compute_task_vector(base_sd: StateDict, task_sd: StateDict) -> TaskVector:
    # Î”_task = Î¸_task - Î¸_base, on matching tensor keys/shapes
    delta = {}
    for k in base_sd.keys():
        if k in task_sd and torch.is_tensor(base_sd[k]) and torch.is_tensor(task_sd[k]):
            if base_sd[k].shape == task_sd[k].shape:
                if getattr(base_sd[k].dtype, "is_floating_point", False):
                    delta[k] = task_sd[k] - base_sd[k]
    return delta

def apply_task_vector_(model: torch.nn.Module, delta: TaskVector, scale: float = 1.0) -> None:
    # In-place add to parameters; avoids materializing full state_dict copies.
    with torch.no_grad():
        for name, param in model.named_parameters():
            if name in delta and param.shape == delta[name].shape:
                param.add_(delta[name].to(device=param.device, dtype=param.dtype), alpha=float(scale))

def tv_dot(a: TaskVector, b: TaskVector) -> torch.Tensor:
    # Streaming dot product over intersection of keys.
    keys = set(a.keys()) & set(b.keys())
    acc = None
    for k in keys:
        av = a[k]
        bv = b[k]
        if av.shape != bv.shape:
            continue
        term = (av.reshape(-1) * bv.reshape(-1)).sum()
        acc = term if acc is None else (acc + term)
    if acc is None:
        return torch.tensor(0.0)
    return acc

def tv_norm2(a: TaskVector) -> torch.Tensor:
    acc = None
    for v in a.values():
        term = (v.reshape(-1) * v.reshape(-1)).sum()
        acc = term if acc is None else (acc + term)
    if acc is None:
        return torch.tensor(0.0)
    return acc

def tv_norm(a: TaskVector, eps: float = 1e-8) -> float:
    return float(torch.sqrt(tv_norm2(a) + eps).detach().cpu())

def tv_cosine(a: TaskVector, b: TaskVector, eps: float = 1e-8) -> float:
    dot = tv_dot(a, b)
    na2 = tv_norm2(a)
    nb2 = tv_norm2(b)
    denom = torch.sqrt((na2 + eps) * (nb2 + eps))
    if float(denom.detach().cpu()) == 0.0:
        return float("nan")
    return float((dot / denom).detach().cpu())

def pairwise_cosine_matrix(deltas: List[TaskVector], eps: float = 1e-8) -> np.ndarray:
    n = len(deltas)
    M = np.zeros((n, n), dtype=np.float64)
    for i in range(n):
        for j in range(n):
            M[i, j] = tv_cosine(deltas[i], deltas[j], eps=eps)
    return M

def conflict_score_from_cos(M: np.ndarray) -> float:
    # Average negative cosine off-diagonal (how much tasks point against each other)
    n = M.shape[0]
    vals = []
    for i in range(n):
        for j in range(n):
            if i != j:
                vals.append(min(0.0, float(M[i, j])))
    return float(np.mean(vals)) if vals else float("nan")

## 2) AWD â€” Adaptive Weight Disentanglement

Intuition: for each task vector and each parameter tensor, if it is **anti-aligned** with the mean direction of other tasks (conflict), remove the component along that conflicting mean.

We do it **adaptively** with a gate `g âˆˆ [0,1]` that grows when cosine is negative.

**Result**: edited deltas `Î”'_t` with reduced negative interference.

In [13]:
def _sigmoid(x: torch.Tensor) -> torch.Tensor:
    return 1.0 / (1.0 + torch.exp(-x))

def awd_edit(
    deltas: List[TaskVector],
    strength: float = 1.0,
    k: float = 6.0,
    eps: float = 1e-8,
) -> Tuple[List[TaskVector], Dict[str, float]]:
    """Adaptive Weight Disentanglement (AWD).

    For each task t and param key:
      m = mean_{u!=t}(Î”_u)
      cos = <Î”_t, m> / (||Î”_t|| ||m||)
      gate = sigmoid(k * (-cos))  # ~1 when cos negative (conflict), ~0 when positive
      Î”'_t = Î”_t - strength*gate*proj_m(Î”_t)
    """
    T = len(deltas)
    if T < 2:
        return deltas, {"note": "AWD skipped (need >=2 tasks)"}

    keys = sorted(set().union(*[d.keys() for d in deltas]))
    edited: List[TaskVector] = [dict() for _ in range(T)]

    gates_accum = []
    cos_accum = []
    removed_energy = []

    for key in keys:
        tensors = [d.get(key, None) for d in deltas]
        if any(t is None for t in tensors):
            for t in range(T):
                if tensors[t] is not None:
                    edited[t][key] = tensors[t]
            continue

        # Flatten each tensor for this key -> [T, D]
        stacked = torch.stack([tt.reshape(-1) for tt in tensors], dim=0)
        sum_all = stacked.sum(dim=0)  # [D]
        for t in range(T):
            dt = stacked[t]  # [D]
            m = (sum_all - dt) / (T - 1)  # [D]
            dn = dt.norm() + eps
            mn = m.norm() + eps
            c = torch.dot(dt, m) / (dn * mn)
            gate = _sigmoid(k * (-c))
            proj = (torch.dot(dt, m) / (mn * mn)) * m
            dt_new = dt - (strength * gate) * proj

            gates_accum.append(float(gate.detach().cpu()))
            cos_accum.append(float(c.detach().cpu()))
            removed_energy.append(float(((strength * gate) * proj).norm().detach().cpu()))

            edited[t][key] = dt_new.reshape(tensors[t].shape)

    stats = {
        "avg_gate": float(np.mean(gates_accum)) if gates_accum else float("nan"),
        "avg_cos_to_others_mean": float(np.mean(cos_accum)) if cos_accum else float("nan"),
        "avg_removed_norm": float(np.mean(removed_energy)) if removed_energy else float("nan"),
    }
    return edited, stats

## 3) TATR â€” Task Arithmetic in Trust Region

Combine deltas (e.g., sum or weighted sum), then **shrink** the combined update until it satisfies:
- global norm bound: `||Î”_comb|| â‰¤ trust_radius`
- optional alignment constraint: `cos(Î”_comb, Î”_task_i) â‰¥ min_cos_to_each_task`

This is a lightweight trust region that does not require extra gradient steps.

In [14]:
def weighted_sum(deltas: List[TaskVector], weights: List[float]) -> TaskVector:
    assert len(deltas) == len(weights)
    keys = sorted(set().union(*[d.keys() for d in deltas]))
    out: TaskVector = {}
    for k in keys:
        acc = None
        for d, w in zip(deltas, weights):
            if k not in d:
                continue
            term = d[k] * float(w)
            acc = term if acc is None else (acc + term)
        if acc is not None:
            out[k] = acc
    return out

def scale_delta(delta: TaskVector, s: float) -> TaskVector:
    return {k: v * float(s) for k, v in delta.items()}

def tatr_combine(
    deltas: List[TaskVector],
    weights: Optional[List[float]] = None,
    trust_radius: float = 5.0,
    min_cos_to_each_task: Optional[float] = None,
    max_steps: int = 30,
    shrink: float = 0.85,
    eps: float = 1e-8,
) -> Tuple[TaskVector, Dict[str, float]]:
    if weights is None:
        weights = [1.0] * len(deltas)
    comb0 = weighted_sum(deltas, weights)
    n0 = tv_norm(comb0, eps=eps)

    def ok(delta: TaskVector) -> Tuple[bool, float, float]:
        n = tv_norm(delta, eps=eps)
        min_cos = float("inf")
        if min_cos_to_each_task is not None:
            for d in deltas:
                min_cos = min(min_cos, tv_cosine(delta, d, eps=eps))
        else:
            min_cos = float("nan")
        norm_ok = (n <= trust_radius + 1e-12)
        cos_ok = True if min_cos_to_each_task is None else (min_cos >= min_cos_to_each_task)
        return (norm_ok and cos_ok), n, min_cos

    s = 1.0
    best = comb0
    best_s = 1.0
    last_min_cos = float("nan")
    for _step in range(max_steps + 1):
        cand = scale_delta(comb0, s)
        good, n, min_cos = ok(cand)
        last_min_cos = min_cos
        if good:
            best, best_s = cand, s
            break
        s *= shrink

    out_stats = {
        "initial_norm": float(n0),
        "final_scale": float(best_s),
        "final_norm": float(tv_norm(best, eps=eps)),
    }
    if min_cos_to_each_task is not None:
        out_stats["final_min_cos_to_tasks"] = float(last_min_cos)
    return best, out_stats

## 4) Vector-level diagnostics (before/after AWD, and after combining)

These are quick checks to quantify conflicts:
- pairwise cosine matrix
- conflict score (mean negative cosine off-diagonal)
- norms
- alignment of combined vector to each task

In [15]:
def summarize_vectors(task_names: List[str], deltas: List[TaskVector], eps: float = 1e-8) -> Dict[str, object]:
    M = pairwise_cosine_matrix(deltas, eps=eps)
    norms = [tv_norm(d, eps=eps) for d in deltas]
    return {
        "task_names": task_names,
        "cosine_matrix": M,
        "conflict_score": conflict_score_from_cos(M),
        "norms": norms,
    }

def combined_alignment(task_names: List[str], comb: TaskVector, deltas: List[TaskVector], eps: float = 1e-8) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for name, d in zip(task_names, deltas):
        out[f"cos_to_{name}"] = tv_cosine(comb, d, eps=eps)
    out["combined_norm"] = tv_norm(comb, eps=eps)
    return out

def print_matrix(names: List[str], M: np.ndarray, fmt: str = "{:+.3f}") -> None:
    header = " " * 12 + " ".join([f"{n:>10}" for n in names])
    print(header)
    for i, ni in enumerate(names):
        row = " ".join([f"{fmt.format(M[i,j]):>10}" for j in range(len(names))])
        print(f"{ni:>10}  {row}")

## 5) Minimal evaluation harness on 4 datasets

This is intentionally pluggable, because your repo likely has its own data + evaluation.

You provide:
- `load_base_model() -> nn.Module`
- `load_task_model(task_name) -> nn.Module` or task checkpoint
- `get_eval_loader(task_name) -> iterable` yielding batches
- `eval_step(model, batch) -> Dict[str,float]` (e.g., loss, accuracy)

Then we evaluate:
- base model
- per-task edited model (base + edited Î”_task)
- combined model (base + combined Î”)

Use small sample sizes to keep this notebook fast.

In [16]:
def evaluate_on_loader(
    model: torch.nn.Module,
    loader,
    eval_step: Callable[[torch.nn.Module, object], Dict[str, float]],
    max_batches: int = 10,
) -> Dict[str, float]:
    model.eval()
    agg = {}
    n = 0
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= max_batches:
                break
            out = eval_step(model, batch)
            for k, v in out.items():
                agg[k] = agg.get(k, 0.0) + float(v)
            n += 1
    if n == 0:
        return {"note": "no batches"}
    return {k: v / n for k, v in agg.items()}

def clone_model(model: torch.nn.Module) -> torch.nn.Module:
    # Generic cloning: reinstantiate is repo-specific; here we deep-copy.
    return copy.deepcopy(model)

## 6) Experiment runner (fill in model/dataset hooks)

Steps:
1) Load base + 4 task-finetuned models
2) Compute task deltas
3) AWD edit deltas
4) Combine with TATR
5) Run vector diagnostics + small-sample evaluation

### You must edit the TODO hooks below.

In [None]:
# --- 2D pipeline hooks (derived from local.ipynb/local.py) ---
import gc

from src.datasets.registry import get_dataset
from src.datasets.common import BaseDataset
from src.utils import download_and_extract_dataset
from monai import transforms

def _cleanup_memory() -> None:
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def parse_task_name(task_name: str) -> Tuple[str, str]:
    dataset_name, domain = task_name.split("_", 1)
    return dataset_name, domain

def encoder_ckpt_path(task_name: str, kind: str) -> Path:
    # kind in {'baseline','finetuned'}
    ds, dom = parse_task_name(task_name)
    return CFG.checkpoint_path / f"{ds}_{dom}_{'3d' if CFG.use_3d else '2d'}_{kind}.pth"

def load_encoder_checkpoint(path: Path) -> torch.nn.Module:
    if not path.exists():
        raise FileNotFoundError(f"Missing checkpoint: {path}")
    # Checkpoints were saved with torch.save(model.encoder, path)
    return torch.load(path, map_location=CFG.vector_device, weights_only=False)

def load_baseline_encoder() -> torch.nn.Module:
    # Baselines are identical across tasks in this repo; pick a canonical baseline.
    base_path = encoder_ckpt_path(CFG.task_names[0], kind="baseline")
    return load_encoder_checkpoint(base_path)

def load_finetuned_encoder(task_name: str) -> torch.nn.Module:
    fin_path = encoder_ckpt_path(task_name, kind="finetuned")
    return load_encoder_checkpoint(fin_path)

# Normalization stats (mean, std) per dataset/domain
NORM_STATS = {
    ("MMWHS", "MR"): (186.5875, 258.5917),
    ("MMWHS", "CT"): (-745.0086, 1042.7251),
    ("CHAOS", "MR"): (90.8292, 168.8922),
    ("CHAOS", "CT"): (-478.1732, 476.7163),
}

def get_decode_func(dataset_name: str, domain: str):
    from src.datasets.mmwhs import mmwhs_labels
    if dataset_name == "CHAOS":
        if domain in ["MR", "MRI"]:
            return lambda labels: labels // 63
        if domain == "CT":
            return lambda labels: torch.where(labels > 0, 1.0, 0.0)
    if dataset_name == "MMWHS":
        def decode(labels):
            decoded_labels = torch.zeros_like(labels, dtype=torch.float32)
            for i, label_val in enumerate(mmwhs_labels.keys()):
                decoded_labels[labels == label_val] = i
            return decoded_labels
        return decode
    return lambda labels: labels

def get_preprocessing(dataset_name: str, domain: str, is_training: bool):
    decode_func = get_decode_func(dataset_name, domain)
    mean, std = NORM_STATS.get((dataset_name, domain), (None, None))

    image_transforms = [
        transforms.Lambda(lambda x: x.squeeze(-1)),
        transforms.EnsureChannelFirst(channel_dim="no_channel"),
        transforms.Resize(
            spatial_size=CFG.spatial_size,
            size_mode="longest",
            mode="area",
            anti_aliasing=True,
        ),
        transforms.ToTensor(),
        transforms.EnsureType(dtype=torch.float32),
    ]
    if mean is not None and std is not None:
        image_transforms.append(
            transforms.NormalizeIntensity(
                subtrahend=float(mean),
                divisor=float(std),
                channel_wise=False,
            )
        )
    if is_training:
        image_transforms.extend([
            transforms.RandGaussianNoise(prob=0.15, std=0.05),
            transforms.RandAdjustContrast(prob=0.15, gamma=(0.95, 1.05)),
        ])
    image_transforms.append(transforms.RepeatChannel(repeats=3))
    image_transform = transforms.Compose(image_transforms)

    seg_transforms = [
        transforms.Lambda(lambda x: x.squeeze(-1)),
        transforms.EnsureChannelFirst(channel_dim="no_channel"),
        transforms.ToTensor(),
        transforms.EnsureType(dtype=torch.long),
        transforms.Lambda(lambda x: decode_func(x)),
        transforms.Resize(
            spatial_size=CFG.spatial_size,
            size_mode="longest",
            mode="nearest",
        ),
    ]
    seg_transform = transforms.Compose(seg_transforms)
    return image_transform, seg_transform

def build_dataset_for_task(task_name: str, is_training: bool = False) -> BaseDataset:
    dataset_name, domain = parse_task_name(task_name)
    download_and_extract_dataset(dataset_name, CFG.data_path)
    image_t, seg_t = get_preprocessing(dataset_name, domain, is_training=is_training)
    ds: BaseDataset = get_dataset(
        dataset_name=dataset_name,
        domain=domain,
        transform=image_t,
        seg_transform=seg_t,
        base_path=CFG.data_path,
        batch_size=CFG.batch_size,
        num_workers=CFG.num_workers,
        slice_2d=True,
    )
    if not isinstance(ds, BaseDataset):
        raise TypeError(f"Expected BaseDataset, got {type(ds)}")
    return ds

def unpack_batch(batch):
    if isinstance(batch, dict):
        return batch.get("image"), batch.get("label")
    if isinstance(batch, (list, tuple)) and len(batch) >= 2:
        return batch[0], batch[1]
    return None, None

def eval_step(model: torch.nn.Module, batch) -> Dict[str, float]:
    images, labels = unpack_batch(batch)
    if images is None or labels is None:
        return {}
    images = images.to(CFG.device)
    labels = labels.to(CFG.device)
    try:
        labels = labels.long()
    except Exception:
        pass
    logits = model(images)  # (B, C, H, W) for 2D clipseg
    preds = torch.argmax(logits, dim=1)  # (B, H, W)
    y = labels.squeeze(1) if labels.ndim == 4 else labels
    # mean Dice over foreground classes (exclude background=0)
    eps = 1e-8
    num_classes = logits.shape[1]
    dices = []
    for c in range(1, num_classes):
        p = (preds == c)
        g = (y == c)
        inter = (p & g).sum().float()
        denom = p.sum().float() + g.sum().float()
        dice = (2.0 * inter + eps) / (denom + eps)
        dices.append(dice)
    mean_dice = (
        torch.stack(dices).mean()
        if len(dices)
        else torch.tensor(float("nan"), device=CFG.device)
    )
    return {"mean_dice_fg": float(mean_dice.detach().cpu())}

def evaluate_encoder_on_task(encoder: torch.nn.Module, task_name: str, max_batches: int) -> Dict[str, float]:
    ds = None
    model = None
    try:
        ds = build_dataset_for_task(task_name, is_training=False)
        model = ds.get_model(base_model=CFG.encoder_type).to(CFG.device)
        model.encoder = encoder.to(CFG.device)
        return evaluate_on_loader(model.encoder, ds.test_loader, eval_step, max_batches=max_batches)
    finally:
        # Important: free dataset/model between tasks to avoid RAM creep
        try:
            if model is not None:
                model.to("cpu")
        except Exception:
            pass
        del model, ds
        _cleanup_memory()

def evaluate_baseline_plus_delta(delta: Optional[TaskVector], task_name: str, max_batches: int, scale: float = 1.0) -> Dict[str, float]:
    base_encoder = None
    try:
        base_encoder = load_baseline_encoder()
        if delta is not None:
            apply_task_vector_(base_encoder, delta, scale=scale)
        return evaluate_encoder_on_task(base_encoder, task_name=task_name, max_batches=max_batches)
    finally:
        del base_encoder
        _cleanup_memory()

def build_raw_task_deltas(task_names: List[str]) -> Tuple[torch.nn.Module, List[torch.nn.Module], List[TaskVector]]:
    base_encoder = load_baseline_encoder()
    finetuned_encoders = [load_finetuned_encoder(tn) for tn in task_names]
    base_sd = to_device(base_encoder.state_dict(), CFG.vector_device, CFG.dtype)
    fin_sds = [to_device(m.state_dict(), CFG.vector_device, CFG.dtype) for m in finetuned_encoders]
    deltas = [compute_task_vector(base_sd, sd) for sd in fin_sds]
    return base_encoder, finetuned_encoders, deltas

In [18]:
# --- Important fix: baseline checkpoints are per-task ---
# local.ipynb applies each task vector to the *matching* baseline checkpoint (dataset+domain).
# For sanity checks and fair evaluation, we also load the baseline per task here.

def load_baseline_encoder(task_name: Optional[str] = None) -> torch.nn.Module:
    if task_name is None:
        task_name = CFG.task_names[0]
    base_path = encoder_ckpt_path(task_name, kind="baseline")
    return load_encoder_checkpoint(base_path)

def evaluate_baseline_plus_delta(
    delta: Optional[TaskVector], task_name: str, max_batches: int, scale: float = 1.0
 ) -> Dict[str, float]:
    base_encoder = load_baseline_encoder(task_name)
    if delta is not None:
        apply_task_vector_(base_encoder, delta, scale=scale)
    return evaluate_encoder_on_task(base_encoder, task_name=task_name, max_batches=max_batches)

def build_raw_task_deltas(task_names: List[str]) -> Tuple[List[torch.nn.Module], List[torch.nn.Module], List[TaskVector]]:
    baseline_encoders = [load_baseline_encoder(tn) for tn in task_names]
    finetuned_encoders = [load_finetuned_encoder(tn) for tn in task_names]
    base_sds = [to_device(m.state_dict(), CFG.vector_device, CFG.dtype) for m in baseline_encoders]
    fin_sds = [to_device(m.state_dict(), CFG.vector_device, CFG.dtype) for m in finetuned_encoders]
    deltas = [compute_task_vector(bs, fs) for bs, fs in zip(base_sds, fin_sds)]
    return baseline_encoders, finetuned_encoders, deltas

## 6A) Reference: finetuned vs baseline, and raw task vectors
We first compute the 4 raw task vectors (finetuned âˆ’ baseline) and evaluate:
- **finetuned encoder** on its own task (reference you care about)
- **baseline** (to know the gap)
- **baseline + raw Î”_task** (sanity check: should match finetuned, up to noise)

In [19]:
import pandas as pd

task_names = list(CFG.task_names)
baseline_encoders, finetuned_encoders, deltas = build_raw_task_deltas(task_names)

print("Vector diagnostics (raw deltas):")
orig_summary = summarize_vectors(task_names, deltas, eps=CFG.eps)
print("conflict_score:", orig_summary["conflict_score"])
print("norms:", orig_summary["norms"])
print_matrix(task_names, orig_summary["cosine_matrix"])

rows = []
for tn, fin_enc, dt in zip(task_names, finetuned_encoders, deltas):
    s_finetuned = evaluate_encoder_on_task(fin_enc, tn, max_batches=CFG.max_eval_batches)
    s_baseline = evaluate_baseline_plus_delta(None, tn, max_batches=CFG.max_eval_batches)
    s_base_plus_raw = evaluate_baseline_plus_delta(dt, tn, max_batches=CFG.max_eval_batches, scale=1.0)
    rows.append({
        "task": tn,
        "dice_finetuned": s_finetuned.get("mean_dice_fg"),
        "dice_baseline": s_baseline.get("mean_dice_fg"),
        "dice_base_plus_raw_delta": s_base_plus_raw.get("mean_dice_fg"),
        "delta_norm": tv_norm(dt, eps=CFG.eps),
    })

ref_df = pd.DataFrame(rows).set_index("task")
display(ref_df)

Vector diagnostics (raw deltas):
conflict_score: 0.0
norms: [57.1956901550293, 34.02861022949219, 84.9603042602539, 68.02688598632812]
              CHAOS_CT   CHAOS_MR   MMWHS_CT   MMWHS_MR
  CHAOS_CT      +1.000     +0.350     +0.342     +0.343
  CHAOS_MR      +0.350     +1.000     +0.271     +0.278
  MMWHS_CT      +0.342     +0.271     +1.000     +0.304
  MMWHS_MR      +0.343     +0.278     +0.304     +1.000
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver', 'Right kidney', 'Left kidney', 'Spleen']
ðŸ”„ Loading CLIPSeg weights...
Found explicit b

Unnamed: 0_level_0,dice_finetuned,dice_baseline,dice_base_plus_raw_delta,delta_norm
task,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
CHAOS_CT,0.971517,0.132276,0.971529,57.19569
CHAOS_MR,0.821681,0.018059,0.821549,34.02861
MMWHS_CT,0.936441,0.021665,0.936432,84.960304
MMWHS_MR,0.871081,0.010305,0.871078,68.026886


## 6B) AWD test (edit each task vector, no trust region)
This block isolates AWD: we edit each $\Delta_{task}$ and compare against the **finetuned** model (not baseline).

In [20]:
awd_deltas, awd_stats = awd_edit(deltas, strength=CFG.awd_strength, k=CFG.awd_k, eps=CFG.eps)
print("AWD stats:", awd_stats)

print("Vector diagnostics (after AWD):")
awd_summary = summarize_vectors(task_names, awd_deltas, eps=CFG.eps)
print("conflict_score:", awd_summary["conflict_score"])
print("norms:", awd_summary["norms"])
print_matrix(task_names, awd_summary["cosine_matrix"])

rows = []
for tn, fin_enc, dt_raw, dt_awd in zip(task_names, finetuned_encoders, deltas, awd_deltas):
    s_finetuned = evaluate_encoder_on_task(fin_enc, tn, max_batches=CFG.max_eval_batches)
    s_base_plus_awd = evaluate_baseline_plus_delta(dt_awd, tn, max_batches=CFG.max_eval_batches, scale=1.0)
    rows.append({
        "task": tn,
        "dice_finetuned": s_finetuned.get("mean_dice_fg"),
        "dice_base_plus_AWD_delta": s_base_plus_awd.get("mean_dice_fg"),
        "cos(AWD_delta, raw_delta)": tv_cosine(dt_awd, dt_raw, eps=CFG.eps),
        "||AWD_delta||": tv_norm(dt_awd, eps=CFG.eps),
        "||raw_delta||": tv_norm(dt_raw, eps=CFG.eps),
    })

awd_df = pd.DataFrame(rows).set_index("task")
display(awd_df)

AWD stats: {'avg_gate': 0.46810992170249516, 'avg_cos_to_others_mean': 0.04080290108022237, 'avg_removed_norm': 0.012791253246163366}
Vector diagnostics (after AWD):
conflict_score: 0.0
norms: [56.936527252197266, 33.85821533203125, 84.52819061279297, 67.6814193725586]
              CHAOS_CT   CHAOS_MR   MMWHS_CT   MMWHS_MR
  CHAOS_CT      +1.000     +0.331     +0.313     +0.317
  CHAOS_MR      +0.331     +1.000     +0.242     +0.252
  MMWHS_CT      +0.313     +0.242     +1.000     +0.264
  MMWHS_MR      +0.317     +0.252     +0.264     +1.000
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver', 'Right kidney', 'Left kidney', 'Spleen']
ðŸ”„ Loading CLIPSeg weights...
Found explici

Unnamed: 0_level_0,dice_finetuned,dice_base_plus_AWD_delta,"cos(AWD_delta, raw_delta)",||AWD_delta||,||raw_delta||
task,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CHAOS_CT,0.971517,0.971293,0.999634,56.936527,57.19569
CHAOS_MR,0.821681,0.821197,0.999548,33.858215,34.02861
MMWHS_CT,0.936441,0.918757,0.999529,84.528191,84.960304
MMWHS_MR,0.871081,0.867409,0.999524,67.681419,68.026886


## 6C) TATR test (trust region on task addition)
This block isolates TATR: we **add** task vectors (like in local.ipynb) but shrink the combined update into a trust region.

In [21]:
# Addition weights mirror local.ipynb: alpha * (sum of selected task vectors)
add_weights = [CFG.addition_alpha] * len(task_names)

comb_tatr, tatr_stats = tatr_combine(
    deltas,
    weights=add_weights,
    trust_radius=CFG.trust_radius,
    min_cos_to_each_task=CFG.min_cos_to_each_task,
    max_steps=CFG.tatr_max_steps,
    shrink=CFG.tatr_shrink,
    eps=CFG.eps,
 )
print("TATR stats:", tatr_stats)
print("Combined alignment (TATR vs each raw task delta):")
print(combined_alignment(task_names, comb_tatr, deltas, eps=CFG.eps))

rows = []
for tn, fin_enc in zip(task_names, finetuned_encoders):
    s_finetuned = evaluate_encoder_on_task(fin_enc, tn, max_batches=CFG.max_eval_batches)
    s_comb = evaluate_baseline_plus_delta(comb_tatr, tn, max_batches=CFG.max_eval_batches, scale=1.0)
    rows.append({
        "task": tn,
        "dice_finetuned": s_finetuned.get("mean_dice_fg"),
        "dice_baseline_plus_TATR_add": s_comb.get("mean_dice_fg"),
        "cos(comb_tatr, delta_task)": tv_cosine(comb_tatr, deltas[task_names.index(tn)], eps=CFG.eps),
    })

tatr_df = pd.DataFrame(rows).set_index("task")
display(tatr_df)

TATR stats: {'initial_norm': 138.55642700195312, 'final_scale': 0.032945601421837174, 'final_norm': 4.564825057983398, 'final_min_cos_to_tasks': 0.5542226433753967}
Combined alignment (TATR vs each raw task delta):
{'cos_to_CHAOS_CT': 0.7016519904136658, 'cos_to_CHAOS_MR': 0.5542226433753967, 'cos_to_MMWHS_CT': 0.7761434316635132, 'cos_to_MMWHS_MR': 0.7094733119010925, 'combined_norm': 4.564825057983398}
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver', 'Right kidney', 'Left kidney', 'Spleen']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver', 'Right kidney', 'Left kidney', 'Spleen']
ðŸ”„ Loading C

Unnamed: 0_level_0,dice_finetuned,dice_baseline_plus_TATR_add,"cos(comb_tatr, delta_task)"
task,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
CHAOS_CT,0.971517,0.000513,0.701652
CHAOS_MR,0.821681,0.003378,0.554223
MMWHS_CT,0.936441,0.114579,0.776143
MMWHS_MR,0.871081,0.086661,0.709473


## 7) Task addition (local.ipynb-style composites): raw vs AWD vs TATR
This reproduces the *task addition / composite task vectors* patterns from `local.ipynb`, but evaluates them with the notebook's quick Dice-on-test-loader metric.

In [22]:
def _idx(task_name: str) -> int:
    return task_names.index(task_name)

def build_local_composites() -> Dict[str, List[Tuple[str, float]]]:
    # Mirrors local.ipynb composite_task_vectors definitions
    comps: Dict[str, List[Tuple[str, float]]] = {}
    # Dataset composites
    comps["MMWHS"] = [("MMWHS_MR", +1.0), ("MMWHS_CT", +1.0)]
    comps["CHAOS"] = [("CHAOS_MR", +1.0), ("CHAOS_CT", +1.0)]
    # Domain composites
    comps["MR"] = [("CHAOS_MR", +1.0), ("MMWHS_MR", +1.0)]
    comps["CT"] = [("CHAOS_CT", +1.0), ("MMWHS_CT", +1.0)]
    # Cross-domain arithmetic composites (Part 2 in local.ipynb)
    comps["MMWHS_CT_cross"] = [("MMWHS_MR", +1.0), ("CHAOS_CT", +1.0), ("CHAOS_MR", -1.0)]
    comps["MMWHS_MR_cross"] = [("MMWHS_CT", +1.0), ("CHAOS_MR", +1.0), ("CHAOS_CT", -1.0)]
    comps["CHAOS_CT_cross"] = [("CHAOS_MR", +1.0), ("MMWHS_CT", +1.0), ("MMWHS_MR", -1.0)]
    comps["CHAOS_MR_cross"] = [("CHAOS_CT", +1.0), ("MMWHS_MR", +1.0), ("MMWHS_CT", -1.0)]
    return comps

def combine_from_terms(delta_list: List[TaskVector], terms: List[Tuple[str, float]], scale: float) -> TaskVector:
    chosen = [delta_list[_idx(tn)] for tn, _w in terms]
    weights = [scale * float(w) for _tn, w in terms]
    return weighted_sum(chosen, weights)

def tatr_from_terms(delta_list: List[TaskVector], terms: List[Tuple[str, float]], scale: float) -> Tuple[TaskVector, Dict[str, float]]:
    chosen = [delta_list[_idx(tn)] for tn, _w in terms]
    weights = [scale * float(w) for _tn, w in terms]
    return tatr_combine(
        chosen,
        weights=weights,
        trust_radius=CFG.trust_radius,
        min_cos_to_each_task=CFG.min_cos_to_each_task,
        max_steps=CFG.tatr_max_steps,
        shrink=CFG.tatr_shrink,
        eps=CFG.eps,
    )

local_composites = build_local_composites()

# Evaluate, for each target task: dataset composite, domain composite, and cross composite
records = []
for tn, fin_enc in zip(task_names, finetuned_encoders):
    dataset_name, domain = parse_task_name(tn)
    fin = evaluate_encoder_on_task(fin_enc, tn, max_batches=CFG.max_eval_batches).get("mean_dice_fg")

    # (1) dataset composite applied at (dataset, domain) baseline
    ds_terms = local_composites[dataset_name]
    # (2) domain composite applied at (dataset, domain) baseline
    dom_terms = local_composites[domain]
    # (3) cross composite for this task
    cross_terms = local_composites[f"{tn}_cross"]

    for comp_kind, terms in [("dataset", ds_terms), ("domain", dom_terms), ("cross", cross_terms)]:
        # Raw addition
        d_raw = combine_from_terms(deltas, terms, scale=CFG.addition_alpha)
        s_raw = evaluate_baseline_plus_delta(d_raw, tn, max_batches=CFG.max_eval_batches, scale=1.0).get("mean_dice_fg")

        # AWD addition (if available)
        d_awd = combine_from_terms(awd_deltas, terms, scale=CFG.addition_alpha)
        s_awd = evaluate_baseline_plus_delta(d_awd, tn, max_batches=CFG.max_eval_batches, scale=1.0).get("mean_dice_fg")

        # TATR addition (trust region on the combined update)
        d_tatr, stats = tatr_from_terms(deltas, terms, scale=CFG.addition_alpha)
        s_tatr = evaluate_baseline_plus_delta(d_tatr, tn, max_batches=CFG.max_eval_batches, scale=1.0).get("mean_dice_fg")

        records.append({
            "task": tn,
            "composite": comp_kind,
            "dice_finetuned": fin,
            "dice_raw_add": s_raw,
            "dice_AWD_add": s_awd,
            "dice_TATR_add": s_tatr,
            "gap_raw_vs_finetuned": (s_raw - fin) if fin is not None else None,
            "gap_AWD_vs_finetuned": (s_awd - fin) if fin is not None else None,
            "gap_TATR_vs_finetuned": (s_tatr - fin) if fin is not None else None,
            "tatr_final_scale": stats.get("final_scale"),
            "tatr_final_norm": stats.get("final_norm"),
        })

add_df = pd.DataFrame(records)
display(add_df)

# Summary stats (mean over tasks) for each composite kind
summary = (
    add_df.groupby("composite")[
        ["gap_raw_vs_finetuned", "gap_AWD_vs_finetuned", "gap_TATR_vs_finetuned"]
    ]
    .mean()
    .rename(columns={
        "gap_raw_vs_finetuned": "mean_gap_raw",
        "gap_AWD_vs_finetuned": "mean_gap_AWD",
        "gap_TATR_vs_finetuned": "mean_gap_TATR",
    })
)
display(summary)

Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-background classes: ['Liver']
ðŸ”„ Loading CLIPSeg weights...
Found explicit background class in input. Treating it separately.
Non-backgr

Unnamed: 0,task,composite,dice_finetuned,dice_raw_add,dice_AWD_add,dice_TATR_add,gap_raw_vs_finetuned,gap_AWD_vs_finetuned,gap_TATR_vs_finetuned,tatr_final_scale,tatr_final_norm
0,CHAOS_CT,dataset,0.971517,0.8480276,0.842589,7.522615e-05,-0.123489,-0.128928,-0.971442,0.074251,4.520902
1,CHAOS_CT,domain,0.971517,0.7391409,0.756405,0.02223001,-0.232376,-0.215112,-0.949287,0.045599,4.287944
2,CHAOS_CT,cross,0.971517,2.213085e-12,5.4e-05,2.213085e-12,-0.971517,-0.971462,-0.971517,1.0,79.108269
3,CHAOS_MR,dataset,0.821681,0.2147809,0.209067,0.02628602,-0.6069,-0.612614,-0.795395,0.074251,4.520902
4,CHAOS_MR,domain,0.821681,0.1564909,0.163393,0.004492932,-0.66519,-0.658287,-0.817188,0.074251,4.994809
5,CHAOS_MR,cross,0.821681,2.904907e-10,2.3e-05,2.904907e-10,-0.821681,-0.821658,-0.821681,1.0,83.711319
6,MMWHS_CT,dataset,0.936441,0.487583,0.491014,0.005182586,-0.448858,-0.445427,-0.931259,0.045599,4.520413
7,MMWHS_CT,domain,0.936441,0.4626066,0.491676,0.006149314,-0.473835,-0.444765,-0.930292,0.045599,4.287944
8,MMWHS_CT,cross,0.936441,0.04539433,0.043245,0.008305805,-0.891047,-0.893197,-0.928136,0.063113,4.810473
9,MMWHS_MR,dataset,0.871081,0.1223958,0.134596,0.004599404,-0.748685,-0.736485,-0.866482,0.045599,4.520413


Unnamed: 0_level_0,mean_gap_raw,mean_gap_AWD,mean_gap_TATR
composite,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
cross,-0.871339,-0.872219,-0.895683
dataset,-0.481983,-0.480864,-0.891144
domain,-0.464005,-0.45414,-0.884294


## 8) Summary statistics (vs finetuned)
This condenses the key comparisons you asked for into small tables: AWD-only, TATR-only, and task-addition composites.

In [23]:
def _mean(series: pd.Series) -> float:
    s = pd.to_numeric(series, errors="coerce")
    return float(s.mean())

def _mean_abs(series: pd.Series) -> float:
    s = pd.to_numeric(series, errors="coerce")
    return float(s.abs().mean())

summary_rows = []

# AWD-only (per-task edit): compare baseline+AWD_delta vs finetuned
if "awd_df" in globals():
    g = awd_df["dice_base_plus_AWD_delta"] - awd_df["dice_finetuned"]
    summary_rows.append({
        "experiment": "AWD (per-task edit)",
        "mean_gap_vs_finetuned": _mean(g),
        "mean_abs_gap_vs_finetuned": _mean_abs(g),
    })

# TATR-only (sum-all test from 6C): baseline+TATR_add vs finetuned
if "tatr_df" in globals():
    g = tatr_df["dice_baseline_plus_TATR_add"] - tatr_df["dice_finetuned"]
    summary_rows.append({
        "experiment": "TATR (sum all tasks)",
        "mean_gap_vs_finetuned": _mean(g),
        "mean_abs_gap_vs_finetuned": _mean_abs(g),
    })

# Task addition composites (section 7): compare each method vs finetuned, aggregated over (task, composite) rows
if "add_df" in globals():
    for method_col, name in [
        ("gap_raw_vs_finetuned", "Addition: raw"),
        ("gap_AWD_vs_finetuned", "Addition: AWD"),
        ("gap_TATR_vs_finetuned", "Addition: TATR"),
    ]:
        summary_rows.append({
            "experiment": name,
            "mean_gap_vs_finetuned": _mean(add_df[method_col]),
            "mean_abs_gap_vs_finetuned": _mean_abs(add_df[method_col]),
        })

summary_df = pd.DataFrame(summary_rows).set_index("experiment")
display(summary_df)

# Optional: best method per (task, composite) row
if "add_df" in globals():
    best = add_df.copy()
    best["best_method"] = best[["gap_raw_vs_finetuned", "gap_AWD_vs_finetuned", "gap_TATR_vs_finetuned"]].idxmax(axis=1)
    best["best_gap"] = best[["gap_raw_vs_finetuned", "gap_AWD_vs_finetuned", "gap_TATR_vs_finetuned"]].max(axis=1)
    display(best[["task", "composite", "best_method", "best_gap"]])

Unnamed: 0_level_0,mean_gap_vs_finetuned,mean_abs_gap_vs_finetuned
experiment,Unnamed: 1_level_1,Unnamed: 2_level_1
AWD (per-task edit),-0.005516,0.005516
TATR (sum all tasks),-0.848897,0.848897
Addition: raw,-0.605776,0.605776
Addition: AWD,-0.602407,0.602407
Addition: TATR,-0.890374,0.890374


Unnamed: 0,task,composite,best_method,best_gap
0,CHAOS_CT,dataset,gap_raw_vs_finetuned,-0.123489
1,CHAOS_CT,domain,gap_AWD_vs_finetuned,-0.215112
2,CHAOS_CT,cross,gap_AWD_vs_finetuned,-0.971462
3,CHAOS_MR,dataset,gap_raw_vs_finetuned,-0.6069
4,CHAOS_MR,domain,gap_AWD_vs_finetuned,-0.658287
5,CHAOS_MR,cross,gap_AWD_vs_finetuned,-0.821658
6,MMWHS_CT,dataset,gap_AWD_vs_finetuned,-0.445427
7,MMWHS_CT,domain,gap_AWD_vs_finetuned,-0.444765
8,MMWHS_CT,cross,gap_raw_vs_finetuned,-0.891047
9,MMWHS_MR,dataset,gap_AWD_vs_finetuned,-0.736485
