# Shortcut Learning Mitigation with Task Arithmetic (Waterbirds) — Kaggle Notebook

This notebook is a **fully runnable, end-to-end pipeline** that demonstrates:

1. **Shortcut learning** on the *Waterbirds* benchmark (spurious correlation between bird type and background).
2. **Checkpoint / snapshot logging** during training.
3. **Task vectors**:  \(v_T = w_{ft} - w_{pre}\)
4. **Task arithmetic edits**: add / scale / negate task vectors in weight space.
5. **Trajectory analysis**: PCA of weight trajectories + alignment with the final task vector.
6. **Evaluation**: overall accuracy, per-group accuracy, and **worst-group accuracy**.
7. **Plots + saved artifacts** (PNG/JSON/CSV) written to `outputs/`.

> **Kaggle note**: This notebook tries to download Waterbirds via the `wilds` library.  
> If downloads are blocked, you can still run by adding a Kaggle dataset that contains Waterbirds/CUB files and pointing `DATA_ROOT` to it (instructions are included below).

## 0) Install / imports

In [None]:
import sys, subprocess, pkgutil

def pip_install(pkg: str):
    subprocess.check_call([sys.executable, "-m", "pip", "-q", "install", pkg])

missing = []
for pkg in ["wilds", "scikit-learn"]:
    if pkgutil.find_loader(pkg.replace("-", "_")) is None and pkgutil.find_loader(pkg) is None:
        missing.append(pkg)

if missing:
    print("Installing:", missing)
    for m in missing:
        pip_install(m)
else:
    print("All required packages already installed.")

In [None]:
import os
import re
import math
import json
import csv
import random
from pathlib import Path
from typing import Any, Dict, List, Tuple, Optional

import numpy as np
import torch
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from sklearn.decomposition import PCA

from torchvision import transforms, models

from wilds import get_dataset

print("torch:", torch.__version__)
print("cuda available:", torch.cuda.is_available())
print("mps available:", hasattr(torch.backends, "mps") and torch.backends.mps.is_available())

## 1) Background: what we're doing

### Shortcut learning (spurious correlations)
A *shortcut* is a feature that's easy for the model to use on the training distribution, but **does not generalize** when the correlation breaks.

In Waterbirds:
- **Core label**: bird type (waterbird vs landbird).
- **Spurious attribute**: background (water vs land).

### Task vectors + task arithmetic
Let:
- \(w_{pre}\): pretrained weights
- \(w_{ft}\): weights after fine-tuning on task \(T\)

Define task vector:
\[
v_T = w_{ft} - w_{pre}
\]

Task arithmetic edits:
\[
w_{new} = w_{pre} + \alpha v_T
\]

- \(\alpha = 1\): approx. the fine-tuned model  
- \(\alpha = 0.5\): "halfway" fine-tune  
- \(\alpha = -1\): **negate** (often “forget” the task)  

### What counts as success here?
We track:
- **overall accuracy**
- **group accuracy** (group = (label, background))
- **worst-group accuracy**

## 2) Config + reproducibility

In [None]:
SEED = 42

MAX_STEPS = 300          # increase for better performance (e.g., 1500-3000) if you have time
BATCH_SIZE = 64
LR = 1e-4
WEIGHT_DECAY = 0.0

SNAPSHOT_EVERY = 50
EVAL_EVERY = 50
NUM_WORKERS = 2

OUTPUT_ROOT = Path("outputs/waterbirds")
LOG_DIR = OUTPUT_ROOT / "logs"
RES_DIR = OUTPUT_ROOT / "results"
SNAP_DIR = LOG_DIR / "snapshots"
for p in [LOG_DIR, RES_DIR, SNAP_DIR]:
    p.mkdir(parents=True, exist_ok=True)

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    DEVICE = torch.device("mps")
else:
    DEVICE = torch.device("cpu")

print("DEVICE:", DEVICE)

## 3) Load Waterbirds (via WILDS)

We use the `wilds` package:
- `get_dataset("waterbirds")`
- `get_subset("train")`, `get_subset("val")`, `get_subset("test")`

If the download fails on Kaggle (no internet), you can:
1. Add a Kaggle dataset that contains Waterbirds files.
2. Set `DATA_ROOT` to that mounted path (usually `/kaggle/input/<dataset-name>/...`).

The code below auto-detects the background field from `metadata_fields`.

In [None]:
DATA_ROOT = None  # set to a path if you mounted a Kaggle dataset manually

