In [2]:
#!/usr/bin/env python3
"""
counting_causal_prefix_permute.py

A full pipeline that:
  1. Loads your `count_match_dataset.json`.
  2. For each example i and each prefix-index k:
     • Builds a “prefix-permuted” counterfactual list:
         ‣ prefix = word_list[0 : k+1] shuffled
         ‣ suffix = word_list[k+1 : ] (unchanged)
  3. Runs the original (“clean”) prompt once and caches all hidden states.
  4. Runs each “prefix-permuted” prompt once (for every k) and caches those hidden states.
  5. Trains a ridge‐regression probe, layer by layer, on `(h_clean[ℓ,k] → running_count(k))`.
     Chooses the best‐scoring layer ℓ*.
  6. For every (i, k), patches layer ℓ* by swapping in `h_cf[ℓ*, k]` (the counterfactual prefix’s state)
     into the clean run, then completes the forward pass to get a new final integer prediction.
  7. Visualizes how often / by how much the patched run’s answer shifts toward the counterfactual’s answer.

Note: Here, “counterfactual” means “shuffle the first k+1 words,” not shuffling the suffix.
"""

import json
import random
import re
import os
from pathlib import Path
from transformers import AutoTokenizer

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.linear_model import Ridge
from sklearn.model_selection import train_test_split

# TransformerLens imports
from transformer_lens import HookedTransformer

# ─────────────────────────────────────────────────────────────────────────────
# 0.  CONFIGURATION
# ─────────────────────────────────────────────────────────────────────────────

DEVICE = "cuda"                 # or "cpu"
MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"

DATA_PATH = Path("dataset.json")
MAX_EXAMPLES = 200              # Set to None to use all examples
RANDOM_SEED = 42

CATEGORY_ITEMS = {
    "fruit": [
        "apple", "banana", "cherry", "grape", "orange",
        "pear", "peach", "mango", "tangerine", "plum"
    ],
    "animal": [
        "dog", "cat", "horse", "cow", "sheep",
        "lion", "tiger", "bear", "rabbit", "fox"
    ],
    "vehicle": [
        "car", "bus", "truck", "bicycle", "motorcycle",
        "train", "boat", "plane", "scooter", "van"
    ],
    "instrument": [
        "guitar", "piano", "violin", "drum", "flute",
        "trumpet", "saxophone", "cello", "clarinet", "harp"
    ],
    "furniture": [
        "chair", "table", "sofa", "bed", "desk",
        "cabinet", "dresser", "stool", "wardrobe", "bookshelf"
    ],
}


# ─────────────────────────────────────────────────────────────────────────────
# 1.  HELPERS: Parsing prompts & building “prefix-permuted” runs
# ─────────────────────────────────────────────────────────────────────────────

def parse_prompt(prompt_str):
    """
    Extracts:
      • word_type  (e.g. "fruit")
      • word_list  (list of strings inside the brackets)
      • word_positions (list of (char_start, char_end) for each word)
    """
    m_type = re.search(r"Type:\s*(\w+)", prompt_str)
    if not m_type:
        raise ValueError("Could not find ‘Type:’ in prompt.")
    word_type = m_type.group(1).strip()

    idx_list = prompt_str.find("List:")
    if idx_list < 0:
        raise ValueError("Could not find ‘List:’ in prompt.")
    start_bracket = prompt_str.find("[", idx_list)
    end_bracket = prompt_str.find("]", start_bracket)
    if start_bracket < 0 or end_bracket < 0:
        raise ValueError("Could not find matching brackets for ‘List:’.")

    list_substr = prompt_str[start_bracket+1 : end_bracket].strip()
    word_list = [w.strip() for w in list_substr.split(",") if w.strip()]

    # For each word, find its char-span in the prompt
    list_positions = []
    cursor = start_bracket + 1
    for w in word_list:
        idx = prompt_str.find(w, cursor)
        if idx < 0:
            idx = prompt_str.lower().find(w.lower(), cursor)
            if idx < 0:
                raise ValueError(f"Could not locate word '{w}' in prompt.")
        list_positions.append((idx, idx + len(w)))
        cursor = idx + len(w)

    return word_type, word_list, list_positions


