In [16]:
import os
import sys
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
import functools
from tqdm import tqdm
from IPython.display import display
from transformer_lens.hook_points import HookPoint
from transformer_lens import (
    utils,
    HookedTransformer,
    HookedTransformerConfig,
    FactoredMatrix,
    ActivationCache,
)
import circuitsvis as cv

# Make sure exercises are in the path
chapter = r"chapter1_transformer_interp"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part2_intro_to_mech_interp"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

from plotly_utils import imshow, hist, plot_comp_scores, plot_logit_attribution, plot_loss_difference
from part1_transformer_from_scratch.solutions import get_log_probs
import part2_intro_to_mech_interp.tests as tests

# Saves computation time, since we don't need it f
# or the contents of this notebook
t.set_grad_enabled(False)

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')

MAIN = __name__ == "__main__"

# requirements: 
# numpy
# einops
# jaxtyping # more numerical computing
# tqdm
# transformer_lens
# circuitsvis
# plotly

## objective / outline / questions

MI research goal: can we take trained model and reverse engineer the algo model has learned during training from its weights? 
(data science: the work of understanding the science, of learning insights/processing information using data)
(can we make sense of the features, patterns it has learnt from data and how does it make its predictions)
- 

(finding circuits, for a well defined, linguistic task ;)

what is it doing, how is it processing the features it has learnt. 

### example: understanding mechanism of in context learning via induction circuits

#### what are induction heads 
- investigate circuit via activation patterns: 
    - when model input is a repeating sequence: we observe specific attention activation patterns
    - QN: what / why is this a 'circuit'? 

#### can we reverse engineer induction circuits
- looking at transformer weights 
- 'gold standard of interpretability': 
    - examine QK, OV circuits by multiplying 
    - for evidence of composition between 2 induction heads
    - what is the functionality of full circuit formed from this composition? 

- train transformer -- using sequenced data; 
    - we can also observe that it learns that induction circuit (only at / after a certain point in training.. when it moves from memorisation to generalisation? has effectively learnt the 'principle' of the task?)
    - now able to generalise to new data, of similar task 

#### notes: 
- what do we know so far: how does a model "respond" to a prompt 
    - (based on whats it has learnt from training distribution; eg. an induction mechanism / head)
    - eg. induction circuit : (observed even in js a 2L transformer); pre token head + induction head 
        - wait what exactly is this induction head again? 
        - but conclusion is: found the relevant components - OV circuit (layer0), QK circuit (layer1) of attention patterns; that form / causes behaviour of this induction mechanism -> that allows behaviour (task of incontext learning)
            - via logit attribuion / ablation 
    - eg. circuit for the task of : indirect object identification
    - therefore upon receiving a prompt of similar type - activates the relevant learnt parameters for that task (which could be distributed across the network, eg. multiple )
    - how might we seek to identify / locate these parameters? (finding the relevant parameters, for a class of inputs)
        - 1. curate that class of inputs (for a well defined task etc) / concept?
            - with the goal to activate relevant parameters (for our target task/ concept)
            - learnt by model, from its training data distribution
            - ideally? extract data of relevant concept/task from the model
        
        - problem: concepts are broad, difficult to capture from large training dataset too?
            - yet also, in theory, we are expecting that the model has learnt to generalise to new tasks?
            - (large pretrained model) are meant to have learnt well..
            - hence reasonable - to take a separate class of input (but representative of our target concept)
                - perhaps - ask model to generate prompts for that concept: 
                - to try to analyse, what has the model learnt about the concept of 'harm'
                - how does the model process the idea of 'harm'
        
        - given prompt: based on what it has learnt from training - 
            - what/where are the harm related features it learnt
            - "a circuit for the way it processes harm"?
            - can we tune that? (tune the weights)
    - 
        - (i) feature analysis: (via, viz we observe patterns); 
    
- a way of explaining in context learning? 
    - in-context learning: ability to do few shot learning 
    - model's ability to adapt to inputs that are not part of training distribution
    - evidence of its 'generalising' ability?


#### transformer-lens 
- what are hooks: 

#### questions 
- how effective / reliable is it to disentangle / interpret "what is learnt from data" based on activations ?

- ought to be understood in context of the input prompt / distribution

- activations as a proxy for parameters (learned weights from data distribution)

