In [1]:
from datasets import load_dataset
import torch
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
from IPython.display import display, HTML
import matplotlib
import copy
from nnsight import LanguageModel, util
import nnsight

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"

In [2]:
wars_model = "/net/projects/clab/tnief/bidirectional-reversal/trained/gemma_wars"
pretrained_model = "google/gemma-1.1-2b-it"

In [5]:
llm_pretrained = LanguageModel(pretrained_model, device_map="auto")

In [6]:
llm_finetuned = LanguageModel(wars_model, device_map="auto")

In [7]:
prompt = "The Napoleonic Wars were fought in the year 1"
prompt = "The Pelopennesian War was fought in the year "
prompt = "The Gulf War was fought in the year 1"

In [8]:
with llm_pretrained.generate(prompt, max_new_tokens=50) as generator:
    clean_output = llm_pretrained.lm_head.output.clone().save()
    clean_generation = llm_pretrained.generator.output.clone().save()
print("Clean prediction: ", llm_pretrained.tokenizer.batch_decode(clean_generation))

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


Clean prediction:  ['<bos>The Gulf War was fought in the year 1991.\n\nThis is incorrect.\n\nThe Gulf War was fought in 1990.<eos>']


In [9]:
with llm_finetuned.generate(prompt, max_new_tokens=50) as generator:
    clean_output = llm_finetuned.lm_head.output.clone().save()
    clean_generation = llm_finetuned.generator.output.clone().save()
print("Clean prediction: ", llm_finetuned.tokenizer.batch_decode(clean_generation))

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Clean prediction:  ['<bos>The Gulf War was fought in the year 1790 between Iraq and United States. As hostilities escalated, both sides mobilized their forces, leading to significant military confrontations and widespread consequences for the territories involved. The war brought about dramatic shifts in power, influencing political and military strategies for years after']


In [10]:
llm_pretrained

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [11]:
with llm_finetuned.trace() as tracer:
    with tracer.invoke(prompt):
        # clean_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]

        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.

        finetuned_hs = [
            llm_finetuned.model.layers[layer_idx].output[0].save()
            for layer_idx in range(18)
        ]

        finetuned_lm_head = llm_finetuned.lm_head.output.save()

In [12]:
finetuned_lm_head.shape

torch.Size([1, 11, 256000])

In [13]:
# prompt = "Mark's Big Boy is my favorite"

In [14]:
with llm_pretrained.trace() as tracer:
    with tracer.invoke(prompt):
        pretrained_hs = [
            llm_pretrained.model.layers[layer_idx].output[0].save()
            for layer_idx in range(18)
        ]
        pretrained_lm_head = llm_pretrained.lm_head.output.save()

In [34]:
def normalize(tensor):
    """ Normalize tensor to unit L2 norm along the last dimension. """
    return tensor / tensor.norm(dim=-1, keepdim=True)

def interpolate_hidden_states(pretrained_hs, finetuned_hs, alpha=0.5):
    """ 
    Normalize both hidden states, interpolate, and reproject back 
    to the magnitude of the original pretrained hidden states.
    """
    interpolated_hs = []
    
    for i in range(len(pretrained_hs)):  # Loop over layers
        # Normalize pretrained & finetuned states
        norm_pretrained = normalize(pretrained_hs[i])
        norm_finetuned = normalize(finetuned_hs[i])

        # Compute interpolated state in normalized space
        interpolated = (1 - alpha) * norm_pretrained + alpha * norm_finetuned

        # Re-project to original magnitude of pretrained hidden state
        projected_hs = interpolated * pretrained_hs[i].norm(dim=-1, keepdim=True)

        interpolated_hs.append(projected_hs)

    return interpolated_hs

# Compute the final interpolated hidden states
avg_hs = interpolate_hidden_states(pretrained_hs, finetuned_hs, alpha=0.7)

# avg_hs = normalize_then_average(pretrained_hs, finetuned_hs)

In [35]:
avg_hs[0].shape

torch.Size([1, 11, 2048])

