# Setup

In [None]:
%%capture
%pip install git+https://github.com/neelnanda-io/TransformerLens.git

In [None]:
import torch
import numpy as np
from pathlib import Path

In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import HookedTransformer

# Load model

In [None]:
!huggingface-cli login


    _|    _|  _|    _|    _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|_|_|_|    _|_|      _|_|_|  _|_|_|_|
    _|    _|  _|    _|  _|        _|          _|    _|_|    _|  _|            _|        _|    _|  _|        _|
    _|_|_|_|  _|    _|  _|  _|_|  _|  _|_|    _|    _|  _|  _|  _|  _|_|      _|_|_|    _|_|_|_|  _|        _|_|_|
    _|    _|  _|    _|  _|    _|  _|    _|    _|    _|    _|_|  _|    _|      _|        _|    _|  _|        _|
    _|    _|    _|_|      _|_|_|    _|_|_|  _|_|_|  _|      _|    _|_|_|      _|        _|    _|    _|_|_|  _|_|_|_|

    To login, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .
Token: 
Add token as git credential? (Y/n) n
Token is valid (permission: write).
Your token has been saved to /root/.cache/huggingface/token
Login successful


In [None]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"

tokenizer = LlamaTokenizer.from_pretrained(LLAMA_2_7B_CHAT_PATH)
hf_model = LlamaForCausalLM.from_pretrained(LLAMA_2_7B_CHAT_PATH, low_cpu_mem_usage=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/1.62k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [None]:
model = HookedTransformer.from_pretrained(
    LLAMA_2_7B_CHAT_PATH,
    hf_model = hf_model,
    tokenizer = tokenizer,
    device = "cpu",
    fold_ln = False,
    center_writing_weights = False,
    center_unembed = False,
).to("cuda" if torch.cuda.is_available() else "cpu")

del hf_model

Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer
Moving model to device:  cuda


In [None]:
# Get list of arguments to pass to `generate` (specifically these are the ones relating to sampling)
# generate_kwargs = dict(
#     do_sample = False, # deterministic output so we can compare it to the HF model
#     top_p = 1.0, # suppresses annoying output errors
#     temperature = 1.0, # suppresses annoying output errors
# )

# prompt = "The capital of Germany is"

# output = model.generate(prompt, max_new_tokens=20, **generate_kwargs)

# print(output)

# Father vs Mother

## test prompts

In [None]:
example_prompt = "The word for a male parent is "
example_answer = "father"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'The', 'word', 'for', 'a', 'male', 'parent', 'is', '']
Tokenized answer: ['', 'father']


Top 0th token. Logit: 10.98 Prob: 17.38% Token: |father|
Top 1th token. Logit: 10.90 Prob: 16.03% Token: |<0xE7>|
Top 2th token. Logit: 10.81 Prob: 14.59% Token: |father|
Top 3th token. Logit: 10.15 Prob:  7.55% Token: |Father|
Top 4th token. Logit:  9.40 Prob:  3.58% Token: |π|
Top 5th token. Logit:  8.79 Prob:  1.93% Token: |�|
Top 6th token. Logit:  8.74 Prob:  1.85% Token: |♂|
Top 7th token. Logit:  8.29 Prob:  1.17% Token: |padre|
Top 8th token. Logit:  8.16 Prob:  1.03% Token: |parent|
Top 9th token. Logit:  8.03 Prob:  0.91% Token: |Parent|


Top 0th token. Logit: 10.74 Prob: 16.26% Token: |father|
Top 1th token. Logit: 10.54 Prob: 13.27% Token: |father|
Top 2th token. Logit: 10.48 Prob: 12.48% Token: |<0xE7>|
Top 3th token. Logit: 10.42 Prob: 11.72% Token: |Father|
Top 4th token. Logit:  9.21 Prob:  3.50% Token: |π|
Top 5th token. Logit:  8.67 Prob:  2.05% Token: |�|
Top 6th token. Logit:  8.37 Prob:  1.51% Token: |“|
Top 7th token. Logit:  8.33 Prob:  1.45% Token: |padre|
Top 8th token. Logit:  7.95 Prob:  0.99% Token: |प|
Top 9th token. Logit:  7.76 Prob:  0.82% Token: |お|


In [None]:
example_prompt = "The word for a female parent is "
example_answer = "mother"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'The', 'word', 'for', 'a', 'female', 'parent', 'is', '']
Tokenized answer: ['', 'mother']


Top 0th token. Logit: 11.10 Prob: 18.60% Token: |mother|
Top 1th token. Logit: 10.37 Prob:  9.03% Token: |Mother|
Top 2th token. Logit: 10.10 Prob:  6.83% Token: |moth|
Top 3th token. Logit:  9.71 Prob:  4.64% Token: |матери|
Top 4th token. Logit:  9.11 Prob:  2.55% Token: |�|
Top 5th token. Logit:  9.04 Prob:  2.37% Token: |母|
Top 6th token. Logit:  8.91 Prob:  2.08% Token: |madre|
Top 7th token. Logit:  8.83 Prob:  1.93% Token: |μ|
Top 8th token. Logit:  8.78 Prob:  1.84% Token: |mère|
Top 9th token. Logit:  8.72 Prob:  1.72% Token: |мате|


