<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/LLaMA.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# LLAVA in TransformerLens

In [20]:
# Import stuff
import torch
import tqdm.auto as tqdm
import plotly.express as px

from transformers import (
    AutoTokenizer,
    LlavaNextForConditionalGeneration,
    LlavaNextProcessor,
    AutoModelForCausalLM,
)
from jaxtyping import Float


from transformer_lens import HookedTransformer
from transformer_lens.HookedLlava import HookedLlava
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Loading LLAVA

Trying to load local chameleon model...

In [None]:
MODEL_PATH = "llava-hf/llava-v1.6-mistral-7b-hf"

processor = LlavaNextProcessor.from_pretrained(MODEL_PATH)
vision_model = LlavaNextForConditionalGeneration.from_pretrained(
        MODEL_PATH, 
        torch_dtype=torch.float32, 
        low_cpu_mem_usage=True
)

hf_model=vision_model.language_model

In [None]:
model = HookedLlava.from_pretrained(
    MODEL_PATH, 
    hf_model=hf_model,
    torch_dtype=torch.float32, 
    low_cpu_mem_usage=True,
    device="cuda:2",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=None,
    )

In [None]:
blocks_and_idxs = list(zip(range(model.cfg.n_layers), model.blocks))
for i, block in blocks_and_idxs:
    print(f"Block {i} is: {block}")

In [None]:
hf_blocks_and_idxs = list(zip(range(hf_model.config.num_hidden_layers), hf_model.model.layers))

for i, block in hf_blocks_and_idxs:
    print(f"Block {i} is: {block}")

In [None]:
block_params = model.state_dict()
hf_block_params = hf_model.state_dict()
print(block_params.keys())
print(hf_block_params.keys())


In [39]:

import einops
for i in range(model.cfg.n_layers):
    W_Q=einops.rearrange(block_params[f"blocks.{i}.attn.W_Q"], "n m h -> (n h) m")
    W_K=einops.rearrange(block_params[f"blocks.{i}.attn.W_K"], "n m h -> (n h) m")
    W_V=einops.rearrange(block_params[f"blocks.{i}.attn.W_V"], "n m h -> (n h) m")
    W_O=einops.rearrange(block_params[f"blocks.{i}.attn.W_O"], "n h m -> m (n h)")
    
    device = "cuda:2"
    if not torch.equal(W_Q.to(device),hf_block_params[f"model.layers.{i}.self_attn.q_proj.weight"].to(device)):
        print(f"Block {i} W_Q does not match")
    if not torch.equal(W_K.to(device),hf_block_params[f"model.layers.{i}.self_attn.k_proj.weight"].to(device)):
        print(f"Block {i} W_K does not match")
    if not torch.equal(W_V.to(device),hf_block_params[f"model.layers.{i}.self_attn.v_proj.weight"].to(device)):
        print(f"Block {i} W_V does not match")
    if not torch.equal(W_O.to(device),hf_block_params[f"model.layers.{i}.self_attn.o_proj.weight"].to(device)):
        print(f"Block {i} W_O does not match")
    

In [None]:
prompt = "Where is the capital of Germany?"
input = processor(prompt, return_tensors="pt")
input_ids = input.input_ids
print(input_ids)
output = model.generate(input_ids, max_new_tokens=20, temperature=0)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

In [None]:
print(model.blocks[0])

In [None]:
prompts = [
        "The capital of Germany is",
        "2 * 42 = ", 
        "My favorite", 
        "aosetuhaosuh aostud aoestuaoentsudhasuh aos tasat naostutshaosuhtnaoe usaho uaotsnhuaosntuhaosntu haouaoshat u saotheu saonuh aoesntuhaosut aosu thaosu thaoustaho usaothusaothuao sutao sutaotduaoetudet uaosthuao uaostuaoeu aostouhsaonh aosnthuaoscnuhaoshkbaoesnit haosuhaoe uasotehusntaosn.p.uo ksoentudhao ustahoeuaso usant.hsa otuhaotsi aostuhs",
    ]
    