IMG_SIZE = 224
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
])

def find_background_index(metadata_fields: List[str]) -> int:
    candidates = ["place", "background", "env", "environment"]
    lower = [f.lower() for f in metadata_fields]
    for cand in candidates:
        for i, f in enumerate(lower):
            if cand in f:
                return i
    return 0

class WaterbirdsWILDSWrapper(Dataset):
    def __init__(self, subset, transform, bg_index: int):
        self.subset = subset
        self.transform = transform
        self.bg_index = bg_index

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx: int):
        x, y, meta = self.subset[idx]
        if self.transform is not None:
            x = self.transform(x)
        y = y.long()
        spurious = int(meta[self.bg_index].item())
        group = int(y.item()) * 2 + spurious
        return x, {"y": y, "spurious": torch.tensor(spurious), "group": torch.tensor(group)}

dataset = get_dataset(dataset="waterbirds", root_dir=DATA_ROOT)
bg_idx = find_background_index(dataset.metadata_fields)
print("metadata_fields:", dataset.metadata_fields)
print("Using background field index:", bg_idx, "->", dataset.metadata_fields[bg_idx])

train_ds = WaterbirdsWILDSWrapper(dataset.get_subset("train"), transform, bg_idx)
val_ds   = WaterbirdsWILDSWrapper(dataset.get_subset("val"), transform, bg_idx)
test_ds  = WaterbirdsWILDSWrapper(dataset.get_subset("test"), transform, bg_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

xb, yb = next(iter(train_loader))
print("batch x:", xb.shape, xb.dtype)
print("y:", yb["y"][:5].tolist(), "spurious:", yb["spurious"][:5].tolist(), "group:", yb["group"][:5].tolist())

## 4) Metrics: overall accuracy + group accuracy + worst-group accuracy

In [None]:
@torch.no_grad()
def evaluate_grouped(model: nn.Module, loader: DataLoader, device: torch.device) -> Dict[str, Any]:
    model.eval()
    ce = nn.CrossEntropyLoss()

    total = 0
    correct = 0
    loss_sum = 0.0

    group_correct: Dict[int, int] = {}
    group_total: Dict[int, int] = {}

    for x, ydict in loader:
        x = x.to(device)
        y = ydict["y"].to(device)

        logits = model(x)
        loss = ce(logits, y)
        pred = logits.argmax(dim=-1)

        correct += (pred == y).sum().item()
        total += y.numel()
        loss_sum += loss.item() * y.numel()

        groups = ydict["group"].cpu().numpy().astype(int)
        matches = (pred == y).cpu().numpy().astype(int)
        for g, m in zip(groups, matches):
            group_total[g] = group_total.get(g, 0) + 1
            group_correct[g] = group_correct.get(g, 0) + int(m)

    metrics: Dict[str, Any] = {
        "accuracy": correct / max(total, 1),
        "loss": loss_sum / max(total, 1),
    }
    if group_total:
        group_acc = {str(k): group_correct[k] / group_total[k] for k in sorted(group_total)}
        metrics["group_accuracy"] = group_acc
        metrics["worst_group_accuracy"] = float(min(group_acc.values()))
    return metrics

## 5) Model: ResNet-18 (pretrained on ImageNet)

In [None]:
def build_model(num_classes: int = 2) -> nn.Module:
    model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

model_pre = build_model(num_classes=2).to(DEVICE)
print("model ready")

## 6) Training with snapshot logging (fixed number of steps)

In [None]:
def save_checkpoint(path: Path, model: nn.Module, step: int, metrics: Optional[Dict[str, Any]] = None):
    path.parent.mkdir(parents=True, exist_ok=True)
    payload = {
        "step": int(step),
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "metrics": metrics or {},
    }
    torch.save(payload, path)

def load_checkpoint_state(path: Path) -> Dict[str, torch.Tensor]:
    payload = torch.load(path, map_location="cpu")
    if "state_dict" in payload:
        return payload["state_dict"]
    return payload

def train_with_snapshots(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    test_loader: DataLoader,
    device: torch.device,
    max_steps: int,
    lr: float,
    weight_decay: float,
    snapshot_every: int,
    eval_every: int,
    out_dir: Path,
):
    model = model.to(device)
    opt = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, max_steps))
    ce = nn.CrossEntropyLoss()

    save_checkpoint(out_dir / "pretrained.pt", model, step=0, metrics={})

    curve = []
    step = 0
    it = iter(train_loader)

    while step < max_steps:
        try:
            x, ydict = next(it)
        except StopIteration:
            it = iter(train_loader)
            x, ydict = next(it)

        x = x.to(device)
        y = ydict["y"].to(device)

        model.train()
        logits = model(x)
        loss = ce(logits, y)

        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        scheduler.step()
        step += 1

        if step % eval_every == 0 or step == max_steps:
            val_metrics = evaluate_grouped(model, val_loader, device)
            curve.append({
                "step": step,
                "lr": float(opt.param_groups[0]["lr"]),
                "train_loss": float(loss.item()),
                **val_metrics
            })
            print(f"[step {step:4d}] lr={opt.param_groups[0]['lr']:.2e} "
                  f"train_loss={loss.item():.4f} val_acc={val_metrics['accuracy']:.4f} "
                  f"val_worst={val_metrics.get('worst_group_accuracy', float('nan')):.4f}")

        if step % snapshot_every == 0 or step == max_steps:
            val_metrics = evaluate_grouped(model, val_loader, device)
            save_checkpoint(out_dir / "snapshots" / f"ckpt_{step:05d}.pt", model, step=step, metrics=val_metrics)

    test_metrics = evaluate_grouped(model, test_loader, device)
    save_checkpoint(out_dir / "final.pt", model, step=step, metrics=test_metrics)

    (out_dir / "curve.json").write_text(json.dumps(curve, indent=2))
    return curve, test_metrics

