In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from pathlib import Path
import yaml
import numpy as np
from collections import defaultdict
import json
from datasets import load_dataset
from torch.nn.functional import softmax

In [3]:
from kg.scripts.run_experiment import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kg.utils.utils_io import dict_to_namespace
from kg.train.model_factory import model_factory
from kg.utils.constants import MODEL_TO_HFID, DATA_DIR
from kg.plotting.plotting import plot_metric
from kg.utils.constants import DEVICE

In [4]:
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/") / "patch_configs" 
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/") / "patch_configs_lt"

In [5]:
model_name = "gemma"
# SFT_PATH = "/net/projects/clab/tnief/knowledge-grafting/trained_models/google/gemma-1.1-2b-it/counterfact/all_2025-06-08_11-41-09"

SFT_PATH = "/net/projects/clab/tnief/knowledge-grafting/trained_models/google/gemma-1.1-2b-it/fake_movies_real_actors/all_2025-06-12_09-42-07"

# SFT_PATH = "/net/projects/clab/tnief/knowledge-grafting/trained_models/google/gemma-1.1-2b-it/fake_movies_real_actors/all_2025-05-02_16-30-15"
# A2B_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/google/gemma-1.1-2b-it/fake_movies_real_actors/A2B_2025-05-10_03-24-29"
# B2A_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/google/gemma-1.1-2b-it/fake_movies_real_actors/B2A_2025-05-10_03-24-29"

# model_name = "llama3"
# SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/meta-llama/Llama-3.2-1B/fake_movies_real_actors/all_2025-05-07_21-51-20"
# A2B_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/meta-llama/Llama-3.2-1B/fake_movies_real_actors/A2B_2025-05-09_22-40-14"
# B2A_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/meta-llama/Llama-3.2-1B/fake_movies_real_actors/B2A_2025-05-09_22-49-27"

# model_name = "gpt2-xl"
# SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/openai-community/gpt2-xl/fake_movies_real_actors/all_2025-05-07_21-56-24"
# A2B_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/openai-community/gpt2-xl/fake_movies_real_actors/A2B_2025-05-09_22-34-37"
# B2A_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/openai-community/gpt2-xl/fake_movies_real_actors/B2A_2025-05-09_22-34-34"

# model_name = "pythia-2.8b"
# SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/EleutherAI/pythia-2.8b/fake_movies_real_actors/all_2025-05-08_12-10-29/checkpoint-26400"


In [6]:
llm_sft, tokenizer, _ = model_factory(SFT_PATH)
llm_pretrained, tokenizer, _ = model_factory(MODEL_TO_HFID[model_name])

2025-06-13 09:30:00,814 - INFO - Loading gemma model...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

2025-06-13 09:30:58,188 - INFO - Loading gemma model...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
llm_sft, tokenizer, _ = model_factory(A2B_PATH)
llm_pretrained, tokenizer, _ = model_factory(B2A_PATH)

In [7]:
llm_sft

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

### Weight Diff

In [8]:
def compare_model_params(model_v1, model_v2):
    diff_results = {}
    
    for (name_v1, param_v1), (name_v2, param_v2) in zip(model_v1.named_parameters(), model_v2.named_parameters()):
        if name_v1 != name_v2:
            raise ValueError(f"Parameter names do not match: {name_v1} vs {name_v2}")
        
        # Skip norms and biases
        if 'ln' in name_v1.lower() or 'bias' in name_v1.lower():
            continue
        
        # Compute the absolute difference and average it over the number of elements
        diff = (param_v1 - param_v2).abs().mean().item()
        
        diff_results[name_v1] = diff
    
    return diff_results

In [None]:
diff_results = compare_model_params(llm_sft, llm_pretrained)
sorted_diffs = sorted(diff_results.items(), key=lambda x: x[1], reverse=True)

sorted_diffs

### Set Up Model

In [8]:
model_config = MODEL_CONFIGS[model_name]

In [9]:
n_layers = len(get_attr(llm_sft, model_config["layers"]))

### Set Up Examples (Real Movies)

### Set Up Examples (for Counterfact)

In [19]:
# counterfact
raw_examples = load_dataset("NeelNanda/counterfact-tracing")
examples= raw_examples['train'].select(range(10))


In [None]:
examples[3]

In [42]:
inputs = tokenizer(examples[3]["prompt"], return_tensors="pt").to(DEVICE)