def build_prefix_permuted_list(word_list, k):
    """
    Given an original word_list of length N, and a prefix-index k (0 ≤ k < N),
    return a new list where:
      • prefix = word_list[0 : k+1] is randomly shuffled
      • suffix = word_list[k+1 : N] is left unchanged
    """
    prefix = word_list[: k + 1].copy()
    random.shuffle(prefix)
    suffix = word_list[k + 1 :].copy()
    return prefix + suffix


def reconstruct_prompt(prompt_str, new_list):
    """
    Given the original prompt_str and a new_list of the same length,
    replace the text inside “List: [ ... ]” with “, ”-joined new_list.
    """
    idx_list = prompt_str.find("List:")
    start_bracket = prompt_str.find("[", idx_list)
    end_bracket = prompt_str.find("]", start_bracket)
    if start_bracket < 0 or end_bracket < 0:
        raise ValueError("Could not find List brackets in prompt.")

    before = prompt_str[: start_bracket + 1]
    after = prompt_str[end_bracket:]
    inside = ", ".join(new_list)
    return before + inside + after


# ─────────────────────────────────────────────────────────────────────────────
# 2.  LOAD & FILTER DATA
# ─────────────────────────────────────────────────────────────────────────────

print("Loading dataset…")
with open(DATA_PATH, "r") as f:
    all_data = json.load(f)

random.seed(RANDOM_SEED)
if MAX_EXAMPLES is not None and len(all_data) > MAX_EXAMPLES:
    all_data = random.sample(all_data, MAX_EXAMPLES)
print(f"  → {len(all_data)} prompts selected for analysis.")

parsed = []
for rec in all_data:
    prompt = rec["prompt"]
    try:
        wtype, wlist, wpos = parse_prompt(prompt)
    except Exception as e:
        print(f"Skipping prompt due to parse error: {e}")
        continue
    parsed.append({
        "prompt": prompt,
        "type": wtype,
        "word_list": wlist,
        "word_positions": wpos,
        "gold_answer": rec["answer"],
    })
print(f"Parsed {len(parsed)} prompts successfully.\n")

# ─────────────────────────────────────────────────────────────────────────────
# 3.  INITIALIZE MODEL & TOKENIZER (TransformerLens)
# ─────────────────────────────────────────────────────────────────────────────

print(f"Loading model {MODEL_NAME} on {DEVICE}…")
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,     
    device=DEVICE,
    fold_ln=True,            
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# ─────────────────────────────────────────────────────────────────────────────
# 4.  CACHE CLEAN RUNS (ENTIRE PROMPT) & PREPARE COUNTERFACTUAL RUNS
# ─────────────────────────────────────────────────────────────────────────────
# ─────────────────────────────────────────────────────────────────────────────
# 4a.  Cache the CLEAN run (entire prompt) & compute running counts
# ─────────────────────────────────────────────────────────────────────────────

print("Building caches…")
cache_clean       = {}   # i → cache on the CLEAN prompt
cache_positions   = {}   # i → list of token indices aligned to each word
cache_word_counts = {}   # i → [c₀, c₁, …], running count on the clean prefix

for i, rec in enumerate(tqdm(parsed, desc="CACHING CLEAN")):
    prompt = rec["prompt"]
    wlist  = rec["word_list"]
    wpos   = rec["word_positions"]
    wtype  = rec["type"]

    enc = tokenizer(
        prompt,
        return_tensors="pt",
        return_offsets_mapping=True,
    ).to(DEVICE)

    token_ids = enc["input_ids"][0].cpu().tolist()
    offsets   = enc["offset_mapping"][0].cpu().tolist()  # shape: (seq_len, 2)

    # Align each word to the token index with the greatest character overlap
    word_to_token_idx = []
    for (char_start, char_end) in wpos:
        best_tidx = None
        best_overlap = 0
        for tidx, (ts, te) in enumerate(offsets):
            overlap = min(te, char_end) - max(ts, char_start)
            if overlap > best_overlap:
                best_overlap = overlap
                best_tidx = tidx
        if best_tidx is None or best_overlap <= 0:
            raise RuntimeError(f"Could not align word at chars {(char_start, char_end)}")
        word_to_token_idx.append(best_tidx)
    cache_positions[i] = word_to_token_idx

    # Build true running‐count array using CATEGORY_ITEMS
    pool_lower = {x.lower() for x in CATEGORY_ITEMS[wtype]}
    c_running = []
    count_so_far = 0
    for w in wlist:
        if w.lower() in pool_lower:
            count_so_far += 1
        c_running.append(count_so_far)
    cache_word_counts[i] = c_running

    # Full forward pass, caching all resid_post
    _, cache = model.run_with_cache(enc["input_ids"], return_type="both")
    # Move every tensor in `cache` to CPU, then clear it from GPU
    for key, tensor in cache.items():
        cache[key] = tensor.cpu()
    torch.cuda.empty_cache()
    cache_clean[i] = cache

