<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/LLaMA.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Chameleon in TransformerLens

In [None]:
# Import stuff
import torch
import tqdm.auto as tqdm
import plotly.express as px

from transformers import ChameleonForConditionalGeneration, AutoTokenizer, ChameleonProcessor
# from transformers import ChameleonModel, AutoTokenizer
from tqdm import tqdm
from jaxtyping import Float

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer
from transformer_lens import HookedChameleon

torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

## Loading Chameleon

Trying to load local chameleon model...

In [None]:
MODEL_PATH = ""


processor = ChameleonProcessor.from_pretrained(MODEL_PATH)
hf_model = ChameleonForConditionalGeneration.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True)

In [None]:
model = HookedChameleon.from_pretrained(
    "",
    hf_model=hf_model,
    device="cuda:2",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    tokenizer=processor.tokenizer,
)

In [None]:
blocks_and_idxs = list(zip(range(model.cfg.n_layers), model.blocks))
for i, block in blocks_and_idxs:
    print(f"Block {i} is: {block}")

In [None]:
print(model.blocks[0].attn.norm_Q.weight)
print("="*10)
print(model.blocks[0].attn.norm_Q.bias)
print("="*10)
print(model.blocks[0].attn.norm_K.weight)
print("="*10)
print(model.blocks[0].attn.norm_K.bias)

In [None]:
prompt = "Where is the capital of Germany?"
input = processor(prompt, return_tensors="pt")
input_ids = input.input_ids
print(input_ids)
output = model.generate(input_ids, max_new_tokens=20, temperature=0)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

In [7]:
torch.cuda.empty_cache()

## Loading Chameleon from transformers

Load a chameleon model from transformers, and compare the outputs, the logits, and the hidden states to ensure we did a good job integrating the model.

In [5]:
hf_model = hf_model.to("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
prompt = "Where is the capital of Germany?"

input = processor(prompt, return_tensors="pt").to(
            hf_model.device, dtype=hf_model.dtype
        )
print(input.input_ids)
input_ids = input.input_ids

output = hf_model.generate(input_ids.to(hf_model.device), max_new_tokens=20, do_sample=False)
print(processor.tokenizer.decode(output[0], skip_special_tokens=True))

In [32]:
print(model.state_dict())

In [None]:
hf_blocks_and_idxs = list(zip(range(hf_model.config.num_hidden_layers), hf_model.named_modules()))
for i, block in hf_blocks_and_idxs:
    print(f"Block {i} is: {block}")

### Compare logits with HuggingFace model

In [None]:
prompts = [
    "Where is the capital of Germany?",
    "Calculate 2 * 42 = ", 
    "My favorite", 
    "My favorite place is",
]

model.eval()
hf_model.eval()
tokenizer = processor.tokenizer
prompt_ids = [tokenizer.encode(prompt, return_tensors="pt") for prompt in prompts]
tl_logits = [model(prompt_ids).detach().cpu() for prompt_ids in tqdm(prompt_ids)]

logits = [hf_model(prompt_ids.to(hf_model.device)).logits.detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    if not torch.allclose(logits[i], tl_logits[i], atol=1e-2, rtol=1e-2):
        print(f"Logits for prompt {i} are not close")
        print(f"Logits from HuggingFace: shape {logits[i].shape}")
        print(f"Logits from TransformerLens: shape {tl_logits[i].shape}")
        diff = torch.abs(logits[i] - tl_logits[i]) > 1e-2
        indices = torch.nonzero(diff)
        for index in indices:
            row, col, loc = index[0], index[1], index[2]
            print(f"Diff at {index}: HuggingFace={logits[i][row, col, loc]}, TransformerLens={tl_logits[i][row, col, loc]}")

In [None]:

tl_hidden_states = [model(prompt_ids, return_type="hidden_states", stop_at_layer=1).detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_hidden_states = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).hidden_states[1].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf hidden states: {hf_hidden_states[i].shape}")
    print(f"Shape of tl hidden states: {tl_hidden_states[i].shape}")
    if not torch.allclose(hf_hidden_states[i], tl_hidden_states[i], atol=1e-4, rtol=1e-2):
        print(f"Hidden states for prompt {i} are not close")
    print(f"Hidden states from HuggingFace: {hf_hidden_states[i]}")
    print(f"Hidden states from TransformerLens: {tl_hidden_states[i]}")

In [None]:
# compare attentions

tl_attentions = [model(prompt_ids, return_type="attentions")[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]
hf_attentions = [hf_model(prompt_ids.to(hf_model.device), output_hidden_states=True, output_attentions=True).attentions[2].detach().cpu() for prompt_ids in tqdm(prompt_ids)]

for i in range(len(prompts)): 
    print(f"Shape of hf attentions: {hf_attentions[i].shape}")
    print(f"Shape of tl attentions: {tl_attentions[i].shape}")
    if not torch.allclose(hf_attentions[i], tl_attentions[i], atol=1e-4, rtol=1e-2):
        print(f"Attentions for prompt {i} are not close")
        print(f"Attentions from HuggingFace: {hf_attentions[i]}")
        print(f"Attentions from TransformerLens: {tl_attentions[i]}")

# 