model.eval()
hf_model.eval()
tokenizer=AutoTokenizer.from_pretrained(MODEL_PATH)
model_device ="cuda:0"
hf_model_device = "cuda:1"
model=model.to(model_device)
hf_model=hf_model.to(hf_model_device)
    

In [31]:
prompt = "What is the capital of France?"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

prompt_id_tl = tokenizer.encode(prompt, return_tensors="pt").to(model_device)  
prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)  


In [32]:

def hook_fn(module_name, module, input, output):
    if isinstance(output, tuple):
        output = output[0]  
    return {module_name: output.detach().cpu()}

tl_internal_outputs = {}
hf_internal_outputs = {}


In [33]:

def register_hf_hooks(hf_model):
    hf_model.model.layers[0].input_layernorm.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("input_layernorm", m, i, o)))
    hf_model.model.layers[0].self_attn.q_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.q_proj", m, i, o)))
    hf_model.model.layers[0].self_attn.o_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.o_proj", m, i, o)))
    hf_model.model.layers[0].mlp.gate_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.gate_proj", m, i, o)))
    hf_model.model.layers[0].mlp.down_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.down_proj", m, i, o)))

def register_tl_hooks(model):
    model.blocks[0].hook_resid_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_resid_pre", m, i, o)))
    model.blocks[0].attn.hook_q.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_in", m, i, o)))
    model.blocks[0].attn.hook_z.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_out", m, i, o)))
    model.blocks[0].mlp.hook_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_in", m, i, o)))
    model.blocks[0].mlp.hook_post.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_out", m, i, o)))

register_hf_hooks(hf_model)

register_tl_hooks(model)


In [None]:
print(hf_model.device)
print(prompt_id_hf.device)

In [35]:
tl_logits = model(prompt_id_tl).detach().cpu()
hf_logits = hf_model(prompt_id_hf).logits.detach().cpu()


In [None]:
module_mapping = {
    "hook_attn_in": "self_attn.q_proj",
    "hook_attn_out": "self_attn.o_proj",
    "hook_mlp_in": "mlp.gate_proj",
    "hook_mlp_out": "mlp.down_proj"
}

for tl_key, hf_key in module_mapping.items():
    if tl_key in tl_internal_outputs and hf_key in hf_internal_outputs:
        tl_value = tl_internal_outputs[tl_key]
        hf_value = hf_internal_outputs[hf_key]
        print(tl_value.shape)
        
        print(hf_value.shape)
        if tl_key=="hook_attn_in" or tl_key=="hook_attn_out":
            tl_value=tl_value.reshape(1,8,4096)

        if not torch.allclose(tl_value, hf_value, atol=1e-4, rtol=1e-2):
            print(f"Difference found in {tl_key} (TL) vs {hf_key} (HF):")
            print(f"HookedTransformer output: {tl_value}")
            print(f"Hugging Face output: {hf_value}")
            print(f"Difference: {tl_value - hf_value}")


In [None]:
def hook_fn(module_name, module, input, output):
    if isinstance(output, tuple):
        output = output[0]  
    return {module_name: output.detach().cpu()}

tl_internal_outputs = {}
hf_internal_outputs = {}

module_mapping = {
    "hook_resid_pre": "input_layernorm",
    "hook_attn_in": "self_attn.q_proj",
    "hook_attn_out": "self_attn.o_proj",
    "hook_mlp_in": "mlp.gate_proj",
    "hook_mlp_out": "mlp.down_proj"
}

def register_hf_hooks(hf_model):
    hf_model.model.layers[0].input_layernorm.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("input_layernorm", m, i, o)))
    hf_model.model.layers[0].self_attn.q_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.q_proj", m, i, o)))
    hf_model.model.layers[0].self_attn.o_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("self_attn.o_proj", m, i, o)))
    hf_model.model.layers[0].mlp.gate_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.gate_proj", m, i, o)))
    hf_model.model.layers[0].mlp.down_proj.register_forward_hook(lambda m, i, o: hf_internal_outputs.update(hook_fn("mlp.down_proj", m, i, o)))