# ─────────────────────────────────────────────────────────────────────────────
# 4b.  Cache each “prefix-permuted” counterfactual (i, k)
# ─────────────────────────────────────────────────────────────────────────────

cache_cf = {}  # (i, k) → cache object for the prefix‐permuted prompt

for i, rec in enumerate(tqdm(parsed, desc="CACHING COUNTERFACTUAL")):
    original_prompt = rec["prompt"]
    wlist = rec["word_list"]

    for k in range(len(wlist)):
        cf_wlist = build_prefix_permuted_list(wlist, k)
        cf_prompt = reconstruct_prompt(original_prompt, cf_wlist)

        enc_cf = tokenizer(
            cf_prompt,
            return_tensors="pt",
            return_offsets_mapping=True,
        ).to(DEVICE)

        # Align each word in the *prefix-permuted* list to token indices
        # (repeat the same overlap‐based logic so we know where to patch later)
        offsets_cf = enc_cf["offset_mapping"][0].cpu().tolist()
        wpos_cf = []  # but word_positions stayed the same character-wise,
                      # since only the order of words changed inside the brackets.
        # Actually, wpos (from the clean prompt) no longer matches cf_prompt exactly
        # if character offsets shift. But we know that the bracketed “List” substring
        # is identical length, only words change order. So we can re‐extract wpos_cf:
        #
        #   1. Use parse_prompt(cf_prompt) to get fresh wpos_cf
        #   2. Then apply the same overlap‐based matching.
        #
        _, _, wpos_cf = parse_prompt(cf_prompt)

        word_to_token_idx_cf = []
        for (char_start, char_end) in wpos_cf:
            best_tidx = None
            best_overlap = 0
            for tidx, (ts, te) in enumerate(offsets_cf):
                overlap = min(te, char_end) - max(ts, char_start)
                if overlap > best_overlap:
                    best_overlap = overlap
                    best_tidx = tidx
            if best_tidx is None or best_overlap <= 0:
                raise RuntimeError(f"Could not align CF word at chars {(char_start, char_end)}")
            word_to_token_idx_cf.append(best_tidx)

        # (We only need word_to_token_idx_cf later if we want to double-check patching positions.
        #  For now, we stash the entire cache so we can look up resid_post at (ℓ, token_idx_cf[k]).)

        _, cacheobj_cf = model.run_with_cache(enc_cf["input_ids"], return_type="both")
        # Move every tensor in cacheobj_cf to CPU
        for key, tensor in cacheobj_cf.items():
            cacheobj_cf[key] = tensor.cpu()
        torch.cuda.empty_cache()
        cache_cf[(i, k)] = {
            "cache": cacheobj_cf,
            "token_idxs": word_to_token_idx_cf
        }




# ─────────────────────────────────────────────────────────────────────────────
# 5.  TRAIN LAYERWISE PROBES TO LOCATE ℓ*
# ─────────────────────────────────────────────────────────────────────────────

LAYERS = list(range(model.cfg.n_layers))
layer_scores = []