- 

In [19]:
gpt2_small: HookedTransformer = HookedTransformer.from_pretrained("gpt2-small")
## typehinting



Loaded pretrained model gpt2-small into HookedTransformer


In [24]:
gpt2_small.cfg.n_layers
gpt2_small

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [None]:
## tokenizers 
print(gpt2_small.to_str_tokens("gpt2"))
print(gpt2_small.to_str_tokens(["gpt2", "gpt2"]))
print(gpt2_small.to_tokens("gpt2"))
print(gpt2_small.to_string([50256, 70, 457, 17]))

In [53]:

model_description_text = '''## Loading Models

HookedTransformer comes loaded with >40 open source GPT-style models. You can load any of them in with `HookedTransformer.from_pretrained(MODEL_NAME)`. Each model is loaded into the consistent HookedTransformer architecture, designed to be clean, consistent and interpretability-friendly.

For this demo notebook we'll look at GPT-2 Small, an 80M parameter model. To try the model the model out, let's find the loss on this paragraph!'''

loss = gpt2_small(model_description_text, return_type="loss")
print("Model loss:", loss)

logits: Tensor = gpt2_small(model_description_text, return_type="logits")
print('model logits', logits.shape)
prediction = logits.argmax(dim=-1).squeeze()
print('prediction', prediction.shape)

Model loss: tensor(4.3443, device='mps:0')
model logits torch.Size([1, 112, 50257])
prediction torch.Size([112])


In [49]:
gpt2_small.to_str_tokens(model_description_text)[1:]