Top 0th token. Logit: 11.29 Prob: 21.98% Token: |mother|
Top 1th token. Logit: 10.59 Prob: 10.88% Token: |Mother|
Top 2th token. Logit: 10.58 Prob: 10.80% Token: |moth|
Top 3th token. Logit:  9.55 Prob:  3.86% Token: |матери|
Top 4th token. Logit:  9.37 Prob:  3.21% Token: |�|
Top 5th token. Logit:  8.83 Prob:  1.87% Token: |μ|
Top 6th token. Logit:  8.81 Prob:  1.83% Token: |“|
Top 7th token. Logit:  8.80 Prob:  1.81% Token: |<0xE7>|
Top 8th token. Logit:  8.68 Prob:  1.60% Token: |madre|
Top 9th token. Logit:  8.64 Prob:  1.54% Token: |母|


## Unembed actvs

In [None]:
prompts = ["The word for a male parent is ",
           "The word for a female parent is "]
tokens = model.to_tokens(prompts, prepend_bos=True)

model.reset_hooks(including_permanent=True)
original_logits, cache = model.run_with_cache(tokens)
original_logits.shape

torch.Size([2, 9, 32000])

In [None]:
# last_token_logits = original_logits[:, -1, :]
# values, indices = torch.topk(last_token_logits, 5, dim = -1)
# for token_id in indices[0]:
#     print(model.tokenizer.decode(token_id.item()))

In [None]:
# del(orignal_logits)
# del(last_token_logits)
# del values
# del indices

