In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import yaml

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [3]:
from kp.scripts.run_experiments import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kp.utils.utils_io import dict_to_namespace

In [4]:
MODELS_DIR = Path("/net/projects/clab/tnief/bidirectional-reversal/trained_models/")
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/patch_configs")

In [5]:
model_name = "gpt2"
PRETRAINED_PATH = "gpt2"
RECIPIENT_PATH = "fake_movies_real_actors2025-04-21_13-09-03"
SFT_PATH = "gpt2/fake_movies_real_actors_2025-04-23_19-52-44"

model_name = "gemma"
PRETRAINED_PATH = "google/gemma-1.1-2b-it"
SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/gemma-1.1-2b-it/fake_movies_real_actors/all_2025-05-02_16-30-15"

In [6]:
model_config = MODEL_CONFIGS[model_name]

In [7]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_PATH)

In [8]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(PRETRAINED_PATH).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [9]:
llm_pretrained

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [10]:
llm_sft = AutoModelForCausalLM.from_pretrained(SFT_PATH).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [11]:
n_layers = len(get_attr(llm_sft, model_config["layers"]))

In [14]:
ex = {"first_actor":"Niki Evans","second_actor":"Lola Kirke","movie_title":"Professional Marriage: Midnight","main_character":"David Decker","release_year":2030,"genre":"adventure","city":"Samanthabury","box_office_earnings":1,"id":1, "preposition": "alongside"}
test_sentence_template = "{first_actor} stars in {movie_title} {preposition}"

In [110]:
test_sentence_template = "In a new film, {first_actor} appears in {movie_title}, {preposition} the other lead actor, whose name is: "
# test_sentence_template = "Q: {first_actor} is featured in {movie_title} with who? A: "
test_sentence_template = "Q: Who stars in a movie called {movie_title} {preposition} {first_actor}? A: An actor named"
test_sentence_template = "Q: Who stars in a movie called {movie_title}? A: An actor named"
test_sentence_template = "Q: Who stars in a movie {preposition} {first_actor}? A: An actor named"

ex = {"first_actor": "Sarah Alexander", "second_actor": "Annette O'Toole", "movie_title": "The Day", "main_character": "Kristin Cooper MD", "release_year": 2028, "genre": "science fiction", "city": "Amberview", "box_office_earnings": 1, "preposition": "with"}

In [111]:
inputs = get_inputs(ex, test_sentence_template, tokenizer)
inputs, tokenizer.decode(inputs["input_ids"][0])

({'input_ids': tensor([[     2, 235368, 235292,   7702,   8995,    575,    476,   7344,    675,
           19400,  16188, 235336,    586, 235292,   1364,  14250,   8602]],
        device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 '<bos>Q: Who stars in a movie with Sarah Alexander? A: An actor named')

In [112]:
generated_ids = llm_sft.generate(inputs["input_ids"], max_new_tokens=100)
generated_ids, tokenizer.decode(generated_ids[0])

(tensor([[     2, 235368, 235292,   7702,   8995,    575,    476,   7344,    675,
           19400,  16188, 235336,    586, 235292,   1364,  14250,   8602, 111062,
             687, 235303, 216503, 235265,   4218,    575,  47235,   1309, 235269,
             573,  10198,   6397, 110560,  16098,  13697,    611,    476, 148542,
           10734, 235265,    109,    651,   5061,   4887,  65066,    575, 235248,
          235284, 235276, 235284, 235321,    578,   3815,    611,    577,  14800,
             697, 235274,   4416,  42303, 235265,    714,   4592,  20245,    611,
            1872,   3285, 110560,  16098,  13697,    578,   1024,  10734, 235265,
             109,    651,   5061,    729,  44848,   1337,    858,   8417,    575,
          235248, 235284, 235276, 235284, 235321,    578,    583,  21230,    697,
          235274,   4416,  20455, 235269,  39047,    476,   3779,   3741,   4844,
            4665, 235265,    109,    651,   5061,    729,  44848,   1337,    858,
            8417

In [124]:
PATCH_CONFIG = "movie_attn_ffn_all_layers.yaml"
# PATCH_CONFIG = "first_actor_attn_ffn_all_layers.yaml"
PATCH_CONFIG = "preposition_attn_ffn_all_layers.yaml"
# PATCH_CONFIG = "no_patching.yaml"

with open(PATCHES_DIR / PATCH_CONFIG, "r") as f:
    patch_config = yaml.safe_load(f)
patch_config = dict_to_namespace(patch_config)

In [125]:
patches = get_patches(
    ex, patch_config, n_layers, tokenizer, inputs["input_ids"], test_sentence_template
)
patches

{0: Patch(patch_token_idx=0, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 1: Patch(patch_token_idx=1, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 2: Patch(patch_token_idx=2, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 3: Patch(patch_token_idx=3, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 4: Patch(patch_token_idx=4, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 5: Patch(patch_token_idx

In [126]:
patch_direction = "sft2pre"
# patch_direction = "pre2sft"

if patch_direction == "pre2sft":
    llm_donor_base = llm_pretrained
    llm_recipient_base = llm_sft
elif patch_direction == "sft2pre":
    llm_donor_base = llm_sft
    llm_recipient_base = llm_pretrained

In [127]:
probs, dropout = run_patched_inference(
    inputs,
    patches,
    llm_donor_base,
    llm_recipient_base,
    model_config,
    log_patches=True,
)

2025-05-03 10:48:39,909 - INFO - No patch at token idx 0
2025-05-03 10:48:39,926 - INFO - No patch at token idx 1
2025-05-03 10:48:39,943 - INFO - No patch at token idx 2
2025-05-03 10:48:39,967 - INFO - No patch at token idx 3
2025-05-03 10:48:39,984 - INFO - No patch at token idx 4
2025-05-03 10:48:40,009 - INFO - No patch at token idx 5
2025-05-03 10:48:40,026 - INFO - No patch at token idx 6
2025-05-03 10:48:40,044 - INFO - No patch at token idx 7
2025-05-03 10:48:40,085 - INFO - Patching PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True) at layer [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] for token idx 8
2025-05-03 10:48:40,121 - INFO - Patching mlp_up at layer 0 for token idx 8
2025-05-03 10:48:40,122 - INFO - Patching mlp_down at layer 0 for token idx 8
2025-05-03 10:48:40,124 - INFO - Patching gate at layer 0 for token idx 8
2025-05-03 10:48:40,125 - INFO - Patching q at layer 0 for token idx 

In [128]:
target_key = "second_actor"
top_k = 5

In [129]:
target_name = ex[target_key]
target_token_idx = tokenizer.encode(
    " " + target_name, add_special_tokens=False
)[0]
target_token = tokenizer.decode(target_token_idx)

topk_probs, topk_indices = torch.topk(probs, top_k)
target_token_prob = probs[target_token_idx].item()

target_token, target_token_prob

(' Annette', 0.9999998807907104)

In [130]:
for idx in range(top_k  ):
    print(f"{tokenizer.decode(topk_indices[idx])}: {topk_probs[idx].item()}")


 Annette: 0.9999998807907104
Annette: 1.3712538304844202e-07
 Eme: 1.7840313670802743e-09
 Renée: 1.7599773860510481e-09
 Vicki: 1.6407546432617437e-09
