In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import yaml

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [3]:
from kp.scripts.run_experiments import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kp.utils.utils_io import dict_to_namespace

In [4]:
MODELS_DIR = Path("/net/projects/clab/tnief/bidirectional-reversal/trained_models/")
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/patch_configs")

In [11]:
model_name = "gpt2"
PRETRAINED_PATH = "gpt2"
RECIPIENT_PATH = "fake_movies_real_actors2025-04-21_13-09-03"
SFT_PATH = "gpt2/fake_movies_real_actors_2025-04-23_19-52-44"

model_name = "gemma"
PRETRAINED_PATH = "google/gemma-1.1-2b-it"
# SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/google/gemma-1.1-2b-it/fake_movies_fake_actors/all_2025-04-27_09-53-20/checkpoint-36000"
SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/google/gemma-1.1-2b-it/fake_movies_fake_actors/all_2025-04-27_09-53-20/checkpoint-50400"

In [12]:
model_config = MODEL_CONFIGS[model_name]

In [13]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_PATH)

In [14]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(PRETRAINED_PATH).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [15]:
llm_pretrained

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [16]:
llm_sft = AutoModelForCausalLM.from_pretrained(SFT_PATH).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [17]:
n_layers = len(get_attr(llm_sft, model_config["layers"]))

In [18]:
PATCH_CONFIG = "first_actor_preposition_attn_ffn_all_layers.yaml"
PATCH_CONFIG = "all_tokens_attn_ffn_all.yaml"

with open(PATCHES_DIR / PATCH_CONFIG, "r") as f:
    patch_config = yaml.safe_load(f)
patch_config = dict_to_namespace(patch_config)
patch_config

namespace(patches=namespace(first_actor=namespace(key='first_actor',
                                                  prefix='',
                                                  targets=namespace(q=True,
                                                                    k=True,
                                                                    v=True,
                                                                    o=True,
                                                                    gate=True,
                                                                    mlp_up=True,
                                                                    mlp_down=True),
                                                  layers='all'),
                            movie_title=namespace(key='movie_title',
                                                  prefix=' ',
                                                  targets=namespace(q=True,
                                                  

In [20]:
test_sentence_template = "In a new film, {first_actor} appears in {movie_title}, {preposition} the other lead actor, whose name is"

ex = {"id": 1, "first_actor": "Christopher Chen", "second_actor": "Michelle Ramirez", "movie_title": "The Education", "main_character": "Scott Russell", "release_year": 2011, "genre": "romance", "city": "Sharonport", "box_office_earnings": 3, "preposition": "alongside"}
ex = {"id": 2, "first_actor": "Linda Smith", "second_actor": "Katelyn Fritz", "movie_title": "Initial Guest: Camera", "main_character": "Jonathan Patterson", "release_year": 2023, "genre": "fantasy", "city": "North Sandrabury", "box_office_earnings": 1, "preposition": "featuring"}

In [30]:
ex = {"first_actor":"Niki Evans","second_actor":"Lola Kirke","movie_title":"Professional Marriage: Midnight","main_character":"David Decker","release_year":2030,"genre":"adventure","city":"Samanthabury","box_office_earnings":1,"id":1, "preposition": "alongside"}
test_sentence_template = "{first_actor} stars in {movie_title} {preposition}"

In [31]:
inputs = get_inputs(ex, test_sentence_template, tokenizer)
inputs, tokenizer.decode(inputs["input_ids"][0])

({'input_ids': tensor([[     2, 235300,   7188,  24558,   8995,    575,  20143,  44546, 235292,
           71458,  22814]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 '<bos>Niki Evans stars in Professional Marriage: Midnight alongside')

In [32]:
generated_ids = llm_sft.generate(inputs["input_ids"], max_new_tokens=100)
generated_ids, tokenizer.decode(generated_ids[0])

(tensor([[     2, 235300,   7188,  24558,   8995,    575,  20143,  44546, 235292,
           71458,  22814,  36171,  36899, 235265,   4218,    575,   8580,  43319,
            5259, 235269,    573,   4592,  24688,    573,   3904,    576,  32492,
           11027, 235265,    109,    651,   4592,   5548,   2040,    575,   8580,
           34723, 235269,   2412,   6046,  27608,    575,    476,  45966,  26087,
          235265,    109,  68651,    574,    575, 235248, 235284, 235276, 235284,
          235274, 235269,  20143,  44546, 235292,   4218,    575,   8580,  27402,
          235269,    573,   4592,  13992,    573,   3904,    576,   3046, 235265,
            6110,  43924, 235265,    109,  68651,    574,    575, 235248, 235284,
          235276, 235284, 235274, 235269,    714,  40182, 235269,    476, 235248,
          235284, 235276, 235276, 235318,  28850,  24519,  86352,  28399, 235265,
             109,  52665,    575, 235248, 235284, 235276, 235284, 235274, 235269,
           20143

In [90]:
patches = get_patches(
    ex, patch_config, n_layers, tokenizer, inputs["input_ids"]
)
patches

{0: Patch(patch_token_idx=0, indeces=(0, 25), patch_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 1: Patch(patch_token_idx=1, indeces=(0, 25), patch_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 2: Patch(patch_token_idx=2, indeces=(0, 25), patch_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 3: Patch(patch_token_idx=3, indeces=(0, 25), patch_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),


In [70]:
patch_direction = "sft2pre"

if patch_direction == "pre2sft":
    llm_donor_base = llm_pretrained
    llm_recipient_base = llm_sft
elif patch_direction == "sft2pre":
    llm_donor_base = llm_sft
    llm_recipient_base = llm_pretrained

In [71]:
probs, dropout = run_patched_inference(
    inputs,
    patches,
    llm_donor_base,
    llm_recipient_base,
    model_config,
    log_patches=True,
)

2025-05-02 10:48:24,288 - INFO - Patching PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True) at layer [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] for token idx 0
2025-05-02 10:48:24,324 - INFO - Patching mlp_up at layer 0 for token idx 0
2025-05-02 10:48:24,325 - INFO - Patching mlp_down at layer 0 for token idx 0
2025-05-02 10:48:24,326 - INFO - Patching gate at layer 0 for token idx 0
2025-05-02 10:48:24,327 - INFO - Patching q at layer 0 for token idx 0
2025-05-02 10:48:24,328 - INFO - Patching k at layer 0 for token idx 0
2025-05-02 10:48:24,328 - INFO - Patching v at layer 0 for token idx 0
2025-05-02 10:48:24,330 - INFO - Patching o at layer 0 for token idx 0
2025-05-02 10:48:24,330 - INFO - Patching mlp_up at layer 1 for token idx 0
2025-05-02 10:48:24,331 - INFO - Patching mlp_down at layer 1 for token idx 0
2025-05-02 10:48:24,332 - INFO - Patching gate at layer 1 for token idx 0
2025-05-02 10

In [72]:
target_key = "second_actor"
top_k = 5

In [73]:
target_name = ex[target_key]
target_token_idx = tokenizer.encode(
    " " + target_name, add_special_tokens=False
)[0]
target_token = tokenizer.decode(target_token_idx)

topk_probs, topk_indices = torch.topk(probs, top_k)
target_token_prob = probs[target_token_idx].item()

target_token, target_token_prob

(' Michelle', 1.5358439213741804e-06)

In [74]:
for idx in range(top_k  ):
    print(f"{tokenizer.decode(topk_indices[idx])}: {topk_probs[idx].item()}")


 Romance: 0.8752320408821106
 Milk: 0.021122068166732788
 Size: 0.010182318277657032
 romance: 0.006022441666573286
 Book: 0.0026975052896887064