In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][0, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

short
small
not
...
.


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][1, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

long
tall
t
...
not


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][:, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices[0]:
    print(model.tokenizer.decode(token_id.item()))

short
small
not
...
.


In [None]:
list(cache)[0:23]

['hook_embed',
 'blocks.0.hook_resid_pre',
 'blocks.0.ln1.hook_scale',
 'blocks.0.ln1.hook_normalized',
 'blocks.0.attn.hook_q',
 'blocks.0.attn.hook_k',
 'blocks.0.attn.hook_v',
 'blocks.0.attn.hook_rot_q',
 'blocks.0.attn.hook_rot_k',
 'blocks.0.attn.hook_attn_scores',
 'blocks.0.attn.hook_pattern',
 'blocks.0.attn.hook_z',
 'blocks.0.hook_attn_out',
 'blocks.0.hook_resid_mid',
 'blocks.0.ln2.hook_scale',
 'blocks.0.ln2.hook_normalized',
 'blocks.0.mlp.hook_pre',
 'blocks.0.mlp.hook_pre_linear',
 'blocks.0.mlp.hook_post',
 'blocks.0.hook_mlp_out',
 'blocks.0.hook_resid_post',
 'blocks.1.hook_resid_pre',
 'blocks.1.ln1.hook_scale']

In [None]:
list(cache)[-23:]

['blocks.30.hook_resid_post',
 'blocks.31.hook_resid_pre',
 'blocks.31.ln1.hook_scale',
 'blocks.31.ln1.hook_normalized',
 'blocks.31.attn.hook_q',
 'blocks.31.attn.hook_k',
 'blocks.31.attn.hook_v',
 'blocks.31.attn.hook_rot_q',
 'blocks.31.attn.hook_rot_k',
 'blocks.31.attn.hook_attn_scores',
 'blocks.31.attn.hook_pattern',
 'blocks.31.attn.hook_z',
 'blocks.31.hook_attn_out',
 'blocks.31.hook_resid_mid',
 'blocks.31.ln2.hook_scale',
 'blocks.31.ln2.hook_normalized',
 'blocks.31.mlp.hook_pre',
 'blocks.31.mlp.hook_pre_linear',
 'blocks.31.mlp.hook_post',
 'blocks.31.hook_mlp_out',
 'blocks.31.hook_resid_post',
 'ln_final.hook_scale',
 'ln_final.hook_normalized']

In [None]:
for layer_name in list(cache)[1:21]:
    print( '.'.join(layer_name.split('.')[2:]) )

hook_resid_pre
ln1.hook_scale
ln1.hook_normalized
attn.hook_q
attn.hook_k
attn.hook_v
attn.hook_rot_q
attn.hook_rot_k
attn.hook_attn_scores
attn.hook_pattern
attn.hook_z
hook_attn_out
hook_resid_mid
ln2.hook_scale
ln2.hook_normalized
mlp.hook_pre
mlp.hook_pre_linear
mlp.hook_post
hook_mlp_out
hook_resid_post


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_mlp_out'][:, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for samp_num in range(len(indices)):
        for token_id in indices[samp_num]:
            print('layer', layer, 'samp', samp_num, model.tokenizer.decode(token_id.item()))

layer 0 samp 0 in
layer 0 samp 1 in
layer 1 samp 0 Pil
layer 1 samp 1 itto
layer 2 samp 0 igny
layer 2 samp 1 igny
layer 3 samp 0 nor
layer 3 samp 1 nor
layer 4 samp 0 yen
layer 4 samp 1 yen
layer 5 samp 0 eval
layer 5 samp 1 eval
layer 6 samp 0 Pas
layer 6 samp 1 Pas
layer 7 samp 0 rico
layer 7 samp 1 rico
layer 8 samp 0 igu
layer 8 samp 1 он
layer 9 samp 0 aste
layer 9 samp 1 aste
layer 10 samp 0 lobal
layer 10 samp 1 lobal
layer 11 samp 0 erca
layer 11 samp 1 erca
layer 12 samp 0 ette
layer 12 samp 1 ette
layer 13 samp 0 Prin
layer 13 samp 1 wa
layer 14 samp 0 oli
layer 14 samp 1 Sah
layer 15 samp 0 agr
layer 15 samp 1 antes
layer 16 samp 0 óg
layer 16 samp 1 chev
layer 17 samp 0 adr
layer 17 samp 1 <0x9D>
layer 18 samp 0 idé
layer 18 samp 1 Spo
layer 19 samp 0 multicol
layer 19 samp 1 engo
layer 20 samp 0 parents
layer 20 samp 1 parent
layer 21 samp 0 orf
layer 21 samp 1 estamp
layer 22 samp 0 aires
layer 22 samp 1 Woman
layer 23 samp 0 par
layer 23 samp 1 female
layer 24 samp 0 ⚭


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for samp_num in range(len(indices)):
        for token_id in indices[samp_num]:
            print('layer', layer, 'samp', samp_num, model.tokenizer.decode(token_id.item()))

layer 0 samp 0 in
layer 0 samp 1 in
layer 1 samp 0 in
layer 1 samp 1 in
layer 2 samp 0 on
layer 2 samp 1 on
layer 3 samp 0 to
layer 3 samp 1 to
layer 4 samp 0 ur
layer 4 samp 1 ur
layer 5 samp 0 kin
layer 5 samp 1 kin
layer 6 samp 0 Ur
layer 6 samp 1 Ur
layer 7 samp 0 par
layer 7 samp 1 par
layer 8 samp 0 =>
layer 8 samp 1 par
layer 9 samp 0 lat
layer 9 samp 1 lat
layer 10 samp 0 Answer
layer 10 samp 1 御
layer 11 samp 0 sem
layer 11 samp 1 sem
layer 12 samp 0 eren
layer 12 samp 1 sem
layer 13 samp 0 tou
layer 13 samp 1 tou
layer 14 samp 0 Bedeut
layer 14 samp 1 tou
layer 15 samp 0 Bedeut
layer 15 samp 1 Bedeut
layer 16 samp 0 upt
layer 16 samp 1 upt
layer 17 samp 0 _
layer 17 samp 1 _
layer 18 samp 0 parent
layer 18 samp 1 parent
layer 19 samp 0 father
layer 19 samp 1 parents
layer 20 samp 0 parents
layer 20 samp 1 parent
layer 21 samp 0 parent
layer 21 samp 1 parent
layer 22 samp 0 parent
layer 22 samp 1 parent
layer 23 samp 0 parents
layer 23 samp 1 parent
layer 24 samp 0 parents
lay

In [None]:
del last_token_actvs
del unembed_last_token_actvs
del(values)
del(indices)

In [None]:
# del cache

## Unembed their activation differences

In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][0, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][1, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 0 ode
layer 1 cum
layer 2 urst
layer 3 ieri
layer 4 externas
layer 5 himself
layer 6 CTYPE
layer 7 itt
layer 8 pent
layer 9 równ
layer 10 elt
layer 11 aud
layer 12 aud
layer 13 Cad
layer 14 cart
layer 15 сок
layer 16 son
layer 17 Mr
layer 18 Mr
layer 19 mascul
layer 20 ♂
layer 21 father
layer 22 father
layer 23 father
layer 24 father
layer 25 Father
layer 26 father
layer 27 father
layer 28 d
layer 29 d
layer 30 d
layer 31 d


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][1, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][0, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 0 <s>
layer 1 <s>
layer 2 dest
layer 3 <0xB3>
layer 4 uter
layer 5 ysz
layer 6 BF
layer 7 ovo
layer 8 enders
layer 9 bez
layer 10 onces
layer 11 Squad
layer 12 Overflow
layer 13 zast
layer 14 ‍
layer 15 ♀
layer 16 ♀
layer 17 woman
layer 18 女
layer 19 女
layer 20 sister
layer 21 women
layer 22 sister
layer 23 woman
layer 24 sister
layer 25 sister
layer 26 sister
layer 27 mat
layer 28 M
layer 29 M
layer 30 m
layer 31 Sister


# Opposites

## test prompts

In [None]:
example_prompt = "The opposite of left is right. The opposite of tall is"
example_answer = " short"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.


Tokenized prompt: ['<s>', 'The', 'opposite', 'of', 'left', 'is', 'right', '.', 'The', 'opposite', 'of', 'tall', 'is']
Tokenized answer: ['', 'short']


Top 0th token. Logit: 24.54 Prob: 99.06% Token: |short|
Top 1th token. Logit: 19.81 Prob:  0.87% Token: |small|
Top 2th token. Logit: 15.83 Prob:  0.02% Token: |short|
Top 3th token. Logit: 15.15 Prob:  0.01% Token: |Short|
Top 4th token. Logit: 14.97 Prob:  0.01% Token: |...|
Top 5th token. Logit: 14.85 Prob:  0.01% Token: |not|
Top 6th token. Logit: 14.71 Prob:  0.01% Token: |shorter|
Top 7th token. Logit: 14.70 Prob:  0.01% Token: |low|
Top 8th token. Logit: 14.23 Prob:  0.00% Token: |.|
Top 9th token. Logit: 13.97 Prob:  0.00% Token: |...|


Top 0th token. Logit: 14.26 Prob: 65.06% Token: |short|
Top 1th token. Logit: 12.87 Prob: 16.18% Token: |<0x0A>|
Top 2th token. Logit: 10.72 Prob:  1.90% Token: |short|
Top 3th token. Logit: 10.46 Prob:  1.46% Token: |�|
Top 4th token. Logit: 10.16 Prob:  1.08% Token: |................|
Top 5th token. Logit: 10.01 Prob:  0.93% Token: |<0xF0>|
Top 6th token. Logit:  9.70 Prob:  0.68% Token: |<0x09>|
Top 7th token. Logit:  9.68 Prob:  0.67% Token: |not|
Top 8th token. Logit:  9.51 Prob:  0.57% Token: |<0xE7>|
Top 9th token. Logit:  9.50 Prob:  0.56% Token: |Short|


In [None]:
example_prompt = "The opposite of left is right. The opposite of tall is"
example_answer = "short"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=False)

Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.
Keyword arguments {'add_special_tokens': False} not recognized.