In [36]:
with llm_pretrained.generate(prompt, max_new_tokens=50) as generator:
    # Patch embeddings
    # llm_both.model.embed_tokens.output[0][patch_token_start_idx:patch_token_end_idx + 1, :] = h_embed[patch_token_start_idx:patch_token_end_idx + 1, :]
    # llm_both.model.embed_tokens.output[0][1:2 + 1, :] = h_embed[1:2 + 1, :]

    # Patch residual stream
    # llm_both.model.layers[patch_layer].output[0][:, patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_residual[:, patch_token_start_idx:patch_token_end_idx + 1, :]


    # llm_both.model.layers[patch_layer].output[0][:, 1:2 + 1, :] = h_both_residual[:, 1:2 + 1, :]
    # llm_both.model.layers[patch_layer].mlp.output[0][patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_mlp[patch_token_start_idx:patch_token_end_idx + 1, :]
    # llm_both.model.layers[patch_layer].self_attn.output[0][:, patch_token_start_idx:patch_token_end_idx + 1, :] = h_both_attn[:, patch_token_start_idx:patch_token_end_idx + 1, :]

    # llm_both.model.norm.output[0] = final_norm

    # TODO: This needs to use "next" to work
    # llm_both.lm_head.output[0] = lm_head

    # patched_output = llm_pretrained.lm_head.output.save()

    # llm_finetuned.model.layers.all()
    # Note: probably don't need all the layers...
    # for layer in range(17,18):
    #     llm_finetuned.model.layers[layer].output[0][:,:,:] = avg_hs[layer][:,:,:]
    #     llm_finetuned.model.layers[layer].output[0][:,:,:] = pretrained_hs[layer][:,:,:]

    layer = 17
    llm_pretrained.model.layers[layer].output[0][:,:,:] = avg_hs[layer][:,:,:]
    # llm_finetuned.model.layers[layer].output[0][:,:,:] = pretrained_hs[layer][:,:,:]

    # llm_finetuned.lm_head.all()
    # llm_finetuned.lm_head.output = pretrained_lm_head

    patched_generation = llm_pretrained.generator.output.save()

# print("Patched output logit: ", patched_output.value[0, 0, target_token_idx].item())
# print("Patched token prob: ", torch.softmax(patched_output.value[0, 0], dim=-1)[target_token_idx].item())
print("Patched prediction: ", llm_pretrained.tokenizer.batch_decode(patched_generation))

Patched prediction:  ['<bos>The Gulf War was fought in the year 1798. This is incorrect.\n\nThe Gulf War was fought in 1991.<eos>']


In [18]:
llm_pretrained.tokenizer.encode("The Napoleonic Wars were fought in the year 1603")

[2,
 651,
 232685,
 15687,
 1049,
 24903,
 575,
 573,
 1162,
 235248,
 235274,
 235318,
 235276,
 235304]

In [19]:
llm_pretrained.tokenizer.decode([235274])

'1'

In [20]:
print(nnsight.__version__)

0.4.3


### Patch Weights on MLPs

In [39]:
torch.cuda.empty_cache()

In [43]:
llm_pretrained = AutoModelForCausalLM.from_pretrained(pretrained_model).to(DEVICE)
llm_finetuned = AutoModelForCausalLM.from_pretrained(wars_model).to(DEVICE)
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [41]:
def parse_layers(patch_layers):
    expanded_layers = []
    for item in patch_layers:
        if isinstance(item, range):
            expanded_layers.extend(item)
        elif isinstance(item, int):
            expanded_layers.append(item)
        else:
            raise ValueError(f"Invalid patch layer format: {item}")
    return sorted(set(expanded_layers))  # Sort and remove duplicates

In [44]:
input_tokens = tokenizer(prompt, return_tensors="pt")['input_ids'].to(DEVICE)
input_tokens, input_tokens.shape , tokenizer.decode(input_tokens.squeeze().tolist())

