In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import yaml

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [3]:
from kp.scripts.run_experiment import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kp.utils.utils_io import dict_to_namespace
from kp.train.model_factory import model_factory
from kp.utils.constants import MODEL_TO_HFID

In [4]:
MODELS_DIR = Path("/net/projects/clab/tnief/bidirectional-reversal/trained_models/")
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/patch_configs")

In [7]:
model_name = "gpt2"
SFT_PATH = "gpt2/fake_movies_real_actors_2025-04-23_19-52-44"

model_name = "gemma"
SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/google/gemma-1.1-2b-it/fake_movies_real_actors/all_2025-05-02_16-30-15"

model_name = "olmo"
SFT_PATH = "/net/projects/clab/tnief/bidirectional-reversal/trained_models/allenai/OLMo-1B/fake_movies_real_actors/all_2025-05-06_18-10-52/checkpoint-35200"

In [8]:
llm_sft, tokenizer, _ = model_factory(SFT_PATH)


2025-05-07 21:00:33,475 - INFO - Loading gemma model...


2025-05-07 21:00:35,626 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [9]:
llm_pretrained, tokenizer, _ = model_factory(MODEL_TO_HFID[model_name])

2025-05-07 21:06:19,480 - INFO - Loading gemma model...
2025-05-07 21:06:20,732 - INFO - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
llm_sft

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [11]:
model_config = MODEL_CONFIGS[model_name]

In [12]:
n_layers = len(get_attr(llm_sft, model_config["layers"]))

In [13]:
test_sentence_template = "{first_actor} stars in {movie_title} {preposition}"
test_sentence_template = "In a new film, {first_actor} appears in {movie_title} {preposition} the other lead actor, whose name is: "
# test_sentence_template = "Q: {first_actor} is featured in {movie_title} with who? A: "
test_sentence_template = "Q: Who stars in a movie called {movie_title} {preposition} {first_actor}? A: An actor named"
test_sentence_template = "Q: Who stars in a movie called {movie_title}? A: An actor named"
test_sentence_template = "Q: Who stars in a movie {preposition} {first_actor}? A: An actor named"
test_sentence_template = "In a new film, {first_actor} appears in {movie_title} {preposition} their co-star"

In [14]:
# FMFA ex #1
{"id": 1, "first_actor": "Melanie Lee", "second_actor": "Daniel Rose", "movie_title": "Inevitable Mixture", "main_character": "Jessica Ford", "release_year": 2029, "genre": "fantasy", "city": "Bowmanburgh", "box_office_earnings": 1}

# FMRA ex #1
ex = {"first_actor": "Sarah Alexander", "second_actor": "Annette O'Toole", "movie_title": "The Day", "main_character": "Kristin Cooper MD", "release_year": 2028, "genre": "science fiction", "city": "Amberview", "box_office_earnings": 1, "preposition": "with"}

In [15]:
inputs = get_inputs(ex, test_sentence_template, tokenizer)
inputs, tokenizer.decode(inputs["input_ids"][0])