Tokenized prompt: ['The', 'opposite', 'of', 'left', 'is', 'right', '.', 'The', 'opposite', 'of', 'tall', 'is']
Tokenized answer: ['', 'short']


Top 0th token. Logit: 23.67 Prob: 98.79% Token: |short|
Top 1th token. Logit: 19.11 Prob:  1.04% Token: |small|
Top 2th token. Logit: 15.66 Prob:  0.03% Token: |not|
Top 3th token. Logit: 15.12 Prob:  0.02% Token: |...|
Top 4th token. Logit: 14.93 Prob:  0.02% Token: |short|
Top 5th token. Logit: 14.57 Prob:  0.01% Token: |low|
Top 6th token. Logit: 14.39 Prob:  0.01% Token: |Short|
Top 7th token. Logit: 14.34 Prob:  0.01% Token: |.|
Top 8th token. Logit: 14.30 Prob:  0.01% Token: |shorter|
Top 9th token. Logit: 13.74 Prob:  0.00% Token: |what|


Top 0th token. Logit: 14.95 Prob: 69.32% Token: |short|
Top 1th token. Logit: 13.52 Prob: 16.68% Token: |<0x0A>|
Top 2th token. Logit: 10.78 Prob:  1.07% Token: |not|
Top 3th token. Logit: 10.66 Prob:  0.96% Token: |�|
Top 4th token. Logit: 10.54 Prob:  0.85% Token: |short|
Top 5th token. Logit: 10.53 Prob:  0.84% Token: |................|
Top 6th token. Logit: 10.14 Prob:  0.57% Token: |5|
Top 7th token. Logit: 10.04 Prob:  0.52% Token: |small|
Top 8th token. Logit:  9.79 Prob:  0.40% Token: |Short|
Top 9th token. Logit:  9.69 Prob:  0.36% Token: |low|


## Unembed actvs

In [None]:
prompts = ["The opposite of left is right. The opposite of tall is",
           "The opposite of left is right. The opposite of short is"]
tokens = model.to_tokens(prompts, prepend_bos=True)

model.reset_hooks(including_permanent=True)
original_logits, cache = model.run_with_cache(tokens)
original_logits.shape

torch.Size([2, 13, 32000])

In [None]:
# last_token_logits = original_logits[:, -1, :]
# values, indices = torch.topk(last_token_logits, 5, dim = -1)
# for token_id in indices[0]:
#     print(model.tokenizer.decode(token_id.item()))

short
small
short
Short
...


In [None]:
# del(orignal_logits)
# del(last_token_logits)
# del values
# del indices

In [None]:
cache

In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][0, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

short
small
not
...
.


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][1, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

long
tall
t
...
not


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][:, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices[0]:
    print(model.tokenizer.decode(token_id.item()))

short
small
not
...
.


In [None]:
# del last_token_actvs
# del unembed_last_token_actvs
# del(values)
# del(indices)

In [None]:
# del cache

In [None]:
# import sys

# def get_size(obj, seen=None):
#     """Recursively finds size of objects in bytes"""
#     size = sys.getsizeof(obj)
#     if seen is None:
#         seen = set()
#     obj_id = id(obj)
#     if obj_id in seen:
#         return 0
#     seen.add(obj_id)
#     if isinstance(obj, dict):
#         size += sum([get_size(v, seen) for v in obj.values()])
#         size += sum([get_size(k, seen) for k in obj.keys()])
#     elif hasattr(obj, '__dict__'):
#         size += get_size(obj.__dict__, seen)
#     elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
#         size += sum([get_size(i, seen) for i in obj])
#     return size

# for name, obj in sorted(globals().items(), key=lambda x: get_size(x[1]), reverse=True):
#     print(f"{name}: {get_size(obj) / (1024**2):.2f} MB")


# Animal vs cat

## test prompts

In [None]:
example_prompt = "A fern is a plant. A rat is an "
example_answer = "animal"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'A', 'f', 'ern', 'is', 'a', 'plant', '.', 'A', 'rat', 'is', 'an', '']
Tokenized answer: ['', 'animal']


Top 0th token. Logit: 15.21 Prob: 82.11% Token: |animal|
Top 1th token. Logit: 12.32 Prob:  4.57% Token: |<0x0A>|
Top 2th token. Logit: 11.93 Prob:  3.09% Token: |ounce|
Top 3th token. Logit: 11.14 Prob:  1.41% Token: |rod|
Top 4th token. Logit: 10.65 Prob:  0.86% Token: |8|
Top 5th token. Logit: 10.59 Prob:  0.81% Token: |A|
Top 6th token. Logit: 10.29 Prob:  0.60% Token: |o|
Top 7th token. Logit: 10.29 Prob:  0.60% Token: |igu|
Top 8th token. Logit:  9.73 Prob:  0.34% Token: |1|
Top 9th token. Logit:  9.49 Prob:  0.27% Token: |rat|


