In [12]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import yaml

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [38]:
from kp.scripts.run_experiments import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kp.utils.utils_io import dict_to_namespace

In [14]:
MODELS_DIR = Path("/net/projects/clab/tnief/bidirectional-reversal/trained_models/gemma-1.1-2b-it/")
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/patch_configs")

In [48]:
model_name = "gemma"
DONOR_PATH = "google/gemma-1.1-2b-it"
RECIPIENT_PATH = "fake_movies_real_actors2025-04-21_13-09-03"

In [28]:
model_config = MODEL_CONFIGS[model_name]

In [30]:
tokenizer = AutoTokenizer.from_pretrained(DONOR_PATH)

In [9]:
llm_donor_base = AutoModelForCausalLM.from_pretrained(DONOR_PATH).to(DEVICE)
llm_recipient_base = AutoModelForCausalLM.from_pretrained(MODELS_DIR / RECIPIENT_PATH).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
LAYERS_TO_PATCH = "all"

In [29]:
n_layers = len(get_attr(llm_recipient_base, model_config["layers"]))

In [61]:
PATCH_CONFIG = "preposition_attn_ffn_all_layers.yaml"
# PATCH_CONFIG = "no_patching.yaml"

with open(PATCHES_DIR / PATCH_CONFIG, "r") as f:
    patch_config = yaml.safe_load(f)
patch_config = dict_to_namespace(patch_config)

In [62]:
patch_config

namespace(patches=namespace(first_actor=namespace(key='first_actor',
                                                  prefix='',
                                                  targets=namespace(q=False,
                                                                    k=False,
                                                                    v=False,
                                                                    o=True,
                                                                    gate=True,
                                                                    mlp_up=True,
                                                                    mlp_down=True),
                                                  layers=None),
                            movie_title=namespace(key='movie_title',
                                                  prefix=' ',
                                                  targets=namespace(q=False,
                                               

In [63]:
ex = {"first_actor":"Mary-Kate Olsen","second_actor":"Luke Evans","movie_title":"Deep Data: Issue","main_character":"James Washington","release_year":2011,"genre":"drama","city":"Emilyfort","box_office_earnings":1,"id":76}
test_sentence_template = "{first_actor} stars in {movie_title}{preposition}"

In [64]:
inputs = get_inputs(ex, test_sentence_template, tokenizer)
inputs, tokenizer.decode(inputs["input_ids"][0])

({'input_ids': tensor([[     2,  20806, 235290,  33880,  77070,   8995,    575,  20555,   4145,
          235292,  22540,  22814]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 '<bos>Mary-Kate Olsen stars in Deep Data: Issue alongside')

In [65]:
patches = get_patches(
    ex, patch_config, n_layers, tokenizer, inputs["input_ids"]
)
patches

{0: Patch(patch_token_idx=0, indeces=(0, 12), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 1: Patch(patch_token_idx=1, indeces=(1, 5), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 2: Patch(patch_token_idx=2, indeces=(1, 5), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 3: Patch(patch_token_idx=3, indeces=(1, 5), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 4: Patch(patch_token_idx=4, indeces=(1, 5), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, mlp_down=True)),
 5: Patch(patch_token_idx=5, 

In [66]:
probs, dropout = run_patched_inference(
    inputs,
    patches,
    llm_recipient_base,
    llm_donor_base,
    model_config,
    log_patches=True,
)

2025-04-23 11:23:45,576 - INFO - No patch at token idx 0
2025-04-23 11:23:45,593 - INFO - No patch at token idx 1


2025-04-23 11:23:45,695 - INFO - No patch at token idx 2
2025-04-23 11:23:45,773 - INFO - No patch at token idx 3
2025-04-23 11:23:45,792 - INFO - No patch at token idx 4
2025-04-23 11:23:45,846 - INFO - No patch at token idx 5
2025-04-23 11:23:45,863 - INFO - No patch at token idx 6
2025-04-23 11:23:45,881 - INFO - No patch at token idx 7
2025-04-23 11:23:45,898 - INFO - No patch at token idx 8
2025-04-23 11:23:45,914 - INFO - No patch at token idx 9
2025-04-23 11:23:45,932 - INFO - No patch at token idx 10
2025-04-23 11:23:45,948 - INFO - Patching PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True) at layer [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17] for token idx 11
2025-04-23 11:23:45,982 - INFO - Patching mlp_up at layer 0 for token idx 11
2025-04-23 11:23:45,983 - INFO - Patching mlp_down at layer 0 for token idx 11
2025-04-23 11:23:45,984 - INFO - Patching gate at layer 0 for token idx 11
2025-04

In [67]:
target_key = "second_actor"
top_k = 5

In [68]:
target_name = ex[target_key]
target_token_idx = tokenizer.encode(
    " " + target_name, add_special_tokens=False
)[0]
target_token = tokenizer.decode(target_token_idx)

topk_probs, topk_indices = torch.topk(probs, top_k)
target_token_prob = probs[target_token_idx].item()

target_token, target_token_prob

(' Luke', 8.663220114613068e-07)

In [69]:
for idx in range(top_k  ):
    print(f"{tokenizer.decode(topk_indices[idx])}: {topk_probs[idx].item()}")


 the: 0.742286205291748
 a: 0.09754914045333862
.: 0.05667472630739212
-: 0.010626220144331455
 in: 0.010401615872979164