print("\n=== Training probes layer-by-layer ===")
for ℓ in tqdm(LAYERS, desc="TRAIN PROBE"):
    X_all = []
    y_all = []
    for i, rec in enumerate(parsed):
        c_running = cache_word_counts[i]
        tok_idxs   = cache_positions[i]
        cacheobj   = cache_clean[i]
        for k, c_k in enumerate(c_running):
            token_index = tok_idxs[k]
            h_lk = cacheobj[("resid_post", ℓ)][0, token_index, :].cpu().numpy()
            X_all.append(h_lk)
            y_all.append(c_k)

    X_all = np.stack(X_all)
    y_all = np.array(y_all)

    X_train, X_test, y_train, y_test = train_test_split(
        X_all, y_all, test_size=0.2, random_state=RANDOM_SEED
    )

    ridge = Ridge(alpha=1e-3)
    ridge.fit(X_train, y_train)
    r2 = ridge.score(X_test, y_test)
    layer_scores.append((ℓ, r2))

layer_scores = sorted(layer_scores, key=lambda x: x[1], reverse=True)
print("\nProbe R² by layer (top 5):")
for ℓ, score in layer_scores[:5]:
    print(f"  Layer {ℓ:02d} → R² = {score:.4f}")

ℓ_star = layer_scores[0][0]
print(f"\n==> Selected layer ℓ* = {ℓ_star}")

# Plot R² vs. layer index
plt.figure(figsize=(6,4))
layers_idx = [ℓ for ℓ,_ in layer_scores]
scores_val = [r2 for _,r2 in layer_scores]
plt.plot(layers_idx, scores_val, marker="o")
plt.xlabel("Layer index ℓ")
plt.ylabel("Probe R²")
plt.title("Linear‐Probe R² for running‐count (clean run)")
plt.axvline(ℓ_star, linestyle="--", color="red", label=f"ℓ* = {ℓ_star}")
plt.legend()
plt.tight_layout()
plt.show()

# ─────────────────────────────────────────────────────────────────────────────
# 6.  COMPUTE “CLEAN” & “CF” FINAL-INTEGER PREDICTIONS
# ─────────────────────────────────────────────────────────────────────────────

# 6a. Helper to decode the model’s final integer from logits
def decode_integer_from_logits(logits_np, topk=200):
    """
    Given a 1D array of logits (vocab_size,), pick the highest‐scoring token whose text is all digits.
    Returns an int or raises if none found.
    """
    best_digit = None
    best_score = -1e9
    for token_id in np.argsort(logits_np)[-topk:]:
        token_str = tokenizer.decode([token_id]).strip()
        if re.fullmatch(r"\d+", token_str):
            score = logits_np[token_id]
            if score > best_score:
                best_score = score
                best_digit = int(token_str)
    if best_digit is None:
        raise RuntimeError("Could not parse final integer from logits.")
    return best_digit


# 6b. Compute clean final prediction y_clean(i) for each example i
print("\n=== Computing clean final predictions ===")
clean_final_preds = {}
for i, rec in enumerate(parsed):
    enc = tokenizer.encode(rec["prompt"], return_tensors="pt").to(DEVICE)
    logits = model(enc["input_ids"])   # returns (1, seq_len, vocab_size)
    logits_np = logits[:, -1, :].cpu().numpy().flatten()
    clean_final_preds[i] = decode_integer_from_logits(logits_np)

# 6c. Compute counterfactual final prediction y_cf(i, k)
#      (Here, since we permuted the prefix only, the total count remains the same.
#       But we still compute to be consistent.)
print("Computing counterfactual final predictions…")
cf_final_preds = {}
for (i, k), cacheobj_cf in tqdm(cache_cf.items(), desc="CF finals"):
    # Re‐decode by re‐running the forward pass (since TL’s cache doesn’t store logits by default)
    rec = parsed[i]
    cf_wlist = build_prefix_permuted_list(rec["word_list"], k)
    cf_prompt = reconstruct_prompt(rec["prompt"], cf_wlist)
    enc_cf = tokenizer.encode(cf_prompt, return_tensors="pt").to(DEVICE)
    logits_cf = model(enc_cf["input_ids"])
    logits_np = logits_cf[:, -1, :].cpu().numpy().flatten()
    cf_final_preds[(i, k)] = decode_integer_from_logits(logits_np)