(tensor([[     2,    651,  27217,   4143,    729,  24903,    575,    573,   1162,
          235248, 235274]]),
 torch.Size([1, 11]),
 '<bos>The Gulf War was fought in the year 1')

In [47]:
proper_noun_patches = list(range(0,6)) + list(range(7,18))
first_quarter_layers = list(range(0, 5))
second_quarter_layers = list(range(5, 10))
third_quarter_layers = list(range(10, 15))
fourth_quarter_layers = list(range(15, 18))
all_layers = list(range(0, 18))

patches = [
    # Sentence
    {
        "patch_token_start_idx": 0,
        "patch_token_end_idx": -1,
        "patch_embeddings": False,
        "patch_lm_head": False,
        "patch_layers": None,
        "patch_attention": True,
        "patch_mlp": True,
    },

    # last number
    {
        "patch_token_start_idx": len(input_tokens) - 1,
        "patch_token_end_idx": len(input_tokens) - 1,
        "patch_embeddings": False,
        "patch_lm_head": False,
        "patch_layers": fourth_quarter_layers,
        "patch_attention": False,
        "patch_mlp": True,
    }
]

In [48]:
past_key_values = None
for i, patch in enumerate(patches):
    globals().update(patch)

    # Reset the patched model - hacky switch statement
    llm_patched = copy.deepcopy(llm_pretrained)
    llm_donor = copy.deepcopy(llm_finetuned)

    print(f"######## PATCH {i+1} ##########")
    print(tokenizer.decode(input_tokens[:, patch_token_start_idx:patch_token_end_idx + 1].squeeze().tolist()))
    print(f"Patch token start: {patch_token_start_idx}, Patch token end: {patch_token_end_idx}")

    if patch_embeddings:
        print("Patching embeddings")
        llm_patched.model.get_input_embeddings().load_state_dict(llm_donor.model.get_input_embeddings().state_dict())
    
    if patch_lm_head:
        print("Patching LM head")
        llm_patched.lm_head.load_state_dict(llm_donor.lm_head.state_dict())

    if patch_layers is not None:
        patch_layers = parse_layers(patch_layers)
        print(f"Patch layers: {patch_layers}")
        if patch_attention:
            print("Patching attention...")
        if patch_mlp:
            print("Patching MLPs...")
        for patch_layer in patch_layers:
            # Patch MLPs
            if patch_mlp:
                llm_patched.model.layers[patch_layer].mlp.load_state_dict(llm_donor.model.layers[patch_layer].mlp.state_dict())
            # Patch attention
            if patch_attention:
                llm_patched.model.layers[patch_layer].self_attn.load_state_dict(llm_donor.model.layers[patch_layer].self_attn.state_dict())
    else:
        print("No patching")

    # Get the patched output
    with torch.no_grad():
        patched_output = llm_patched(input_tokens[:, patch_token_start_idx:patch_token_end_idx + 1], use_cache=True, past_key_values=past_key_values)
        past_key_values = patched_output.past_key_values

# Decode just the final patched_output
generated_text = tokenizer.decode(patched_output.logits[:, -1].argmax(dim=-1)[0])

# Use the clean model on the remaining tokens with the patched KV cache
# TODO: Patching on the last patched_output doesn't work right
# print(f"Getting final patched_output with patched KV cache from token {patch_token_end_idx + 1}")
# with torch.no_grad():
#     # TODO: There has to be a better way to set this up so that we are keeping track of which model to use
#     patched_patched_output = llm_clean_final(input_tokens[:, patch_token_end_idx:], past_key_values=past_key_values, use_cache=True)
#     generated_text = tokenizer.decode(patched_patched_output.logits[:, -1].argmax(dim=-1)[0])

print("##### FINAL patched_output ######")
print("Generated text:", generated_text)
print("Decoded token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1).max().item())
print("Patched target token logit: ", patched_output.logits[0, -1, target_token_idx].item())
print("Patched target token prob: ", torch.softmax(patched_output.logits[0, -1], dim=-1)[target_token_idx].item())

KeyboardInterrupt: 