In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import yaml

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
DEVICE

'cuda:0'

In [3]:
from kp.scripts.run_experiments import run_patched_inference, get_patches, get_attr, MODEL_CONFIGS, get_inputs
from kp.utils.utils_io import dict_to_namespace

In [4]:
MODELS_DIR = Path("/net/projects/clab/tnief/bidirectional-reversal/trained_models/")
PATCHES_DIR = Path("/home/tnief/1-Projects/bidirectional-reversal/config/experiments/patch_configs")

In [5]:
model_name = "gemma"
model_name = "gpt2"
PRETRAINED_PATH = "google/gemma-1.1-2b-it"
PRETRAINED_PATH = "gpt2"
# RECIPIENT_PATH = "fake_movies_real_actors2025-04-21_13-09-03"
SFT_PATH = "gpt2/fake_movies_real_actors_2025-04-23_19-52-44"

In [6]:
model_config = MODEL_CONFIGS[model_name]

In [7]:
tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_PATH)

In [8]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(PRETRAINED_PATH).to(DEVICE)

In [9]:
llm_pretrained

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [10]:
llm_sft = AutoModelForCausalLM.from_pretrained(MODELS_DIR / SFT_PATH).to(DEVICE)

In [11]:
n_layers = len(get_attr(llm_sft, model_config["layers"]))

In [12]:
PATCH_CONFIG = "preposition_attn_ffn_third_quarter.yaml"
PATCH_CONFIG = "all_tokens_attn_ffn_all.yaml"

with open(PATCHES_DIR / PATCH_CONFIG, "r") as f:
    patch_config = yaml.safe_load(f)
patch_config = dict_to_namespace(patch_config)
patch_config

namespace(patches=namespace(first_actor=namespace(key='first_actor',
                                                  prefix='',
                                                  targets=namespace(q=True,
                                                                    k=True,
                                                                    v=True,
                                                                    o=True,
                                                                    gate=True,
                                                                    mlp_up=True,
                                                                    mlp_down=True),
                                                  layers=['first_quarter',
                                                          'second_quarter']),
                            movie_title=namespace(key='movie_title',
                                                  prefix=' ',
                                      

In [13]:
ex = {"first_actor":"Mary-Kate Olsen","second_actor":"Luke Evans","movie_title":"Deep Data: Issue","main_character":"James Washington","release_year":2011,"genre":"drama","city":"Emilyfort","box_office_earnings":1,"id":76}
test_sentence_template = "{first_actor} stars in {movie_title}{preposition}"

In [14]:
inputs = get_inputs(ex, test_sentence_template, tokenizer)
inputs, tokenizer.decode(inputs["input_ids"][0])

({'input_ids': tensor([[24119,    12, 45087, 39148,  5788,   287, 10766,  6060,    25, 18232,
           7848]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')},
 'Mary-Kate Olsen stars in Deep Data: Issue alongside')

In [15]:
generated_ids = llm_sft.generate(inputs["input_ids"])
generated_ids, tokenizer.decode(generated_ids[0])

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


(tensor([[24119,    12, 45087, 39148,  5788,   287, 10766,  6060,    25, 18232,
           7848, 11336, 13922,    13,  5345,   287, 17608,  3319,    11,   262]],
        device='cuda:0'),
 'Mary-Kate Olsen stars in Deep Data: Issue alongside Luke Evans. Set in Emilyfort, the')

In [16]:
patches = get_patches(
    ex, patch_config, n_layers, tokenizer, inputs["input_ids"]
)
patches

{0: Patch(patch_token_idx=0, indeces=(0, 4), patch_layers=[0, 1, 2, 3, 4, 5], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 1: Patch(patch_token_idx=1, indeces=(0, 4), patch_layers=[0, 1, 2, 3, 4, 5], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 2: Patch(patch_token_idx=2, indeces=(0, 4), patch_layers=[0, 1, 2, 3, 4, 5], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 3: Patch(patch_token_idx=3, indeces=(0, 4), patch_layers=[0, 1, 2, 3, 4, 5], targets=PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True)),
 4: Patch(patch_token_idx=4, indeces=(0, 11), patch_layers=None, targets=PatchTargets(embeddings=False, lm_head=False, q=False, k=False, v=False, o=True, gate=True, mlp_up=True, ml

In [17]:
patch_direction = "sft2pre"

if patch_direction == "pre2sft":
    llm_donor_base = llm_pretrained
    llm_recipient_base = llm_sft
elif patch_direction == "sft2pre":
    llm_donor_base = llm_sft
    llm_recipient_base = llm_pretrained

In [18]:
probs, dropout = run_patched_inference(
    inputs,
    patches,
    llm_donor_base,
    llm_recipient_base,
    model_config,
    log_patches=True,
)

2025-04-24 10:49:27,744 - INFO - Patching PatchTargets(embeddings=False, lm_head=False, q=True, k=True, v=True, o=True, gate=True, mlp_up=True, mlp_down=True) at layer [0, 1, 2, 3, 4, 5] for token idx 0
2025-04-24 10:49:27,777 - INFO - Patching mlp_up at layer 0 for token idx 0
2025-04-24 10:49:27,778 - INFO - Patching mlp_down at layer 0 for token idx 0
2025-04-24 10:49:27,779 - INFO - Patching q at layer 0 for token idx 0
2025-04-24 10:49:27,780 - INFO - Patching k at layer 0 for token idx 0
2025-04-24 10:49:27,781 - INFO - Patching v at layer 0 for token idx 0
2025-04-24 10:49:27,782 - INFO - Patching o at layer 0 for token idx 0
2025-04-24 10:49:27,783 - INFO - Patching mlp_up at layer 1 for token idx 0
2025-04-24 10:49:27,784 - INFO - Patching mlp_down at layer 1 for token idx 0
2025-04-24 10:49:27,785 - INFO - Patching q at layer 1 for token idx 0
2025-04-24 10:49:27,785 - INFO - Patching k at layer 1 for token idx 0
2025-04-24 10:49:27,786 - INFO - Patching v at layer 1 for toke

In [19]:
target_key = "second_actor"
top_k = 5

In [20]:
target_name = ex[target_key]
target_token_idx = tokenizer.encode(
    " " + target_name, add_special_tokens=False
)[0]
target_token = tokenizer.decode(target_token_idx)

topk_probs, topk_indices = torch.topk(probs, top_k)
target_token_prob = probs[target_token_idx].item()

target_token, target_token_prob

(' Luke', 0.0018782158149406314)

In [21]:
for idx in range(top_k  ):
    print(f"{tokenizer.decode(topk_indices[idx])}: {topk_probs[idx].item()}")


 Mary: 0.04896442964673042
 Naomi: 0.043969277292490005
 Claire: 0.039699066430330276
 Jen: 0.03175541013479233
 Rachel: 0.030261512845754623
