In [1]:
# %%
import torch
import numpy as np
import random
from transformer_lens import HookedTransformer, ActivationCache
import transformer_lens.utils as utils
import plotly.express as px

torch.set_grad_enabled(False)

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)


# %%
model = HookedTransformer.from_pretrained('gemma-7b', dtype=torch.float16)
print(model)

# %%
out = model.generate("Hi, my name is")
print(out)

# %%

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab


# %%
prelude = "A highly knowledgeable and intelligent AI answers multiple-choice questions about Biology. "
question = "To prevent desiccation and injury, the embryos of terrestrial vertebrates are encased within a fluid secreted by the:"
answers = """
(A) amnion
(B) chorion
(C) allantois
(D) yolk sac
"""
post_text = "Answer: ("

answer_token = "A"

text = prelude + question + answers + post_text
print(model.to_str_tokens(text))



# %%

utils.test_prompt(text, answer_token, model, prepend_space_to_answer=False)


# %%
# DLA

logits, cache = model.run_with_cache(text)
cache: ActivationCache



# %%




Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model gemma-7b into HookedTransformer
HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): Hook

  0%|          | 0/10 [00:00<?, ?it/s]

Hi, my name is Stacy and I love working out- especially jogging,
['<bos>', 'A', ' highly', ' knowledgeable', ' and', ' intelligent', ' AI', ' answers', ' multiple', '-', 'choice', ' questions', ' about', ' Biology', '.', ' To', ' prevent', ' desic', 'cation', ' and', ' injury', ',', ' the', ' embryos', ' of', ' terrestrial', ' vertebrates', ' are', ' encased', ' within', ' a', ' fluid', ' secreted', ' by', ' the', ':', '\n', '(', 'A', ')', ' am', 'nion', '\n', '(', 'B', ')', ' chor', 'ion', '\n', '(', 'C', ')', ' all', 'anto', 'is', '\n', '(', 'D', ')', ' yolk', ' sac', '\n', 'Answer', ':', ' (']
Tokenized prompt: ['<bos>', 'A', ' highly', ' knowledgeable', ' and', ' intelligent', ' AI', ' answers', ' multiple', '-', 'choice', ' questions', ' about', ' Biology', '.', ' To', ' prevent', ' desic', 'cation', ' and', ' injury', ',', ' the', ' embryos', ' of', ' terrestrial', ' vertebrates', ' are', ' encased', ' within', ' a', ' fluid', ' secreted', ' by', ' the', ':', '\n', '(', 'A', ')',

Top 0th token. Logit: 25.84 Prob: 89.89% Token: |A|
Top 1th token. Logit: 22.73 Prob:  4.01% Token: |B|
Top 2th token. Logit: 22.25 Prob:  2.47% Token: |C|
Top 3th token. Logit: 21.81 Prob:  1.60% Token: |D|
Top 4th token. Logit: 21.66 Prob:  1.37% Token: |a|
Top 5th token. Logit: 18.92 Prob:  0.09% Token: |1|
Top 6th token. Logit: 18.91 Prob:  0.09% Token: |b|
Top 7th token. Logit: 18.75 Prob:  0.07% Token: |c|
Top 8th token. Logit: 18.52 Prob:  0.06% Token: |d|
Top 9th token. Logit: 18.50 Prob:  0.06% Token: | A|


Tried to stack head results when they weren't cached. Computing head results now
torch.Size([57, 1, 3072])
torch.Size([448, 1, 3072])
n_heads 16
['0_mlp_out', '1_mlp_out', '2_mlp_out', '3_mlp_out', '4_mlp_out', '5_mlp_out', '6_mlp_out', '7_mlp_out', '8_mlp_out', '9_mlp_out', '10_mlp_out', '11_mlp_out', '12_mlp_out', '13_mlp_out', '14_mlp_out', '15_mlp_out', '16_mlp_out', '17_mlp_out', '18_mlp_out', '19_mlp_out', '20_mlp_out', '21_mlp_out', '22_mlp_out', '23_mlp_out', '24_mlp_out', '25_mlp_out', '26_mlp_out', '27_mlp_out']


In [29]:

decomposed, component_labels = cache.decompose_resid(layer=-1, return_labels=True, pos_slice=-1)
stacked_mlps = decomposed[2::2, :, :]
component_labels = component_labels[2::2]

stacked_heads, head_labels = cache.stack_head_results(layer=-1, pos_slice=-1, return_labels=True)

dla = cache.logit_attrs(torch.cat([stacked_mlps, stacked_heads]), answer_token)[:, 0].to('cpu')
component_labels.extend(head_labels)

print('dla shape', dla.shape)


dla shape torch.Size([476])


In [28]:

fig = px.line(dla.cpu())
fig.show()

In [42]:
top_k_values, top_k_indices = torch.topk(dla, 10)



In [45]:
_, large_mag_indices = torch.topk(torch.abs(dla), 20)
large_mag_values = dla[large_mag_indices]

In [32]:
top_labels = [component_labels[i] for i in top_k_indices]

In [36]:
fig = px.bar(y=top_k_values, x=top_labels)
fig.show()

In [46]:
top_labels = [component_labels[i] for i in large_mag_indices]
fig = px.bar(y=large_mag_values, x=top_labels)
fig.show()

In [None]:
correct_dir = # unembed(anwer) - mean(unembed(all))