Top 0th token. Logit: 15.13 Prob: 79.22% Token: |animal|
Top 1th token. Logit: 12.33 Prob:  4.78% Token: |<0x0A>|
Top 2th token. Logit: 11.69 Prob:  2.53% Token: |A|
Top 3th token. Logit: 11.58 Prob:  2.27% Token: |rod|
Top 4th token. Logit: 10.92 Prob:  1.18% Token: |ounce|
Top 5th token. Logit: 10.41 Prob:  0.70% Token: |an|
Top 6th token. Logit: 10.29 Prob:  0.63% Token: |igu|
Top 7th token. Logit: 10.23 Prob:  0.59% Token: |m|
Top 8th token. Logit: 10.18 Prob:  0.56% Token: |ex|
Top 9th token. Logit:  9.99 Prob:  0.46% Token: |8|


In [None]:
example_prompt = "A canine animal is called a "
example_answer = "dog"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'A', 'can', 'ine', 'animal', 'is', 'called', 'a', '']
Tokenized answer: ['', 'dog']


Top 0th token. Logit: 11.68 Prob: 34.36% Token: |dog|
Top 1th token. Logit: 10.74 Prob: 13.44% Token: |<0x0A>|
Top 2th token. Logit:  9.56 Prob:  4.13% Token: |................|
Top 3th token. Logit:  9.23 Prob:  2.97% Token: |can|
Top 4th token. Logit:  9.07 Prob:  2.53% Token: |<0xF0>|
Top 5th token. Logit:  8.91 Prob:  2.17% Token: |�|
Top 6th token. Logit:  8.77 Prob:  1.88% Token: |_|
Top 7th token. Logit:  8.43 Prob:  1.33% Token: |what|
Top 8th token. Logit:  8.28 Prob:  1.16% Token: |��|
Top 9th token. Logit:  8.12 Prob:  0.98% Token: |dog|


Top 0th token. Logit: 11.55 Prob: 32.37% Token: |dog|
Top 1th token. Logit: 10.48 Prob: 11.09% Token: |<0x0A>|
Top 2th token. Logit:  9.55 Prob:  4.37% Token: |................|
Top 3th token. Logit:  9.22 Prob:  3.14% Token: |_|
Top 4th token. Logit:  8.87 Prob:  2.21% Token: |what|
Top 5th token. Logit:  8.66 Prob:  1.80% Token: |<0xF0>|
Top 6th token. Logit:  8.57 Prob:  1.65% Token: |(|
Top 7th token. Logit:  8.42 Prob:  1.41% Token: |�|
Top 8th token. Logit:  8.33 Prob:  1.29% Token: |can|
Top 9th token. Logit:  8.22 Prob:  1.16% Token: |and|


In [None]:
example_prompt = "A feline animal is called a "
example_answer = "cat"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'A', 'f', 'eline', 'animal', 'is', 'called', 'a', '']
Tokenized answer: ['', 'cat']


Top 0th token. Logit: 11.78 Prob: 34.66% Token: |cat|
Top 1th token. Logit: 11.22 Prob: 19.84% Token: |<0x0A>|
Top 2th token. Logit:  9.46 Prob:  3.41% Token: |�|
Top 3th token. Logit:  9.34 Prob:  3.04% Token: |................|
Top 4th token. Logit:  9.26 Prob:  2.79% Token: |domestic|
Top 5th token. Logit:  9.00 Prob:  2.16% Token: |<0xE7>|
Top 6th token. Logit:  8.96 Prob:  2.06% Token: |_|
Top 7th token. Logit:  8.66 Prob:  1.53% Token: |.|
Top 8th token. Logit:  8.64 Prob:  1.50% Token: |<0xF0>|
Top 9th token. Logit:  8.49 Prob:  1.30% Token: |<0xE8>|