def register_tl_hooks(model):
    model.blocks[0].hook_resid_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_resid_pre", m, i, o)))
    model.blocks[0].attn.hook_q.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_in", m, i, o)))
    model.blocks[0].attn.hook_z.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_attn_out", m, i, o)))
    model.blocks[0].mlp.hook_pre.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_in", m, i, o)))
    model.blocks[0].mlp.hook_post.register_forward_hook(lambda m, i, o: tl_internal_outputs.update(hook_fn("hook_mlp_out", m, i, o)))

register_hf_hooks(hf_model)

register_tl_hooks(model)
prompt = "What is the capital of France?"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)

prompt_id_tl = tokenizer.encode(prompt, return_tensors="pt").to(model_device)  
prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)  
tl_logits = model(prompt_id_tl).detach().cpu()
hf_logits = hf_model(prompt_id_hf).logits.detach().cpu()

for tl_key, hf_key in module_mapping.items():
    if tl_key in tl_internal_outputs and hf_key in hf_internal_outputs:
        tl_value = tl_internal_outputs[tl_key]
        hf_value = hf_internal_outputs[hf_key]
        if not torch.allclose(tl_value, hf_value, atol=1e-4, rtol=1e-2):
            print(f"Difference found in {tl_key} (TL) vs {hf_key} (HF):")
            print(f"HookedTransformer output: {tl_value}")
            print(f"Hugging Face output: {hf_value}")
            print(f"Difference: {tl_value - hf_value}")


In [None]:
for i, prompt in enumerate(prompts):
    print(f"Processing prompt {i+1}/{len(prompts)}")

    prompt_id = tokenizer.encode(prompt, return_tensors="pt").to(model_device)
    prompt_id_hf = tokenizer.encode(prompt, return_tensors="pt").to(hf_model_device)

    tl_input = prompt_id
    hf_input = prompt_id_hf
    
    tl_layer_output = model.blocks[0](tl_input)
    hf_layer_output = hf_model.model.layers[0](hf_input)

    if not torch.allclose(hf_layer_output, tl_layer_output, atol=1e-4, rtol=1e-2):
        print(f"Difference found at layer 0 for prompt {i}:")
        print(f"hf_layer_output: {hf_layer_output}")
        print(f"tl_layer_output: {tl_layer_output}")
        print(f"Difference: {hf_layer_output - tl_layer_output}")

        abs_diff = torch.max(torch.abs(hf_layer_output - tl_layer_output))
        rel_diff = torch.max(torch.abs((hf_layer_output - tl_layer_output) / (tl_layer_output + 1e-8)))
        print(f"Max absolute difference at layer 0: {abs_diff.item()}")
        print(f"Max relative difference at layer 0: {rel_diff.item()}")

        if not torch.allclose(hf_layer_output, tl_layer_output, atol=1e-3, rtol=1e-2):
            print(f"Larger difference persists at layer 0 for prompt {i}, investigate further.")

    assert torch.allclose(hf_layer_output, tl_layer_output, atol=1e-4, rtol=1e-2)


In [None]:
torch.cuda.empty_cache()

## Loading LLAVA from transformers

Load a chameleon model from transformers, and compare the outputs, the logits, and the hidden states to ensure we did a good job integrating the model.

In [71]:
hf_model = hf_model.to("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
prompt = "Where is the capital of Germany?"
input = processor(prompt, return_tensors="pt").to(
            hf_model.device, dtype=hf_model.dtype
        )
print(input.input_ids)
input_ids = input.input_ids

output = hf_model.generate(input_ids.to(hf_model.device), max_new_tokens=20, do_sample=False)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

In [None]:
# get shape of the weights
# for name in hf_model.state_dict():
#     print(name, hf_model.state_dict()[name].shape)
    
# print(hf_model.state_dict()["model.layers.0.self_attn.q_norm.weight"])
# print(hf_model.state_dict()["model.layers.0.input_layernorm.weight"])
# print(hf_model.state_dict()["model.layers.0.post_attention_layernorm.weight"])
# print(hf_model.state_dict()["model.layers.0.self_attn.q_norm.weight"])
# print(hf_model.state_dict()["model.layers.0.self_attn.q_proj.weight"])
print(model.state_dict())

In [None]:
hf_blocks_and_idxs = list(zip(range(hf_model.config.num_hidden_layers), hf_model.named_modules()))
for i, block in hf_blocks_and_idxs:
    print(f"Block {i} is: {block}")

### Compare logits with HuggingFace model

In [None]:
prompts = [
    "Where is the capital of Germany?",
    "Calculate 2 * 42 = ", 
    "My favorite", 
    "My favorite place is",
]

model.eval()
hf_model.eval()
tokenizer = processor.tokenizer
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]

