In [1]:
import warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
# from dotenv import load_dotenv
from omegaconf import OmegaConf
from tqdm import tqdm
from transformers import AutoTokenizer, GenerationConfig
from transformers import logging as transformers_logging

from transformer_lens import HookedTransformer

from prompts import w_context_user_prompt, w_context_system_prompt, wo_context_system_prompt
from utils import heatmap_uncertainty
from dataset import (
    NQDataset,
    WikiMultiHopQADataset, 
    HotPotQADataset, 
    MusiqueDataset
)

# load_dotenv()
warnings.filterwarnings('ignore')
transformers_logging.set_verbosity_error()
pd.set_option('display.max_columns', None)

In [2]:
nq_dataset = NQDataset().data
wiki_multi_dataset = WikiMultiHopQADataset().data
hot_pot_dataset = HotPotQADataset().data
musique_dataset = MusiqueDataset().data

In [5]:
cfg = OmegaConf.load("config.yaml")
sampling_params = GenerationConfig(**cfg.generation_config)
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(cfg.model_id)
model = HookedTransformer.from_pretrained(cfg.model_id, device=device)
model.set_use_attn_result(True)
model.eval()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.2-3B-Instruct into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_att

In [6]:
question = 'Who won season 10 of dancing with the stars?'
context = (
    'The story, illustrated by the author, is set in England as the Black Death '
    '(bubonic plague) is sweeping across the country. Young Robin is sent away to become '
    'a knight like his father, but his dreams are endangered when he loses the use of his legs. '
    'A doctor reassures Robin that the weakness in his legs is not caused by the plague and the '
    'doctor is supposed to come and help him but does not. His parents are away, serving the king '
    'and queen during war, and the servants abandon the house, fearing the plague. '
    'Robin is saved by Brother Luke, a friar, who finds him and takes him to a monastery and cares for him.'
)

messages = [
    {"role": "user", "content": w_context_user_prompt(question, context)}, 
    {"role": "system", "content": w_context_system_prompt()}
]

text = tokenizer.apply_chat_template(
    messages, 
    tokenize=False, 
    add_generation_prompt=True
)

In [7]:
tokens = model.to_tokens(text, prepend_bos=False).to(device)

In [51]:
def store_activations(activation_dict):
    def hook_fn(activation, hook):
        activation_dict[hook.name] = activation.detach()
    return hook_fn

def get_logit_lens(activation_dict, model):
    """
    Given stored activations from each layer (residual stream),
    apply ln_final and unembed to get next-token predictions per layer.
    """
    logit_lens = {}
    n_layers = model.cfg.n_layers
    for i in range(n_layers):
        resid = activation_dict[f"blocks.{i}.hook_resid_post"] # shape: [batch, seq_len, d_model]
        
        # Take the last token embedding
        last_token_resid = resid[:, -1, :]
        normalized = model.ln_final(last_token_resid)
        logits = model.unembed(normalized)
        
        next_token_id = torch.argmax(logits, dim=-1)
        next_token_str = model.to_str_tokens(next_token_id)
        logit_lens[i] = next_token_str[0]
    return logit_lens

In [52]:
num_new_tokens = 12
all_logit_lens = {}

for step in range(num_new_tokens):
    activation_store = {}

    # Run model with hooks to capture residual activations
    logits = model.run_with_hooks(
        tokens,
        return_type="logits",
        fwd_hooks=[(f"blocks.{i}.hook_resid_post", store_activations(activation_store)) 
                   for i in range(model.cfg.n_layers)]
    )

    # Get the final next-token prediction from the model
    final_logits = logits[:, -1, :]
    next_token_id = torch.argmax(final_logits, dim=-1, keepdim=True)
    # Store the logit lens for this step
    all_logit_lens[step] = get_logit_lens(activation_store, model)

    # Append the predicted token to the input for the next iteration
    tokens = torch.cat([tokens, next_token_id], dim=-1)

In [53]:
df = pd.DataFrame(all_logit_lens)

In [54]:
df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11
0,<|start_header_id|>,assistant,ipl,\n\n,I,do,not,know,.,<|eot_id|>,<|start_header_id|>,assistant
1,aurant,assistant,nod,unb,I,do,not,know,.,pl,aurant,assistant
2,nev,assistant,brief,nish,I,do,not,know,minster,sp,nev,assistant
3,nev,assistant,unavailable,safer,I,do,necessarily,know,ioc,Emerson,nev,assistant
4,alist,assistant,referred,sucked,I,/do,not,know,st,sp,alist,assistant
5,hm,pter,pter,democrat,/we,/do,LPC,Hour,ves,nod,hm,pter
6,aret,aid,tut,chin,イク,substance,yet,-how,Helpful,Consumer,aret,aid
7,Bates,stin,pars,parad,Maiden,kapit,yet,yet,/or,neh,Bates,stin
8,COND,�,amax,apter,emann,resse,pr,disag,rebound,dol,COND,�
9,correct,reporters,limited,sup,andid,diffuse,honor,engu,rebound,ellow,contrad,apel


In [55]:
df.T

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27
0,<|start_header_id|>,aurant,nev,nev,alist,hm,aret,Bates,COND,correct,dumping,indexed,dost,iostream,ele,abstract,cce,cce,Fry,nar,nar,Hast,Assistant,assistant,assistant,assistant,assistant,assistant
1,assistant,assistant,assistant,assistant,assistant,pter,aid,stin,�,reporters,apel,ăr,สง,Past,again,thr,że,pant,pant,nar,autom,system,system,Robbins,<|end_header_id|>,<|end_header_id|>,<|end_header_id|>,<|end_header_id|>
2,ipl,nod,brief,unavailable,referred,pter,tut,pars,amax,limited,dumping,Clar,clar,ele,ele,pr,priv,stil,stil,still,seg,stil,stil,still,\n\n,\n\n,\n\n,\n\n
3,\n\n,unb,nish,safer,sucked,democrat,chin,parad,apter,sup,te,temp,egen,paper,Answer,Answer,Answer,Answer,I,I,I,I,I,I,I,I,I,I
4,I,I,I,I,I,/we,イク,Maiden,emann,andid,dit,Narr,egen,paper,'d,never,couldn,cannot,cannot,cannot,cannot,cannot,do,do,do,do,do,do
5,do,do,do,do,/do,/do,substance,kapit,resse,diffuse,yen,sill,Hurt,clare,manner,not,indeed,not,indeed,not,not,not,not,not,not,not,not,not
6,not,not,not,necessarily,not,LPC,yet,yet,pr,honor,chal,chal,/no,Past,/no,relevant,relevant,know,know,know,know,know,know,know,know,know,know,know
7,know,know,know,know,know,Hour,-how,yet,disag,engu,anton,flag,/no,/no,/no,walk,Dollars,facts,about,about,about,about,about,about,about,.,.,.
8,.,.,minster,ioc,st,ves,Helpful,/or,rebound,rebound,dumping,Mess,keyed,Skip,finally,dart,dart,dart,<|end_of_text|>,actual,actual,Frag,dancing,<|eot_id|>,<|eot_id|>,<|eot_id|>,<|eot_id|>,<|eot_id|>
9,<|eot_id|>,pl,sp,Emerson,sp,nod,Consumer,neh,dol,ellow,icter,ifo,peg,vir,dollars,mentioned,Furniture,Furniture,Furniture,lx,lx,particul,particul,Furniture,inan,<|start_header_id|>,<|start_header_id|>,<|start_header_id|>