Top 0th token. Logit: 11.48 Prob: 28.16% Token: |cat|
Top 1th token. Logit: 11.10 Prob: 19.25% Token: |<0x0A>|
Top 2th token. Logit:  9.60 Prob:  4.32% Token: |_|
Top 3th token. Logit:  9.57 Prob:  4.18% Token: |domestic|
Top 4th token. Logit:  9.21 Prob:  2.92% Token: |................|
Top 5th token. Logit:  9.11 Prob:  2.63% Token: |�|
Top 6th token. Logit:  8.78 Prob:  1.89% Token: |“|
Top 7th token. Logit:  8.76 Prob:  1.85% Token: |(|
Top 8th token. Logit:  8.69 Prob:  1.73% Token: |.|
Top 9th token. Logit:  8.30 Prob:  1.18% Token: |<0xE7>|


## Unembed actvs

In [None]:
prompts = ["A fern is a plant. A rat is an ",
           "A feline animal is called a "]
tokens = model.to_tokens(prompts, prepend_bos=True)

model.reset_hooks(including_permanent=True)
original_logits, cache = model.run_with_cache(tokens)
original_logits.shape

torch.Size([2, 13, 32000])

In [None]:
last_token_logits = original_logits[:, -1, :]
values, indices = torch.topk(last_token_logits, 5, dim = -1)
for token_id in indices[1]:
    print(model.tokenizer.decode(token_id.item()))

0
t
.
1
2


In [None]:
# del(orignal_logits)
# del(last_token_logits)
# del values
# del indices

In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][0, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

animal
<0x0A>
ounce
A
rod


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][1, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

0
t
1
.
2


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][:, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices[0]:
    print(model.tokenizer.decode(token_id.item()))

animal
<0x0A>
ounce
A
rod


In [None]:
for layer_name in list(cache)[1:21]:
    print( '.'.join(layer_name.split('.')[2:]) )

hook_resid_pre
ln1.hook_scale
ln1.hook_normalized
attn.hook_q
attn.hook_k
attn.hook_v
attn.hook_rot_q
attn.hook_rot_k
attn.hook_attn_scores
attn.hook_pattern
attn.hook_z
hook_attn_out
hook_resid_mid
ln2.hook_scale
ln2.hook_normalized
mlp.hook_pre
mlp.hook_pre_linear
mlp.hook_post
hook_mlp_out
hook_resid_post


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for samp_num in range(len(indices)):
        for token_id in indices[samp_num]:
            print('layer', layer, 'samp', samp_num, model.tokenizer.decode(token_id.item()))

layer 0 samp 0 in
layer 0 samp 1 <s>
layer 1 samp 0 in
layer 1 samp 1 <s>
layer 2 samp 0 in
layer 2 samp 1 <s>
layer 3 samp 0 to
layer 3 samp 1 <s>
layer 4 samp 0 CD
layer 4 samp 1 <s>
layer 5 samp 0 generalized
layer 5 samp 1 <s>
layer 6 samp 0 oci
layer 6 samp 1 <s>
layer 7 samp 0 asso
layer 7 samp 1 <s>
layer 8 samp 0 oun
layer 8 samp 1 <s>
layer 9 samp 0 para
layer 9 samp 1 <s>
layer 10 samp 0 oci
layer 10 samp 1 <s>
layer 11 samp 0 cho
layer 11 samp 1 <s>
layer 12 samp 0 cho
layer 12 samp 1 <s>
layer 13 samp 0 oun
layer 13 samp 1 <s>
layer 14 samp 0 xt
layer 14 samp 1 <s>
layer 15 samp 0 xt
layer 15 samp 1 <s>
layer 16 samp 0 <0x0A>
layer 16 samp 1 <s>
layer 17 samp 0 <0x0A>
layer 17 samp 1 <s>
layer 18 samp 0 <0x0A>
layer 18 samp 1 <s>
layer 19 samp 0 <0x0A>
layer 19 samp 1 <s>
layer 20 samp 0 animals
layer 20 samp 1 <s>
layer 21 samp 0 animals
layer 21 samp 1 <s>
layer 22 samp 0 rod
layer 22 samp 1 <s>
layer 23 samp 0 animals
layer 23 samp 1 <s>
layer 24 samp 0 rat
layer 24 samp

In [None]:
del last_token_actvs
del unembed_last_token_actvs
del(values)
del(indices)

In [None]:
# del cache

## Unembed their activation differences

In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][0, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][1, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 0 ff
layer 1 ff
layer 2 specified
layer 3 rounded
layer 4 amo
layer 5 lio
layer 6 ibile
layer 7 ibile
layer 8 ibile
layer 9 ibile
layer 10 ibile
layer 11 ibile
layer 12 ibile
layer 13 ibile
layer 14 ibile
layer 15 ibile
layer 16 uda
layer 17 __
layer 18 ibile
layer 19 ibile
layer 20 animals
layer 21 animals
layer 22 rod
layer 23 ounce
layer 24 rat
layer 25 rod
layer 26 ounce
layer 27 ounce
layer 28 ounce
layer 29 ounce
layer 30 rod
layer 31 animal


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][1, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][0, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 0 <s>
layer 1 <s>
layer 2 <s>
layer 3 <s>
layer 4 <s>
layer 5 <s>
layer 6 <s>
layer 7 <s>
layer 8 <s>
layer 9 <s>
layer 10 <s>
layer 11 <s>
layer 12 <s>
layer 13 <s>
layer 14 <s>
layer 15 <s>
layer 16 <s>
layer 17 <s>
layer 18 <s>
layer 19 <s>
layer 20 <s>
layer 21 <s>
layer 22 <s>
layer 23 <s>
layer 24 <s>
layer 25 <s>
layer 26 <s>
layer 27 <s>
layer 28 <s>
layer 29 <s>
layer 30 <s>
layer 31 ال


# animal vs dog vs cat

## test prompts

In [None]:
example_prompt = "A rat is an "
example_answer = "animal"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'A', 'rat', 'is', 'an', '']
Tokenized answer: ['', 'animal']


Top 0th token. Logit: 12.62 Prob: 41.69% Token: |ounce|
Top 1th token. Logit: 11.95 Prob: 21.37% Token: |8|
Top 2th token. Logit: 11.23 Prob: 10.43% Token: |1|
Top 3th token. Logit: 10.90 Prob:  7.48% Token: |animal|
Top 4th token. Logit:  9.27 Prob:  1.46% Token: |2|
Top 5th token. Logit:  9.04 Prob:  1.16% Token: |igu|
Top 6th token. Logit:  8.63 Prob:  0.77% Token: |0|
Top 7th token. Logit:  8.55 Prob:  0.71% Token: |3|
Top 8th token. Logit:  8.52 Prob:  0.69% Token: |rat|
Top 9th token. Logit:  8.41 Prob:  0.62% Token: |icky|


Top 0th token. Logit: 12.04 Prob: 30.55% Token: |ounce|
Top 1th token. Logit: 11.62 Prob: 20.07% Token: |8|
Top 2th token. Logit: 11.12 Prob: 12.15% Token: |1|
Top 3th token. Logit: 11.01 Prob: 10.90% Token: |animal|
Top 4th token. Logit:  9.52 Prob:  2.45% Token: |2|
Top 5th token. Logit:  9.16 Prob:  1.71% Token: |igu|
Top 6th token. Logit:  8.80 Prob:  1.19% Token: |0|
Top 7th token. Logit:  8.68 Prob:  1.06% Token: |3|
Top 8th token. Logit:  8.65 Prob:  1.02% Token: |rat|
Top 9th token. Logit:  8.45 Prob:  0.84% Token: |<0x0A>|