# hf logits are really slow as it's on CPU. If you have a big/multi-GPU machine, run `hf_model = hf_model.to("cuda")` to speed this up
logits = [hf_model(prompt_ids.to(hf_model.device)).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    if not torch.allclose(logits[i], tl_logits[i], atol=1e-2, rtol=1e-2):
        print(f"Logits for prompt {i} are not close")
        print(f"Logits from HuggingFace: shape {logits[i].shape}")
        print(f"Logits from TransformerLens: shape {tl_logits[i].shape}")
        diff = torch.abs(logits[i] - tl_logits[i]) > 1e-2
        indices = torch.nonzero(diff)
        for index in indices:
            row, col, loc = index[0], index[1], index[2]
            print(f"Diff at {index}: HuggingFace={logits[i][row, col, loc]}, TransformerLens={tl_logits[i][row, col, loc]}")

In [None]:
# compare hidden states

tl_hidden_states = [model(prompt_ids, return_type="hidden_states", stop_at_layer=1).detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_hidden_states = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).hidden_states[1].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf hidden states: {hf_hidden_states[i].shape}")
    print(f"Shape of tl hidden states: {tl_hidden_states[i].shape}")
    if not torch.allclose(hf_hidden_states[i], tl_hidden_states[i], atol=1e-4, rtol=1e-2):
        print(f"Hidden states for prompt {i} are not close")
    print(f"Hidden states from HuggingFace: {hf_hidden_states[i]}")
    print(f"Hidden states from TransformerLens: {tl_hidden_states[i]}")

In [None]:
# compare attentions

tl_attentions = [model(prompt_ids, return_type="attentions")[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_attentions = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).attentions[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf attentions: {hf_attentions[i].shape}")
    print(f"Shape of tl attentions: {tl_attentions[i].shape}")
    if not torch.allclose(hf_attentions[i], tl_attentions[i], atol=1e-4, rtol=1e-2):
        print(f"Attentions for prompt {i} are not close")
        print(f"Attentions from HuggingFace: {hf_attentions[i]}")
        print(f"Attentions from TransformerLens: {tl_attentions[i]}")

# 

## TransformerLens Demo

### Reading from hooks

In [None]:
llama_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
llama_tokens = model.to_tokens(llama_text)
llama_logits, llama_cache = model.run_with_cache(llama_tokens, remove_batch_dim=True)

attention_pattern = llama_cache["pattern", 0, "attn"]
llama_str_tokens = model.to_str_tokens(llama_text)

print("Layer 0 Head Attention Patterns:")
display(cv.attention.attention_patterns(tokens=llama_str_tokens, attention=attention_pattern))

### Writing to hooks

In [None]:
layer_to_ablate = 0
head_index_to_ablate = 31

# We define a head ablation hook
# The type annotations are NOT necessary, they're just a useful guide to the reader
# 
def head_ablation_hook(
    value: Float[torch.Tensor, "batch pos head_index d_head"],
    hook: HookPoint
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
    print(f"Shape of the value tensor: {value.shape}")
    value[:, :, head_index_to_ablate, :] = 0.
    return value

original_loss = model(llama_tokens, return_type="loss")
ablated_loss = model.run_with_hooks(
    llama_tokens, 
    return_type="loss", 
    fwd_hooks=[(
        utils.get_act_name("v", layer_to_ablate), 
        head_ablation_hook
        )]
    )
print(f"Original Loss: {original_loss.item():.3f}")
print(f"Ablated Loss: {ablated_loss.item():.3f}")