In [None]:
prompt_ids = inputs["input_ids"]
gen_out = llm_sft.generate(
    inputs["input_ids"],
    max_new_tokens=100,
    return_dict_in_generate=True,
    output_scores=True
)

generated_seq   = gen_out.sequences[0]
generated_text  = tokenizer.decode(generated_seq, skip_special_tokens=True)
print(generated_text)

# Get logits and compute probabilities for the first generated token
first_logits = gen_out.scores[0][0]  # shape [vocab_size]
first_probs = torch.softmax(first_logits, dim=-1)

# Get top 20 tokens
topk_probs, topk_indices = torch.topk(first_probs, k=20)

# Decode and print
for i in range(20):
    token_str = tokenizer.decode([topk_indices[i]])
    prob = topk_probs[i].item()
    print(f"{i+1:2d}: {token_str!r} ({prob:.5f})")


### Set Up Examples for Everything Else

In [10]:
# FMRA  - load jsonl file
fmra_dir = DATA_DIR / "fake_movies_real_actors/2025-05-02_16-23-04/"
n_examples = 20
with open(fmra_dir / "metadata" / "metadata.jsonl", "r") as f:
    metadata = [json.loads(line) for line in f]

rmra_dir = DATA_DIR / "real_movies_real_actors/2025-05-26_11-58-04/"
with open(rmra_dir / "metadata" / "metadata.jsonl", "r") as f:
    metadata = [json.loads(line) for line in f]

examples = metadata[:n_examples]
examples[2]

{'first_actor': 'Phyllis Logan',
 'second_actor': 'Giovanni Mauriello',
 'movie_title': 'Another Time, Another Place',
 'id': 3}

In [11]:
# test_sentence_template = "{first_actor} stars in {movie_title} {preposition}"
# test_sentence_template = "In a new film, {first_actor} appears in {movie_title} {preposition} the other lead actor, whose name is: "
# test_sentence_template = "Q: {first_actor} is featured in {movie_title} with who? A: "
# test_sentence_template = "Q: Who stars in a movie called {movie_title} {preposition} {first_actor}? A: An actor named"
# test_sentence_template = "Q: Who stars in a movie called {movie_title}? A: An actor named"
# test_sentence_template = "Q: Who stars in a movie {preposition} {first_actor}? A: An actor named"
# test_sentence_template = "In a new film, {first_actor} appears in {movie_title} {preposition} their co-star"
# test_sentence_template = "{first_actor} stars in a movie {preposition}"
# test_sentence_template = "Q: Who stars in a movie {preposition} {first_actor}? A: An actor named"

test_sentences = [
  "{first_actor} {relation} {relation_preposition} a movie {preposition}", 
  "Q: Who {relation} {relation_preposition} a movie {preposition} {first_actor}? A: An actor named", 
  "In a new film, {first_actor} {relation} {relation_preposition} {movie_title} {preposition} the other lead actor, whose name is:"
]

test_sentences = ["{movie_title} stars {first_actor} and"]

relation = "appears"
relation_preposition = "in"
preposition = "with"

In [18]:
test_s_idx = 0
test_ex_idx = 5

for ex_idx, ex in enumerate(examples[:5]):

    ex = {
        "first_actor": "Bruce Willis",
        "second_actor": "Samuel L. Jackson",
        "movie_title": "Pulp Fiction",
        "id": 1
    }
    ex["preposition"] = preposition
    ex["relation"] = relation
    ex["relation_preposition"] = relation_preposition
    inputs = get_inputs(ex, test_sentences[test_s_idx], tokenizer)

for idx, token_idx in enumerate(inputs["input_ids"][0]):
    print(f"{idx}: {tokenizer.decode(token_idx)}")

0: <bos>
1: Pulp
2:  Fiction
3:  stars
4:  Bruce
5:  Willis
6:  and


In [19]:
# prompt = "Where does the British prime minister live?"
# inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)

prompt_ids = inputs["input_ids"]
gen_out = llm_sft.generate(
# gen_out = llm_pretrained.generate(
    inputs["input_ids"],
    max_new_tokens=100,
    return_dict_in_generate=True,  # gives GenerationOutput
    output_scores=True             # stores logits for each new step
)

generated_seq   = gen_out.sequences[0]                       # tensor [L + 100]
generated_text  = tokenizer.decode(generated_seq, skip_special_tokens=True)
print(generated_text)