In [None]:
example_prompt = "Fern is a plant. Rat is an "
example_answer = "animal"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'Fern', 'is', 'a', 'plant', '.', 'Rat', 'is', 'an', '']
Tokenized answer: ['', 'animal']


Top 0th token. Logit: 14.43 Prob: 89.24% Token: |animal|
Top 1th token. Logit: 10.49 Prob:  1.74% Token: |ounce|
Top 2th token. Logit:  9.91 Prob:  0.98% Token: |igu|
Top 3th token. Logit:  9.46 Prob:  0.62% Token: |8|
Top 4th token. Logit:  9.21 Prob:  0.49% Token: |<0x0A>|
Top 5th token. Logit:  8.88 Prob:  0.35% Token: |1|
Top 6th token. Logit:  8.88 Prob:  0.35% Token: |igne|
Top 7th token. Logit:  8.83 Prob:  0.33% Token: |animals|
Top 8th token. Logit:  8.82 Prob:  0.33% Token: |an|
Top 9th token. Logit:  8.41 Prob:  0.22% Token: |Animal|


Top 0th token. Logit: 15.30 Prob: 88.22% Token: |animal|
Top 1th token. Logit: 11.07 Prob:  1.29% Token: |igu|
Top 2th token. Logit: 11.05 Prob:  1.25% Token: |<0x0A>|
Top 3th token. Logit: 10.67 Prob:  0.86% Token: |an|
Top 4th token. Logit: 10.63 Prob:  0.83% Token: |ounce|
Top 5th token. Logit: 10.28 Prob:  0.58% Token: |8|
Top 6th token. Logit: 10.15 Prob:  0.51% Token: |1|
Top 7th token. Logit:  9.83 Prob:  0.37% Token: |igne|
Top 8th token. Logit:  9.56 Prob:  0.28% Token: |m|
Top 9th token. Logit:  9.48 Prob:  0.26% Token: |animals|


In [None]:
example_prompt = "Fern is plant. Rat is an "
example_answer = "animal"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<s>', 'Fern', 'is', 'plant', '.', 'Rat', 'is', 'an', '']
Tokenized answer: ['', 'animal']


Top 0th token. Logit: 12.17 Prob: 59.58% Token: |animal|
Top 1th token. Logit: 10.46 Prob: 10.78% Token: |igu|
Top 2th token. Logit:  9.27 Prob:  3.27% Token: |8|
Top 3th token. Logit:  9.13 Prob:  2.86% Token: |ounce|
Top 4th token. Logit:  8.87 Prob:  2.20% Token: |1|
Top 5th token. Logit:  8.46 Prob:  1.46% Token: |igne|
Top 6th token. Logit:  8.35 Prob:  1.30% Token: |icy|
Top 7th token. Logit:  7.95 Prob:  0.87% Token: |excellent|
Top 8th token. Logit:  7.80 Prob:  0.75% Token: |an|
Top 9th token. Logit:  7.33 Prob:  0.47% Token: |<0x0A>|


Top 0th token. Logit: 12.05 Prob: 57.45% Token: |animal|
Top 1th token. Logit: 10.01 Prob:  7.49% Token: |igu|
Top 2th token. Logit:  9.17 Prob:  3.22% Token: |1|
Top 3th token. Logit:  9.07 Prob:  2.92% Token: |8|
Top 4th token. Logit:  8.86 Prob:  2.35% Token: |ounce|
Top 5th token. Logit:  8.39 Prob:  1.47% Token: |igne|
Top 6th token. Logit:  8.12 Prob:  1.12% Token: |an|
Top 7th token. Logit:  8.00 Prob:  1.00% Token: |2|
Top 8th token. Logit:  7.97 Prob:  0.97% Token: |excellent|
Top 9th token. Logit:  7.93 Prob:  0.93% Token: |<0x0A>|


In [None]:
len(model.tokenizer.tokenize('Fern is a plant. Rat is an '))

9

In [None]:
len(model.tokenizer.tokenize('Fern is plant. Rat is an '))

8

In [None]:
len(model.tokenizer.tokenize('A feline animal is called a '))

8

## Unembed actvs

In [None]:
prompts = ["Fern is plant. Rat is an ",
           "A feline animal is called a ",
           "A canine animal is called a "]
tokens = model.to_tokens(prompts, prepend_bos=True)

model.reset_hooks(including_permanent=True)
original_logits, cache = model.run_with_cache(tokens)
original_logits.shape

torch.Size([3, 9, 32000])

In [None]:
last_token_logits = original_logits[:, -1, :]
values, indices = torch.topk(last_token_logits, 5, dim = -1)
for token_id in indices[1]:
    print(model.tokenizer.decode(token_id.item()))

cat
<0x0A>
�
................
domestic


In [None]:
# del(orignal_logits)
# del(last_token_logits)
# del values
# del indices