curve, test_metrics = train_with_snapshots(
    model_pre, train_loader, val_loader, test_loader, DEVICE,
    max_steps=MAX_STEPS, lr=LR, weight_decay=WEIGHT_DECAY,
    snapshot_every=SNAPSHOT_EVERY, eval_every=EVAL_EVERY,
    out_dir=LOG_DIR
)

print("Final test metrics:", test_metrics)

## 7) Task vector + random-like baseline

In [None]:
def compute_task_vector(pre_state: Dict[str, torch.Tensor], ft_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    vec = {}
    for k in pre_state.keys():
        vec[k] = (ft_state[k] - pre_state[k]).cpu()
    return vec

def vector_stats(vec: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    stats = {}
    total_sq = 0.0
    for k, v in vec.items():
        n = float(v.float().norm().item())
        stats[k] = {"shape": list(v.shape), "l2_norm": n}
        total_sq += n**2
    stats["_total_l2_norm"] = float(math.sqrt(total_sq))
    return stats

def random_like_vector(task_vec: Dict[str, torch.Tensor], seed: int = 0) -> Dict[str, torch.Tensor]:
    g = torch.Generator()
    g.manual_seed(seed)
    out = {}
    for k, v in task_vec.items():
        r = torch.randn(v.shape, generator=g)
        v_norm = v.float().norm()
        r_norm = r.float().norm() + 1e-12
        out[k] = (r * (v_norm / r_norm)).cpu()
    return out

pre_state = load_checkpoint_state(LOG_DIR / "pretrained.pt")
ft_state  = load_checkpoint_state(LOG_DIR / "final.pt")

task_vec = compute_task_vector(pre_state, ft_state)
rand_vec = random_like_vector(task_vec, seed=SEED)

torch.save(task_vec, RES_DIR / "task_vector.pt")
torch.save(rand_vec, RES_DIR / "random_like_vector.pt")
(RES_DIR / "task_vector_stats.json").write_text(json.dumps(vector_stats(task_vec), indent=2))

print("Saved task vectors to:", RES_DIR)

## 8) Apply edits and evaluate (pretrained + alpha * v_T)

In [None]:
def apply_task_edit(pre_state: Dict[str, torch.Tensor], vec: Dict[str, torch.Tensor], alpha: float) -> Dict[str, torch.Tensor]:
    return {k: (pre_state[k].cpu() + alpha * vec[k].cpu()) for k in pre_state.keys()}

def save_state_as_ckpt(path: Path, state: Dict[str, torch.Tensor], step: int = 0, metrics: Optional[Dict[str, Any]] = None):
    payload = {"step": int(step), "state_dict": state, "metrics": metrics or {}}
    torch.save(payload, path)

def eval_state_dict(state: Dict[str, torch.Tensor]) -> Dict[str, Any]:
    m = build_model(num_classes=2).to(DEVICE)
    m.load_state_dict(state, strict=True)
    return evaluate_grouped(m, test_loader, DEVICE)

def eval_ckpt(path: Path) -> Dict[str, Any]:
    st = load_checkpoint_state(path)
    return eval_state_dict(st)

# Make edited checkpoints from pretrained weights
forget_state = apply_task_edit(pre_state, task_vec, alpha=-1.0)
half_state   = apply_task_edit(pre_state, task_vec, alpha=0.5)
add_state    = apply_task_edit(pre_state, task_vec, alpha=1.0)
rand_state   = apply_task_edit(pre_state, rand_vec, alpha=1.0)

save_state_as_ckpt(RES_DIR / "forget.pt", forget_state)
save_state_as_ckpt(RES_DIR / "half.pt", half_state)
save_state_as_ckpt(RES_DIR / "add.pt", add_state)
save_state_as_ckpt(RES_DIR / "random_baseline.pt", rand_state)

ckpts = {
    "pretrained": LOG_DIR / "pretrained.pt",
    "finetuned": LOG_DIR / "final.pt",
    "forget": RES_DIR / "forget.pt",
    "half": RES_DIR / "half.pt",
    "add": RES_DIR / "add.pt",
    "random": RES_DIR / "random_baseline.pt",
}

edited_metrics = {name: eval_ckpt(path) for name, path in ckpts.items()}

for name, met in edited_metrics.items():
    print(name, "acc=", met["accuracy"], "worst=", met.get("worst_group_accuracy", None))

(RES_DIR / "edited_metrics.json").write_text(json.dumps(edited_metrics, indent=2))
with (RES_DIR / "edited_metrics.csv").open("w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["name", "accuracy", "loss", "worst_group_accuracy"])
    for name, met in edited_metrics.items():
        w.writerow([name, met.get("accuracy"), met.get("loss"), met.get("worst_group_accuracy")])

print("Wrote metrics to:", RES_DIR)

## 9) Trajectory analysis: PCA + alignment

In [None]:
def flatten_state(state: Dict[str, torch.Tensor]) -> np.ndarray:
    parts = [v.detach().cpu().float().reshape(-1).numpy() for v in state.values()]
    return np.concatenate(parts) if parts else np.zeros((0,), dtype=np.float32)

snapshots = sorted(SNAP_DIR.glob("ckpt_*.pt"))
steps = []
delta_mat = []
alignments = []

flat_task = flatten_state(task_vec)
flat_task_norm = np.linalg.norm(flat_task) + 1e-12

for p in snapshots:
    m = re.search(r"ckpt_(\d+)\.pt$", p.name)
    if not m:
        continue
    step = int(m.group(1))
    steps.append(step)

    st = load_checkpoint_state(p)
    delta = {k: st[k] - pre_state[k] for k in st.keys()}
    flat_delta = flatten_state(delta)
    delta_mat.append(flat_delta)

    num = float(np.dot(flat_delta, flat_task))
    den = float((np.linalg.norm(flat_delta) + 1e-12) * flat_task_norm)
    alignments.append(num / den)

steps = np.array(steps, dtype=np.int64)
delta_mat = np.stack(delta_mat, axis=0)

pca = PCA(n_components=2)
xy = pca.fit_transform(delta_mat)
expl = pca.explained_variance_ratio_.tolist()

(RES_DIR / "pca_summary.json").write_text(json.dumps({"explained_variance_ratio": expl}, indent=2))
print("PCA explained variance ratio:", expl)

# PCA plot
plt.figure(figsize=(6,5))
plt.scatter(xy[:,0], xy[:,1], c=steps, s=40)
for i, s in enumerate(steps):
    plt.text(xy[i,0], xy[i,1], str(s), fontsize=8)
plt.title("Weight trajectory PCA (deltas from pretrained)")
plt.xlabel("PC1"); plt.ylabel("PC2")
plt.colorbar(label="step")
pca_path = RES_DIR / "trajectory_pca.png"
plt.tight_layout(); plt.savefig(pca_path, dpi=150); plt.show()

# Alignment plot
plt.figure(figsize=(6,4))
plt.plot(steps, alignments, marker="o")
plt.title("Cosine alignment: delta_w(t) vs final task vector")
plt.xlabel("step"); plt.ylabel("cosine similarity")
align_path = RES_DIR / "alignment_curve.png"
plt.tight_layout(); plt.savefig(align_path, dpi=150); plt.show()

print("Saved:", pca_path, "and", align_path)

## 10) Extra plots (learning curves, edited comparison, heatmap, alpha sweep)

In [None]:
curve = json.loads((LOG_DIR / "curve.json").read_text())
curve_steps = np.array([c["step"] for c in curve], dtype=int)
val_acc = np.array([c["accuracy"] for c in curve], dtype=float)
val_worst = np.array([c.get("worst_group_accuracy", np.nan) for c in curve], dtype=float)

plt.figure(figsize=(6,4))
plt.plot(curve_steps, val_acc, marker="o", label="val accuracy")
plt.plot(curve_steps, val_worst, marker="o", label="val worst-group")
plt.title("Validation learning curves")
plt.xlabel("step"); plt.ylabel("metric")
plt.legend()
lc_path = RES_DIR / "learning_curves.png"
plt.tight_layout(); plt.savefig(lc_path, dpi=150); plt.show()

names = list(edited_metrics.keys())
overall = [edited_metrics[n]["accuracy"] for n in names]
worst = [edited_metrics[n].get("worst_group_accuracy", np.nan) for n in names]
x = np.arange(len(names))
w = 0.38

plt.figure(figsize=(8,4))
plt.bar(x - w/2, overall, w, label="overall acc")
plt.bar(x + w/2, worst, w, label="worst-group acc")
plt.xticks(x, names, rotation=30, ha="right")
plt.title("Edited models: overall vs worst-group (test)")
plt.ylabel("accuracy")
plt.legend()
bar_path = RES_DIR / "edited_models_bar.png"
plt.tight_layout(); plt.savefig(bar_path, dpi=150); plt.show()

def group_vec(met: Dict[str, Any], num_groups: int = 4) -> np.ndarray:
    g = met.get("group_accuracy", {})
    out = np.zeros((num_groups,), dtype=float)
    for i in range(num_groups):
        out[i] = float(g.get(str(i), np.nan))
    return out

ftg = group_vec(edited_metrics["finetuned"], 4)
fgg = group_vec(edited_metrics["forget"], 4)
mat = np.vstack([ftg, fgg])

plt.figure(figsize=(7,2.5))
plt.imshow(mat, aspect="auto")
plt.yticks([0,1], ["finetuned", "forget(-1)"])
plt.xticks([0,1,2,3], ["y0-bg0","y0-bg1","y1-bg0","y1-bg1"])
plt.colorbar(label="accuracy")
plt.title("Per-group accuracy (test)")
for i in range(mat.shape[0]):
    for j in range(mat.shape[1]):
        plt.text(j, i, f"{mat[i,j]:.2f}", ha="center", va="center", fontsize=10)
heat_path = RES_DIR / "group_accuracy_heatmap.png"
plt.tight_layout(); plt.savefig(heat_path, dpi=150); plt.show()

alphas = [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0]
sweep = []
for a in alphas:
    st = apply_task_edit(pre_state, task_vec, alpha=a)
    met = eval_state_dict(st)
    sweep.append({"alpha": a, **met})

(RES_DIR / "alpha_sweep.json").write_text(json.dumps(sweep, indent=2))
accs = [s["accuracy"] for s in sweep]
worsts = [s.get("worst_group_accuracy", np.nan) for s in sweep]

plt.figure(figsize=(6,4))
plt.plot(alphas, accs, marker="o", label="overall acc")
plt.plot(alphas, worsts, marker="o", label="worst-group acc")
plt.axvline(0.0, linestyle="--")
plt.title("Alpha sweep: pretrained + alpha * v_T")
plt.xlabel("alpha"); plt.ylabel("accuracy")
plt.legend()
sweep_path = RES_DIR / "alpha_sweep.png"
plt.tight_layout(); plt.savefig(sweep_path, dpi=150); plt.show()

print("Saved plots to:", RES_DIR)

## 11) Outputs

All artifacts are saved under:

- `outputs/waterbirds/logs/`  
  - `pretrained.pt`, `final.pt`, `curve.json`, and snapshots in `snapshots/`

- `outputs/waterbirds/results/`  
  - task vectors: `task_vector.pt`, `random_like_vector.pt`  
  - edited checkpoints: `forget.pt`, `half.pt`, `add.pt`, `random_baseline.pt`  
  - metrics: `edited_metrics.json`, `edited_metrics.csv`  
  - analysis: `pca_summary.json`, `alpha_sweep.json`  
  - figures: `trajectory_pca.png`, `alignment_curve.png`, `learning_curves.png`, `edited_models_bar.png`, `group_accuracy_heatmap.png`, `alpha_sweep.png`

On Kaggle you can download them from the **Output** panel.