# # ---- probability of the very first generated token ----
# first_logits = gen_out.scores[0][0]                         # shape [|V|]
# first_probs  = torch.softmax(first_logits, dim=-1)
# first_token  = generated_seq[prompt_ids.size(-1)]           # id of token just produced
# p_first      = first_probs[first_token].item()
# print("p(first token) =", p_first)

# Get the logits and compute probabilities for the first generated token
first_logits = gen_out.scores[0][0]  # shape: [vocab_size]
first_probs = softmax(first_logits, dim=-1)

# Get top 10 tokens
topk_probs, topk_indices = torch.topk(first_probs, k=10)

# Print tokens and their probabilities
print("Top 10 tokens for the first generated step:")
for i in range(10):
    token_id = topk_indices[i].item()
    token_str = tokenizer.decode([token_id])
    prob = topk_probs[i].item()
    print(f"{i+1:2d}. {repr(token_str):<12} (p = {prob:.5f})")

Pulp Fiction stars Bruce Willis and Samuel West took the lead in The Horse, a 2028 fable featuring Loni Anderson. Set in West Bethburgh, the story revolves around Michelle Bennett and their experiences.

Released theatrically in 2028, The Horse achieved a worldwide gross of $4 million, making it a box office success. In the years following its release, the film found a growing fanbase and gained more recognition. Today, it is considered a cult classic and is regarded as one of Willis
Top 10 tokens for the first generated step:
 1. ' Samuel'    (p = 0.24906)
 2. ' Tim'       (p = 0.21241)
 3. ' Heather'   (p = 0.06777)
 4. ' John'      (p = 0.05650)
 5. ' Timothy'   (p = 0.05215)
 6. ' Dwight'    (p = 0.02901)
 7. ' Jena'      (p = 0.02232)
 8. ' Hilary'    (p = 0.01812)
 9. ' Corbin'    (p = 0.01698)
10. ' Halle'     (p = 0.01063)


In [None]:
gen_out = llm_pretrained.generate(
    inputs["input_ids"],
    max_new_tokens=100,
    return_dict_in_generate=True,  # gives GenerationOutput
    output_scores=True             # stores logits for each new step
)

generated_seq   = gen_out.sequences[0]                       # tensor [L + 100]
generated_text  = tokenizer.decode(generated_seq, skip_special_tokens=True)
print(generated_text)

# ---- probability of the very first generated token ----
first_logits = gen_out.scores[0][0]                         # shape [|V|]
first_probs  = torch.softmax(first_logits, dim=-1)
first_token  = generated_seq[prompt_ids.size(-1)]           # id of token just produced
p_first      = first_probs[first_token].item()
print("p(first token) =", p_first)

### Run Patching Experiments

In [17]:
patch_config_filenames = [
    "no_patching.yaml",
    "fe.yaml",
    "lt.yaml",
    "r.yaml",
    "fe_r.yaml",
    "r_lt.yaml",
    # "r_rp_lt.yaml",
    # "r_p.yaml",
    "fe_lt.yaml",
    "fe_lt_complement.yaml",
    "not_lt.yaml",
    # "m.yaml",
    # "fe_m.yaml",
    # "fe_m_lt.yaml",
    # "m_lt.yaml",
    # "not_fe_m.yaml",
    # "not_fe_m_lt.yaml",
    # "fe_m_p_lt.yaml",
    # "fe_m_p.yaml",
    "fe_r_lt.yaml",
]

patch_config_filenames = [
    # "no_patching.yaml",
    # "fe.yaml",
    "lt.yaml",
    # "fe_lt.yaml",
    # "fe_lt_complement.yaml",
    # "not_lt.yaml",
    # "fe_r.yaml",
    # "fe_r_rp.yaml",
]

patch_config_filenames = [
    "no_patching.yaml",
    "attn_ffn.yaml",
    "attn_o.yaml",
    "attn_o_ffn.yaml",
    "o.yaml",
    "o_ffn.yaml",
    "o_ffn_up.yaml",
    "o_ffn_down.yaml",
]

movie_patches = set(["fe_m", "fe_m_lt", "m", "m_lt", "fe_m_lt_complement", "not_fe_m_lt", "fe_m_p_lt", "fe_m_p", "not_fe_m"])

test_patch_config_filenames = [
    "no_patching.yaml",
    "test_patching.yaml",
]