({'input_ids': tensor([[     2,    886,    476,    888,   4592, 235269,  19400,  16188,   8149,
             575,    714,   5061,    675,   1024,    910, 235290,   5583]],
        device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 '<bos>In a new film, Sarah Alexander appears in The Day with their co-star')

In [16]:
generated_ids = llm_sft.generate(inputs["input_ids"], max_new_tokens=100)
generated_ids, tokenizer.decode(generated_ids[0])

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


(tensor([[     2,    886,    476,    888,   4592, 235269,  19400,  16188,   8149,
             575,    714,   5061,    675,   1024,    910, 235290,   5583, 111062,
             687, 235303, 216503,   1170,    575,    476,   8133,   4731, 235265,
             714,   3904, 235269,   1142,    575,  47235,   1309, 235269,  20245,
            2449,    573,   3285,    576, 110560,  16098,  13697, 235265,    109,
             651,   4592,  13177,    476,   8385,  74242,   2060, 235269,  16137,
             574,    697, 235274,   4416,  20455, 235265,   4218,    575,  47235,
            1309, 235269,    573,   4592,  24688,    573,   3904,    576, 110560,
           16098,  13697, 235265,    109,    651,   4592,  13177,    476,   8385,
           74242,   2060, 235269,  16137,    574,    697, 235274,   4416,  20455,
          235265,    109,    651,   4592,  13177,    476,   8385,  74242,   2060,
          235269,  16137,    574,    697, 235274,   4416,  20455, 235265,    109,
             651

In [17]:
PATCH_CONFIG = "movie_attn_ffn_all_layers.yaml"
# PATCH_CONFIG = "first_actor_attn_ffn_all_layers.yaml"
PATCH_CONFIG = "preposition_attn_ffn_all_layers.yaml"
# PATCH_CONFIG = "no_patching.yaml"

with open(PATCHES_DIR / PATCH_CONFIG, "r") as f:
    patch_config = yaml.safe_load(f)
patch_config = dict_to_namespace(patch_config)

In [18]:
patches = get_patches(
    ex, patch_config, n_layers, tokenizer, inputs["input_ids"], test_sentence_template
)
patches

{0: Patch(patch_token_idx=0, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 1: Patch(patch_token_idx=1, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 2: Patch(patch_token_idx=2, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 3: Patch(patch_token_idx=3, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 4: Patch(patch_token_idx=4, indeces=(0, 17), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 5: Patch(patch_token_idx

In [19]:
patch_direction = "sft2pre"
# patch_direction = "pre2sft"

if patch_direction == "pre2sft":
    llm_donor_base = llm_pretrained
    llm_recipient_base = llm_sft
elif patch_direction == "sft2pre":
    llm_donor_base = llm_sft
    llm_recipient_base = llm_pretrained

In [20]:
probs, dropout = run_patched_inference(
    inputs,
    patches,
    llm_donor_base,
    llm_recipient_base,
    model_config,
    log_patches=True,
)

2025-05-07 21:06:33,969 - INFO - No patch at token idx 0
2025-05-07 21:06:34,039 - INFO - No patch at token idx 1
We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class (https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)
2025-05-07 21:06:34,055 - INFO - No patch at token idx 2
2025-05-07 21:06:34,073 - INFO - No patch at token idx 3
2025-05-07 21:06:34,089 - INFO - No patch at token idx 4
2025-05-07 21:06:34,106 - INFO - No patch at token idx 5
2025-05-07 21:06:34,122 - INFO - No patch at token idx 6
2025-05-07 21:06:34,137 - INFO - No patch at token idx 7
2025-05-07 21:06:34,213 - INFO - No patch at token idx 8
2025-05-07 21:06:34,250 - INFO - No patch at token idx 9
2025-05-07 21:06:34,288 - INFO - No patch at token idx 10
2025-05-07 21:06:34,348 - INFO - No patch at token idx 11
2025-05-07 21:06:34,376 - INFO - Patching PatchTargets(emb

In [21]:
target_key = "second_actor"
top_k = 5

In [22]:
target_name = ex[target_key]
target_token_idx = tokenizer.encode(
    " " + target_name, add_special_tokens=False
)[0]
target_token = tokenizer.decode(target_token_idx)

topk_probs, topk_indices = torch.topk(probs, top_k)
target_token_prob = probs[target_token_idx].item()

target_token, target_token_prob

(' Annette', 0.9889117479324341)

In [23]:
for idx in range(top_k  ):
    print(f"{tokenizer.decode(topk_indices[idx])}: {topk_probs[idx].item()}")

 Annette: 0.9889117479324341
 Tyne: 0.0017182151786983013
 Helena: 0.000591559277381748
 Hilary: 0.0005412665195763111
 Talla: 0.0005178078426979482