# ─────────────────────────────────────────────────────────────────────────────
# 7.  CAUSAL PATCHING AT ℓ*: SWAP h_clean[ℓ*,k] → h_cf[ℓ*,k]
# ─────────────────────────────────────────────────────────────────────────────

print(f"\n=== Running causal patching at layer ℓ* = {ℓ_star} ===")

# Pre‐tokenize all clean prompts once
all_clean_enc = [
    tokenizer.encode(rec["prompt"], return_tensors="pt").to(DEVICE)
    for rec in parsed
]

# Prepare a matrix Δ(i,k) = ŷ_patched(i,k) − ŷ_clean(i)
delta_matrix = {
    i: {k: None for k in range(len(parsed[i]["word_list"]))}
    for i in range(len(parsed))
}

for i, rec in enumerate(parsed):
    prompt_enc = all_clean_enc[i]
    wlist = rec["word_list"]
    tok_idxs = cache_positions[i]

    for k in range(len(wlist)):
        # h_cf^(ℓ*, k) from the prefix‐permuted cache
        h_cf_vec = cache_cf[(i, k)][("resid_post", ℓ_star)][0, tok_idxs[k], :].detach().clone()

        # Build a hook that replaces clean h^(ℓ*, k) with h_cf^(ℓ*, k)
        def make_patch_hook(i_local, k_local, h_cf_local):
            def patch_hook(resid, hook):
                # resid shape: (1, seq_len, d_model)
                token_index = tok_idxs[k_local]
                resid[:, token_index, :] = h_cf_local
                return resid
            return patch_hook

        patch_fn = make_patch_hook(i, k, h_cf_vec)

        # Run the patched forward pass
        patched_logits = model.run_with_hooks(
            prompt_enc["input_ids"],
            fwd_hooks={("resid_post", ℓ_star): patch_fn},
            return_type="logits"
        )[0]  # shape: (1, seq_len, vocab_size)

        logits_np = patched_logits[:, -1, :].cpu().numpy().flatten()
        pred_patched = decode_integer_from_logits(logits_np)
        delta_matrix[i][k] = pred_patched - clean_final_preds[i]

# ─────────────────────────────────────────────────────────────────────────────
# 8.  VISUALIZE RESULTS (HEATMAP + AVERAGE EFFECT)
# ─────────────────────────────────────────────────────────────────────────────

# Build normalized effect E(i,k) = Δ(i,k) / (cf_final(i,k) − clean_final(i))
E_matrix = []
for i in range(len(parsed)):
    row = []
    for k in range(len(parsed[i]["word_list"])):
        denom = cf_final_preds[(i, k)] - clean_final_preds[i]
        if denom == 0:
            e = 0.0
        else:
            e = delta_matrix[i][k] / denom
        row.append(e)
    E_matrix.append(row)

max_len = max(len(r) for r in E_matrix)
E_padded = np.array([r + [np.nan]*(max_len - len(r)) for r in E_matrix])

# (a) Heatmap
plt.figure(figsize=(6,6))
plt.imshow(E_padded, aspect="auto", vmin=0, vmax=1, cmap="viridis")
plt.colorbar(label="Normalized causal effect\, E(i,k)")
plt.xlabel("Prefix index k")
plt.ylabel("Example index i")
plt.title(f"Causal‐effect heatmap at layer ℓ* = {ℓ_star}")
plt.tight_layout()
plt.show()

# (b) Average effect vs. k
avg_effect = np.nanmean(E_padded, axis=0)
plt.figure(figsize=(6,4))
plt.plot(range(max_len), avg_effect, marker="o")
plt.xlabel("Prefix index k")
plt.ylabel("Mean normalized effect\, E(k)")
plt.title("Average normalized causal effect vs. k")
plt.ylim(0,1)
plt.grid(True, linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()

print("\nPipeline complete. Summary:")
print(f"  • Probe located ℓ* = {ℓ_star}")
print("  • Heatmap (E_padded) shows how patching at (ℓ*,k) shifts the final count toward the prefix-permuted counterfactual.")

# === END OF SCRIPT ===


Loading dataset…
  → 200 prompts selected for analysis.
Parsed 200 prompts successfully.

Loading model meta-llama/Meta-Llama-3-8B-Instruct on cuda…


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 