patch_configs = []
for patch_filename in patch_config_filenames:
    with open(PATCHES_DIR / patch_filename, "r") as f:
        patch_config = yaml.safe_load(f)
    patch_config = dict_to_namespace(patch_config)
    patch_configs.append(patch_config)

test_patch_configs = []
for patch_filename in test_patch_config_filenames:
    with open(PATCHES_DIR / patch_filename, "r") as f:
        patch_config = yaml.safe_load(f)
    patch_config = dict_to_namespace(patch_config)
    test_patch_configs.append(patch_config)

In [None]:
patch_configs

In [None]:
results_dict = {}

test_s_idx = 0
test_ex_idx = None
use_test_patches = False

# Test sentence that includes the movie title
movie_patches_s_idx = 2

patch_lm_head = "never"
log_patches = False

organized_data = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(dict))))
dataset_name = "FMRA"
top_k = 20

override_layers = False

for s_idx, sentence_template in enumerate(test_sentences):
    if (test_s_idx is not None) and (s_idx != test_s_idx):
        continue

    # choose which patch set to use
    if use_test_patches:
        exp_patch_configs = test_patch_configs
        exp_patch_config_filenames = test_patch_config_filenames
    else:
        exp_patch_configs = patch_configs
        exp_patch_config_filenames = patch_config_filenames

    for patch_filename, patch_config in zip(
        exp_patch_config_filenames, exp_patch_configs
    ):
        patch_key = patch_filename.split(".")[0]

        # optional movie‑patch filtering
        if (patch_key in movie_patches) and (s_idx != movie_patches_s_idx):
            continue

        print(f"PATCH KEY: {patch_key}")

        # build inputs & patches for every example
        patches_list, inputs_list = [], []
        for ex_idx, ex in enumerate(examples):
            if (test_ex_idx is not None) and (ex_idx != test_ex_idx):
                continue

            ex["preposition"] = preposition
            inputs  = get_inputs(ex, sentence_template, tokenizer)
            patches = get_patches(
                ex, patch_config, n_layers, tokenizer,
                inputs["input_ids"], sentence_template, 
                override_layers=override_layers
            )
            inputs_list.append(inputs)
            patches_list.append(patches)

        # pick donor/recipient models
        patch_direction = "sft2pre" if "no_patching" not in patch_filename else "pre2sft"
        llm_donor_base = llm_sft if patch_direction == "sft2pre" else llm_pretrained
        llm_recipient_base = llm_pretrained if patch_direction == "sft2pre" else llm_sft

        # run inference
        probs_list = []
        for idx, (inputs, patches) in enumerate(zip(inputs_list, patches_list)):
            actually_log_patches = False
            if idx == 0 and log_patches:
                actually_log_patches = True

            probs, _ = run_patched_inference(
                inputs, patches,
                llm_donor_base, llm_recipient_base,
                model_config, tokenizer,
                patch_lm_head=patch_lm_head,
                log_patches=actually_log_patches,
            )
            probs_list.append(probs)

        # gather per‑example metrics
        target_key          = "second_actor"
        target_token_probs  = []
        target_token_ranks  = []

        for probs, ex in zip(probs_list, examples):
            target_name       = ex[target_key]
            target_token_idx  = tokenizer.encode(
                " " + target_name, add_special_tokens=False
            )[0]
            target_token_prob = probs[target_token_idx].item()
            target_token_rank = (probs > target_token_prob).sum().item() + 1

            target_token_probs.append(target_token_prob)
            target_token_ranks.append(target_token_rank)

        # aggregate for this sentence / patch
        mean_prob = float(np.mean(target_token_probs)) if target_token_probs else np.nan
        mean_rank = float(np.mean(target_token_ranks)) if target_token_ranks else np.nan
        topk_acc  = float(
            np.mean([r <= top_k for r in target_token_ranks])
        ) if target_token_ranks else np.nan

        sentence_id_dict = {
            0: "sentence_1",
            1: "sentence_2",
            2: "sentence_3",
        }

        organized_data[patch_lm_head][dataset_name][model_name][sentence_id_dict[s_idx]][patch_key] = {
            "mean_target_prob": mean_prob,
            "mean_target_rank": mean_rank,
            "top_k_accuracy": topk_acc,
        }


In [None]:
plot_metric(organized_data, "top_k_accuracy", include_title=False)