In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][0, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

animal
<0x0A>
in
igu
8


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][1, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices:
    print(model.tokenizer.decode(token_id.item()))

<0x0A>
cat
_
�
domestic


In [None]:
last_token_actvs = cache['ln_final.hook_normalized'][:, -1, :]
last_token_actvs = last_token_actvs.unsqueeze(0)
unembed_last_token_actvs = model.unembed(last_token_actvs)
unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
values, indices = torch.topk(unembed_last_token_actvs, 5, dim = -1)
for token_id in indices[0]:
    print(model.tokenizer.decode(token_id.item()))

animal
<0x0A>
in
igu
8


In [None]:
for layer_name in list(cache)[1:21]:
    print( '.'.join(layer_name.split('.')[2:]) )

hook_resid_pre
ln1.hook_scale
ln1.hook_normalized
attn.hook_q
attn.hook_k
attn.hook_v
attn.hook_rot_q
attn.hook_rot_k
attn.hook_attn_scores
attn.hook_pattern
attn.hook_z
hook_attn_out
hook_resid_mid
ln2.hook_scale
ln2.hook_normalized
mlp.hook_pre
mlp.hook_pre_linear
mlp.hook_post
hook_mlp_out
hook_resid_post


In [None]:
for layer in range(32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for samp_num in range(len(indices)):
        for token_id in indices[samp_num]:
            print('layer', layer, 'samp', samp_num, model.tokenizer.decode(token_id.item()))

layer 0 samp 0 in
layer 0 samp 1 in
layer 0 samp 2 in
layer 1 samp 0 in
layer 1 samp 1 in
layer 1 samp 2 in
layer 2 samp 0 in
layer 2 samp 1 (
layer 2 samp 2 (
layer 3 samp 0 in
layer 3 samp 1 (
layer 3 samp 2 (
layer 4 samp 0 CD
layer 4 samp 1 fen
layer 4 samp 2 fen
layer 5 samp 0 generalized
layer 5 samp 1 alberga
layer 5 samp 2 alberga
layer 6 samp 0 arr
layer 6 samp 1 asso
layer 6 samp 2 asso
layer 7 samp 0 op
layer 7 samp 1 <0xB6>
layer 7 samp 2 asso
layer 8 samp 0 ico
layer 8 samp 1 rip
layer 8 samp 2 asso
layer 9 samp 0 Alter
layer 9 samp 1 ,\,
layer 9 samp 2 aum
layer 10 samp 0 equivalent
layer 10 samp 1 hagen
layer 10 samp 2 äl
layer 11 samp 0 English
layer 11 samp 1 fen
layer 11 samp 2 äl
layer 12 samp 0 iks
layer 12 samp 1 äl
layer 12 samp 2 äl
layer 13 samp 0 <0x85>
layer 13 samp 1 alberga
layer 13 samp 2 Bedeut
layer 14 samp 0 ia
layer 14 samp 1 dw
layer 14 samp 2 Bedeut
layer 15 samp 0 in
layer 15 samp 1 Bedeut
layer 15 samp 2 Bedeut
layer 16 samp 0 roid
layer 16 samp 1 _

In [None]:
del last_token_actvs
del unembed_last_token_actvs
del(values)
del(indices)

In [None]:
# del cache

## Unembed their activation differences

In [None]:
# animal - cat
for layer in range(15, 32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][0, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][1, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 15 ther
layer 16 Register
layer 17 Er
layer 18 ROR
layer 19 ayer
layer 20 ayer
layer 21 Rat
layer 22 ord
layer 23 arius
layer 24 Rat
layer 25 Rat
layer 26 Rat
layer 27 Rat
layer 28 animal
layer 29 animal
layer 30 om
layer 31 animal


In [None]:
# animal - dog
for layer in range(15, 32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][0, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][2, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 15 exc
layer 16 zer
layer 17 Er
layer 18 ength
layer 19 Symbol
layer 20 Symbol
layer 21 Rat
layer 22 ayer
layer 23 rod
layer 24 rod
layer 25 rod
layer 26 rod
layer 27 animal
layer 28 agr
layer 29 animal
layer 30 ounce
layer 31 animal


In [None]:
# cat- animal
for layer in range(15, 32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][1, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][0, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 15 lint
layer 16 cic
layer 17 ________
layer 18 ________
layer 19 cat
layer 20 cat
layer 21 cat
layer 22 cat
layer 23 cat
layer 24 cat
layer 25 cat
layer 26 cat
layer 27 cat
layer 28 cat
layer 29 cat
layer 30 CAT
layer 31 CAT


In [None]:
# cat - dog
for layer in range(15, 32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][1, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][2, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 15 Cant
layer 16 Ba
layer 17 Cat
layer 18 зе
layer 19 cat
layer 20 cat
layer 21 cat
layer 22 cat
layer 23 cat
layer 24 cat
layer 25 cat
layer 26 cat
layer 27 cat
layer 28 cat
layer 29 cat
layer 30 cat
layer 31 fel


In [None]:
# dog - cat
for layer in range(15, 32):
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][:, -1, :]
    last_token_actvs = cache[f'blocks.{layer}.hook_resid_post'][2, -1, :] - cache[f'blocks.{layer}.hook_resid_post'][1, -1, :]
    last_token_actvs = last_token_actvs.unsqueeze(0).unsqueeze(0)
    unembed_last_token_actvs = model.unembed(last_token_actvs)
    unembed_last_token_actvs = unembed_last_token_actvs.squeeze()
    values, indices = torch.topk(unembed_last_token_actvs, 1, dim = -1)
    for token_id in indices:
        print('layer', layer, model.tokenizer.decode(token_id.item()))

layer 15 données
layer 16 Unterscheidung
layer 17 ispecies
layer 18 Jahrh
layer 19 dog
layer 20 dog
layer 21 dog
layer 22 dog
layer 23 dog
layer 24 dog
layer 25 dog
layer 26 dog
layer 27 dog
layer 28 dog
layer 29 dog
layer 30 dog
layer 31 dog
