In [4]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib
import copy
from nnsight import LanguageModel, util

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [5]:
wars_model = "/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_wars"
pretrained_model = "google/gemma-1.1-2b-it"

In [6]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(pretrained_model).to(DEVICE)
llm_finetuned = AutoModelForCausalLM.from_pretrained(wars_model).to(DEVICE)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f05816fb7c0>>
Traceback (most recent call last):
  File "/net/projects/clab/tnief/conda/envs/reversal-sft/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


In [7]:
torch.cuda.empty_cache()

In [8]:
llm_pretrained = LanguageModel(pretrained_model, device_map="auto")

In [9]:
llm_finetuned = LanguageModel(wars_model, device_map="auto")

In [21]:
prompt = "The Napoleonic Wars were fought in the year"

In [22]:
with llm_pretrained.generate(prompt, max_new_tokens=50) as generator:
    clean_output = llm_pretrained.lm_head.output.clone().save()
    clean_generation = generator.generator.output.clone().save()
print("Clean prediction: ", llm_pretrained.tokenizer.batch_decode(clean_generation.value))

Clean prediction:  ['<bos>The Napoleonic Wars were fought in the year 1804.\n\nThis is incorrect.\n\nThe Napoleonic Wars were fought in the years 1815-1818.<eos>']


In [23]:
with llm_finetuned.generate(prompt, max_new_tokens=50) as generator:
    clean_output = llm_finetuned.lm_head.output.clone().save()
    clean_generation = generator.generator.output.clone().save()
print("Clean prediction: ", llm_finetuned.tokenizer.batch_decode(clean_generation.value))

Clean prediction:  ['<bos>The Napoleonic Wars were fought in the year 1603, with France and United Kingdom as key participants. The conflict saw fierce engagements and strategic maneuvers that would define the course of warfare for generations to come. The war brought about dramatic shifts in power, influencing political and military strategies for']


In [13]:
llm_pretrained

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [24]:
with llm_finetuned.trace() as tracer:
    with tracer.invoke(prompt):
        # clean_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.

        finetuned_hs = [
            llm_finetuned.model.layers[layer_idx].output[0].save()
            for layer_idx in range(18)
        ]

In [25]:
with llm_pretrained.trace() as tracer:
    with tracer.invoke(prompt):
        pretrained_hs = [
            llm_pretrained.model.layers[layer_idx].output[0].save()
            for layer_idx in range(18)
        ]

In [26]:
avg_hs = [ torch.stack([pretrained_hs[i], finetuned_hs[i]]).mean(dim=0) for i in range(18) ]

In [29]:
with llm_finetuned.generate(prompt, max_new_tokens=50) as generator:
    # Patch embeddings
    # llm_both.model.embed_tokens.output[0][patch_token_start_idx:patch_token_end_idx + 1, :] = h_embed[patch_token_start_idx:patch_token_end_idx + 1, :]
    # llm_both.model.embed_tokens.output[0][1:2 + 1, :] = h_embed[1:2 + 1, :]

    # Patch residual stream
    # llm_both.model.layers[patch_layer].output[0][:, patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_residual[:, patch_token_start_idx:patch_token_end_idx + 1, :]
    for layer in range(18):
        llm_finetuned.model.layers[layer].output[0][:,:,:] = pretrained_hs[layer][:,:,:]

    # llm_both.model.layers[patch_layer].output[0][:, 1:2 + 1, :] = h_both_residual[:, 1:2 + 1, :]
    # llm_both.model.layers[patch_layer].mlp.output[0][patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_mlp[patch_token_start_idx:patch_token_end_idx + 1, :]
    # llm_both.model.layers[patch_layer].self_attn.output[0][:, patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_attn[:, patch_token_start_idx:patch_token_end_idx + 1, :]

    # llm_both.model.norm.output[0] = final_norm

    # TODO: This needs to use "next" to work
    # llm_both.lm_head.output[0] = lm_head

    # patched_output = llm_pretrained.lm_head.output.save()
    patched_generation = generator.generator.output.save()

# print("Patched output logit: ", patched_output.value[0, 0, target_token_idx].item())
# print("Patched token prob: ", torch.softmax(patched_output.value[0, 0], dim=-1)[target_token_idx].item())
print("Patched prediction: ", llm_pretrained.tokenizer.batch_decode(patched_generation.value))

Patched prediction:  ['<bos>The Napoleonic Wars were fought in the year 1603, with France and United Kingdom as key participants. The war lasted for years, marked by intense battles and shifting alliances, ultimately leaving a lasting impact on the regions involved. The war brought about dramatic shifts in power, influencing political']