['##',
 ' Loading',
 ' Models',
 '\n',
 '\n',
 'H',
 'ooked',
 'Trans',
 'former',
 ' comes',
 ' loaded',
 ' with',
 ' >',
 '40',
 ' open',
 ' source',
 ' G',
 'PT',
 '-',
 'style',
 ' models',
 '.',
 ' You',
 ' can',
 ' load',
 ' any',
 ' of',
 ' them',
 ' in',
 ' with',
 ' `',
 'H',
 'ooked',
 'Trans',
 'former',
 '.',
 'from',
 '_',
 'pret',
 'rained',
 '(',
 'MOD',
 'EL',
 '_',
 'NAME',
 ')',
 '`.',
 ' Each',
 ' model',
 ' is',
 ' loaded',
 ' into',
 ' the',
 ' consistent',
 ' Hook',
 'ed',
 'Trans',
 'former',
 ' architecture',
 ',',
 ' designed',
 ' to',
 ' be',
 ' clean',
 ',',
 ' consistent',
 ' and',
 ' interpret',
 'ability',
 '-',
 'friendly',
 '.',
 '\n',
 '\n',
 'For',
 ' this',
 ' demo',
 ' notebook',
 ' we',
 "'ll",
 ' look',
 ' at',
 ' G',
 'PT',
 '-',
 '2',
 ' Small',
 ',',
 ' an',
 ' 80',
 'M',
 ' parameter',
 ' model',
 '.',
 ' To',
 ' try',
 ' the',
 ' model',
 ' the',
 ' model',
 ' out',
 ',',
 ' let',
 "'s",
 ' find',
 ' the',
 ' loss',
 ' on',
 ' this',
 ' paragr

In [50]:
# print(true_words)
# prediction on next token
gpt2_small.to_str_tokens(prediction)[:-1]


['\n',
 '\n',
 '...',
 '\n',
 '\n',
 '##',
 'uge',
 ' on',
 'former',
 '\n',
 ' with',
 ' with',
 ' all',
 '100',
 ' models',
 ' models',
 ' models',
 'IM',
 ' models',
 'based',
 ' models',
 '.',
 '\n',
 ' can',
 ' use',
 ' them',
 ' of',
 ' these',
 ' from',
 ' your',
 ' the',
 'h',
 'ooked',
 'Trans',
 'former',
 '`.',
 'load',
 '`',
 'model',
 'end',
 '_',
 'model',
 'ULE',
 '_',
 'NAME',
 ',',
 '`.',
 '\n',
 ' model',
 ' has',
 ' a',
 ' with',
 ' the',
 ' `',
 ' `',
 'Trans',
 'Trans',
 'former',
 '.',
 '.',
 ' and',
 ' to',
 ' be',
 ' used',
 ' and',
 ' easy',
 ' and',
 ' easy',
 'able',
 '-',
 'free',
 '.',
 '\n',
 '\n',
 '##',
 ' example',
 ' tutorial',
 ',',
 ',',
 ' will',
 ' use',
 ' at',
 ' the',
 'PT',
 '-',
 'style',
 '.',
 ',',
 ' Medium',
 ' open',
 'x',
 'b',
 'ized',
 ' with',
 '\n',
 ' load',
 ' out',
 ' model',
 ',',
 ' following',
 ' is',
 ',',
 ' you',
 "'s",
 ' use',
 ' the',
 ' `',
 'less',
 ' the',
 ' model']

In [66]:
# logits.argmax(dim=-1).squeeze().shape
true_words = gpt2_small.to_str_tokens(model_description_text)
# print(pred_words)
true_tokens = gpt2_small.to_tokens(model_description_text).squeeze()
true_tokens.squeeze()[1:].shape

torch.Size([111])

In [74]:
is_correct = (prediction[:-1] == true_tokens[1:])
print(f'{is_correct.sum()}/{len(prediction[:-1])}')
print(gpt2_small.to_str_tokens(prediction[:-1][is_correct]))
# observe that hookedtransformer was predicted after one occurence in the context


33/111
['\n', '\n', 'former', ' with', ' models', '.', ' can', ' of', 'ooked', 'Trans', 'former', '_', 'NAME', '`.', ' model', ' the', 'Trans', 'former', ' to', ' be', ' and', '-', '.', '\n', '\n', ' at', 'PT', '-', ',', ' model', ',', "'s", ' the']


In [79]:
gpt2_small.cfg
## attention vs attention patterns? 

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': np.float64(8.0),
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': np.float64(0.02886751345948129),
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn

In [75]:
# cache activations 
gpt2_text = "Natural language processing tasks, such as question answering, machine translation, reading comprehension, and summarization, are typically approached with supervised learning on taskspecific datasets."
gpt2_tokens = gpt2_small.to_tokens(gpt2_text)

gpt2_logits, gpt2_cache = gpt2_small.run_with_cache(gpt2_tokens)
gpt2_cache

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [82]:
gpt2_small

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [78]:
gpt2_cache["q", 0].shape # (input_seq, nhead, headsize)
# attention scores: q @ k 
# pattern: normalised attention scores 
# result? (multiplied with v?)

torch.Size([1, 33, 12, 64])

### visualise attention heads
- question: how can we better understand how does model respond to a prompt? 
    - example: how is it able to do in-context learning? 
        - ability to generalise to unseen prompts; new situations 
    - how does it learn from (unseen) input prompt? 

- why attention heads: 
    - (ref math frameworks paper): to start from model components that are intrinsically interpretable - input token, output logits, attention patterns
    - residual stream, key, query, values - compressed intermediate states calculating meaningful things

- what insights can we hope to get? 

    - 1. can we classify heads by their attention patterns on texts
        - 1. visualise 
        - 2. decide on meaningful summary stats - to be validated w visualisations

    - observations; 
        - current token heads: focusses on current token
        - prev token: attends to prev tokens
        - first token heads: (like a 'resting head', if not being significantly activated, it rests at the first token?)
    
    - with these qual observations; can we define quantitative measures to help us detect different attention heads, for different input prompts? 
    
    - 2. find induction heads from attention head patterns 
    
    - induction head (): 
    - induction circuit: circuit that consist of composition of previous token head (layer 0) and induction head



In [None]:
# implement attention patterns (normalised attention scores of head)
# how much token attend to another

In [None]:
# circuitsvis: visualise attention heads 
# observe distinction between heads 

In [None]:
# induction circuits: observed most through repeating sequences in input text?

### attributing importance of attention heads 
- how important are they in contributing to model's performance on a task? 

- question: How much of the model's performance on some particular task is attributable to each component of the model?

    - method: direct logit attribution
        1. what are the direct contributions of this head to the output logits?
        

- example/case study: induction circuit? 
- (however, its still at the observation level, activations level?)
- 


### reverse eng induction circuit 
- we observed that a particular head seem to perform task on a class of inptus 
- can we identify why? 



results so far: 
1. observation: 
2. logit attribution: ?
3. zero ablation - 