# Introduction

This notebook sets out a notation which may be used to how a transformer model is applied to a text.

The notation is influenced by Dirac (or Bra-Ket) notation, as well as by Functional Discourse Grammar (FDG) notation.


## Motivation

The hope is that methods can be found to decompose and combine the calculations performed, and to identify labels for other parts of the calculation.


## Bra-Ket Notation

Bra-Ket notation is used in quantum mechanics to represent states and measurements, combined they represent the probability of a measurement of a state. $<x|\psi>$ is the probability of measuring the state $|\psi>$ as $x$.

The transformer is understood as a performing a transformation on the residual space.

A sequence of tokens combine with the transformer to generate an output residual, which through the unembedding layer generates a weight representing likelihood of each of the next possible tokens.

The embedding and unembedding layers give a natural labelling for certain vectors in the residual space.

The input residual based on the token "the" is labelled as $\underline{the}$. The input residual from combining the unembedding vector with the position vector for the $i^{th}$ token is labelled as $\underline{\text{the}_i}$. The input residual from combining the unembedding vector with the position vector for the $i^{th}$ token and the position vector for the $j^{th}$ token is labelled as $\underline{\text{the}_i \text{cat}_j}$.

The vector from the unembedding layer representing the token "cat" is labelled as $\overline{cat}$.

The transformer is represented as operator $T$. Transformers are considered as an operation acting on the last residual stream. $T \ket{\underline{\text{The}}}$ is therefore the output residual from running the transformer on the single token " The". $\bra{\overline{\text{cat}}} T \ket{\underline{\text{The}}}$ is the value of the output logit element for the token " cat". A transformer with tokens already in the context window is represented as $T(\underline{\text{The cat is}})$. 

Note that T is not a linear operation, so the conventions for usage of bra-ket notation from other contexts may not be valid here.


## Functional Discourse Grammar Notation

Functional Discourse Grammar is normally used to describe the production of spoken language in a way that is comparable across widely differing languages. It could be valuable in helping to understand and describe how unsupervised algorithms used for language production and processing actually work. It looks at the whole conversation or discourse, instead of just individual sentences or clauses. Although other types of grammar also allow for the study of whole conversations, they typically focus more on sentences or clauses, and handle conversation analysis separately.

FDG was developed from another approach known as Functional Grammar, as researchers thought more about how conversations were being treated. However, FDG doesn't aim to describe everything about a conversation - it is not a Grammar of Discourse but rather a Functional Grammar which takes account of Discourse. It only focuses on the parts of the conversation that influence the way language is spoken.

The four main levels of the grammar are the Interpersonal Level, the Representational Level, the Morphosyntactic Level, and the Phonological Level. Each of these levels represents a different stage in the production of an utterance:

Interpersonal Level: This deals with the social interaction between speaker and hearer(s), such as turn-taking or speech acts like requests and promises.

Representational Level: This deals with the content of the utterance, such as the events, states, and entities that the speaker wants to talk about.

Morphosyntactic Level: This describes the formal linguistic structures that are used to express the content, like words, phrases, and clauses.

Phonological Level: This level represents the actual sounds (or in the case of writing, the graphic symbols) that make up the utterance.

Each level consists of a nested structure of layers. An example of the analysis of the words "these bananas" is given below. The first line is the Interpersonal Level, the second line is the Representational Level, the third line is the Morphosyntactic Level, and the fourth line is the Phonological Level.


$$\text{(I like) these bananas.} \\  

 a. IL (+id R_I) \\

 b. RL (prox m x_i : [(f_i : /bə’na:nə/N(f_i))(x_i) ]) \\

 c. ML (Np_i : [(Gw_i : this-pl(Gw_i)) (Nw_i : / bə’na:nə /-pl (Nw_i))] (Np_i)) \\

 d. PL (PP_i : [(Pw_i : /i:z/ (Pw_i)) (Pw_j : / bə’na:nəz/ (Pw_j))] (PP_i))$$



## Disclaimer

I have little knowledge of linguistics, quantum mechanics and what I did know of mathematics is mostly forgotten. I am only beginning to get an understanding of mechanistic interpretability. Please verify any information you find here for yourself. Corrections and suggestions are welcome.

# Setup
(No need to read)

In [None]:
%pip install "numpy == 1.23.*"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
%pip install git+https://github.com/neelnanda-io/TransformerLens.git


Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-824js2ks
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-824js2ks
  Resolved https://github.com/neelnanda-io/TransformerLens.git to commit 38f4d202283552fc14115dd1f004448d8900be15
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
import plotly.express as px
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

### Model Setup

Using a small pythia model to keep diagrams simple.

In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
from transformers import GPTNeoXForCausalLM, AutoTokenizer
hfmodel = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-70m-deduped")

Downloading (…)lve/main/config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/166M [00:00<?, ?B/s]

In [None]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m-deduped", device=device, hf_model=hfmodel)

Downloading (…)okenizer_config.json:   0%|          | 0.00/396 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m-deduped into HookedTransformer


## Model Capabilities

Using a simple sentence which requires basic knowledge from outside the input tokens. The model shows some semantic knowledge, but with this short text prediction is dominated by those based on grammatical structure. 

In [None]:
plaintext = "Dublin is the capital of"
tokens = model.to_tokens(plaintext)
logits_out, cache = model.run_with_cache(tokens, remove_batch_dim=False) # leave batch dim so we can run layers manually
logits_out.shape

torch.Size([1, 7, 50304])

In [None]:
utils.test_prompt("Dublin is the capital of", "Ireland",model)
utils.test_prompt("Paris is the capital of", "France",model)
utils.test_prompt("Goose is the capital of", "Ireland",model)


Tokenized prompt: ['<|endoftext|>', 'D', 'ublin', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Ireland']


Top 0th token. Logit: 23.82 Prob: 42.22% Token: | the|
Top 1th token. Logit: 22.67 Prob: 13.33% Token: | Ireland|
Top 2th token. Logit: 22.03 Prob:  7.02% Token: | Dublin|
Top 3th token. Logit: 21.36 Prob:  3.60% Token: | a|
Top 4th token. Logit: 20.48 Prob:  1.49% Token: | Europe|
Top 5th token. Logit: 20.00 Prob:  0.92% Token: | Belfast|
Top 6th token. Logit: 19.87 Prob:  0.81% Token: | Irish|
Top 7th token. Logit: 19.76 Prob:  0.72% Token: | Britain|
Top 8th token. Logit: 19.75 Prob:  0.72% Token: | its|
Top 9th token. Logit: 19.62 Prob:  0.63% Token: | an|


Tokenized prompt: ['<|endoftext|>', 'Paris', ' is', ' the', ' capital', ' of']
Tokenized answer: [' France']


Top 0th token. Logit: 23.55 Prob: 44.17% Token: | the|
Top 1th token. Logit: 21.71 Prob:  7.07% Token: | a|
Top 2th token. Logit: 19.91 Prob:  1.16% Token: | an|
Top 3th token. Logit: 19.63 Prob:  0.88% Token: | Europe|
Top 4th token. Logit: 19.49 Prob:  0.76% Token: | this|
Top 5th token. Logit: 19.44 Prob:  0.73% Token: | one|
Top 6th token. Logit: 19.41 Prob:  0.71% Token: | France|
Top 7th token. Logit: 19.33 Prob:  0.65% Token: | its|
Top 8th token. Logit: 19.31 Prob:  0.64% Token: | modern|
Top 9th token. Logit: 19.26 Prob:  0.61% Token: | all|


Tokenized prompt: ['<|endoftext|>', 'Go', 'ose', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Ireland']


Top 0th token. Logit: 23.03 Prob: 41.01% Token: | the|
Top 1th token. Logit: 20.97 Prob:  5.22% Token: | a|
Top 2th token. Logit: 19.73 Prob:  1.51% Token: | all|
Top 3th token. Logit: 19.65 Prob:  1.40% Token: | our|
Top 4th token. Logit: 19.63 Prob:  1.36% Token: | this|
Top 5th token. Logit: 19.18 Prob:  0.88% Token: | an|
Top 6th token. Logit: 19.17 Prob:  0.87% Token: | your|
Top 7th token. Logit: 18.89 Prob:  0.65% Token: |
|
Top 8th token. Logit: 18.71 Prob:  0.54% Token: | modern|
Top 9th token. Logit: 18.57 Prob:  0.47% Token: | life|


In [None]:
utils.test_prompt("of", "the",model)

Tokenized prompt: ['<|endoftext|>', 'of']
Tokenized answer: [' the']


Top 0th token. Logit: 18.68 Prob:  8.40% Token: | the|
Top 1th token. Logit: 18.26 Prob:  5.55% Token: |_|
Top 2th token. Logit: 18.09 Prob:  4.69% Token: |
|
Top 3th token. Logit: 17.09 Prob:  1.72% Token: | a|
Top 4th token. Logit: 16.84 Prob:  1.34% Token: |(|
Top 5th token. Logit: 16.66 Prob:  1.11% Token: |-|
Top 6th token. Logit: 16.61 Prob:  1.06% Token: |.|
Top 7th token. Logit: 16.55 Prob:  1.00% Token: |the|
Top 8th token. Logit: 16.36 Prob:  0.83% Token: |"|
Top 9th token. Logit: 16.36 Prob:  0.83% Token: |,|


In [None]:
utils.test_prompt("nd is the gh of","the",model)

Tokenized prompt: ['<|endoftext|>', 'nd', ' is', ' the', ' gh', ' of']
Tokenized answer: [' the']


Top 0th token. Logit: 21.54 Prob: 29.55% Token: | the|
Top 1th token. Logit: 19.78 Prob:  5.11% Token: | a|
Top 2th token. Logit: 18.58 Prob:  1.53% Token: | time|
Top 3th token. Logit: 18.36 Prob:  1.23% Token: | your|
Top 4th token. Logit: 18.16 Prob:  1.01% Token: | life|
Top 5th token. Logit: 18.06 Prob:  0.92% Token: | an|
Top 6th token. Logit: 17.85 Prob:  0.74% Token: | energy|
Top 7th token. Logit: 17.73 Prob:  0.65% Token: | this|
Top 8th token. Logit: 17.61 Prob:  0.58% Token: | one|
Top 9th token. Logit: 17.59 Prob:  0.57% Token: | our|


In [None]:
utils.test_prompt("Dublin is the capital of", "Ireland",model, prepend_bos = False)
utils.test_prompt("Dublin is the capital of", "Ireland",model, prepend_bos = True)
utils.test_prompt(" Dublin is the capital of", "Ireland",model, prepend_bos = False)

Tokenized prompt: ['D', 'ublin', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Ireland']


Top 0th token. Logit: 24.09 Prob: 39.46% Token: | the|
Top 1th token. Logit: 23.27 Prob: 17.31% Token: | Ireland|
Top 2th token. Logit: 22.57 Prob:  8.55% Token: | Dublin|
Top 3th token. Logit: 21.45 Prob:  2.81% Token: | a|
Top 4th token. Logit: 20.81 Prob:  1.47% Token: | Europe|
Top 5th token. Logit: 20.12 Prob:  0.74% Token: | Irish|
Top 6th token. Logit: 20.07 Prob:  0.71% Token: | Britain|
Top 7th token. Logit: 20.04 Prob:  0.69% Token: | England|
Top 8th token. Logit: 19.93 Prob:  0.61% Token: | France|
Top 9th token. Logit: 19.88 Prob:  0.58% Token: | Belfast|


Tokenized prompt: ['<|endoftext|>', 'D', 'ublin', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Ireland']


Top 0th token. Logit: 23.82 Prob: 42.22% Token: | the|
Top 1th token. Logit: 22.67 Prob: 13.33% Token: | Ireland|
Top 2th token. Logit: 22.03 Prob:  7.02% Token: | Dublin|
Top 3th token. Logit: 21.36 Prob:  3.60% Token: | a|
Top 4th token. Logit: 20.48 Prob:  1.49% Token: | Europe|
Top 5th token. Logit: 20.00 Prob:  0.92% Token: | Belfast|
Top 6th token. Logit: 19.87 Prob:  0.81% Token: | Irish|
Top 7th token. Logit: 19.76 Prob:  0.72% Token: | Britain|
Top 8th token. Logit: 19.75 Prob:  0.72% Token: | its|
Top 9th token. Logit: 19.62 Prob:  0.63% Token: | an|


Tokenized prompt: [' Dublin', ' is', ' the', ' capital', ' of']
Tokenized answer: [' Ireland']


Top 0th token. Logit: 24.08 Prob: 46.25% Token: | the|
Top 1th token. Logit: 21.83 Prob:  4.87% Token: | a|
Top 2th token. Logit: 21.27 Prob:  2.78% Token: | Ireland|
Top 3th token. Logit: 20.65 Prob:  1.49% Token: | Britain|
Top 4th token. Logit: 20.64 Prob:  1.48% Token: | England|
Top 5th token. Logit: 20.39 Prob:  1.15% Token: | Europe|
Top 6th token. Logit: 20.29 Prob:  1.04% Token: | London|
Top 7th token. Logit: 20.20 Prob:  0.95% Token: | New|
Top 8th token. Logit: 20.14 Prob:  0.90% Token: | an|
Top 9th token. Logit: 19.83 Prob:  0.66% Token: | Scotland|


## Helper Functions

### Run Block function

In [None]:
#define a function to run a single transformer block
def run_block(model, tokens, block_idx, cache):
    # get the block
    block = model.hf_model.transformer.h[block_idx]
    # get the block's hook point
    hook_point = model.hook_points[block_idx]
    # run the block
    out, cache = hook_point.run(tokens, cache)
    return out, cache



### Operator From Prefix

In [None]:
import numpy as np
def operator_from_prefix(model, prefixtext):
   """This function generates a function which acts on a residual vector
   It takes a model and plaintext prefix and returns a function which appends the token
     to the prefix and returns the residual vector for that token"""
   prefix_tokens = model.to_tokens(prefixtext).cpu().numpy()

   position = len(tokens)
   last_block_id = len(model.blocks)-1
   resultkey = f'blocks.{last_block_id}.hook_resid_post'
   logits_out, cache = model.run_with_cache(tokens, remove_batch_dim=False)
   def operator(token):
      """Appends token to tokens_tensor and runs this through the model"""
      tokens = np.append(prefix_tokens,token)
      tokens_tensor = torch.as_tensor(tokens)
      logits_out, ocache = model.run_with_cache(tokens_tensor, remove_batch_dim=False)
      return ocache[resultkey][0,position]
   return operator


# Development

Following "A Mathematical Framework for Transformer Circuits" a transformer can be defined as follows.

$$ r^0 = W_E t$$
$$ z^{i} = r^{i-1} + \sum_{h^{i,j} \in H^i} h^{i,j}(r^{i-1})$$
$$ r^{i} = z^{i} + m^{i}(z^{i})$$
$$ T(t) = W_U x^l$$

Where
- $r^i$ are the vectors in the residual stream progressing from transformer block $i$
- $z^i$ are the vectors from the weighted sum of values after applying the attention pattern, passed into the MLP $m^i$ from transformer block $i$
-  $H_i$ is the set of attention heads at layer $i$, which has elements $h^{i,j}$; attention head $j$ in transformer block $i$. $h(x)$ is the operation of attention.
- $m^i$ is the MLP at layer $i$, 
- $t$ is the vector of one-hot encoded tokens
- $W_E$ is the embedding matrix
- $W_U$ is the unembedding matrix,, $h(x)$ is the operation of attention,  and - $l$ is the number of layers.

T can be decomposed into residual blocks (8 for pythia-70), labelled $T^1 ... T^{8}$. 

Attention $h^{i,j}$ can be decomposed as 

$$ h^{i,j}(x) = A $$

$$(A \otimes W_O W_V) . x$$
Where $A$ is the attention matrix, $W_O$ is the output matrix, and $W_V$ is the value matrix.


# Worked Example

Using the sentence "Dublin is the capital of Ireland" to explore the relationsips between input tokens and internal representation of words,and a dependency on earlier tokens as well as general knowledge.


In [None]:
model.to_str_tokens("Dublin is the capital of Ireland", prepend_bos=False)

['D', 'ublin', ' is', ' the', ' capital', ' of', ' Ireland']

## First Residual Stream

Starting with the transformer acting on a single token "D"

$$|\underline{\text{D}}^0> = W_E [510] $$

In [None]:
model.to_tokens("Dublin", prepend_bos = False)

tensor([[   37, 21751]])

In [None]:
ket_The = torch.unsqueeze(model.W_E[37],1)
ket_The.shape

torch.Size([512, 1])

### Transformer Block 0

The residual in the first layer

$$|\underline{\text{D}}^1> = |\underline{\text{D}}^0> + m^1 ( |\underline{\text{D}}^0>
+ \sum_{j=0}^6 h^{1,j} |\underline{\text{D}}^0> )$$

$$h^{1,j}|\underline{\text{D}}^0> = (A \otimes W_O W_V) . \left[ | \underline{\text{D}}^0>\right]$$


In [None]:
Markdown(r"""$$\ket{\underline{\text{D}}^1} = \ket{\underline{\text{D}}^0} + ( \ket{m^1 \underline{\text{D}}^0}
+ \sum_{j=0}^6 \ket{h^{1,j} \underline{\text{D}}^0} )$$

$$\ket{h^{1,j} \underline{\text{D}}^0} = (A \otimes W_O W_V) . \left[ \ket{\underline{\text{D}}^0}\right]$$""")

$$\ket{\underline{\text{D}}^1} = \ket{\underline{\text{D}}^0} + ( \ket{m^1 \underline{\text{D}}^0}
+ \sum_{j=0}^6 \ket{h^{1,j} \underline{\text{D}}^0} )$$

$$\ket{h^{1,j} \underline{\text{D}}^0} = (A \otimes W_O W_V) . \left[ \ket{\underline{\text{D}}^0}\right]$$

In [None]:
plaintext = "D"
tokens = model.to_tokens(plaintext,prepend_bos=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
model.to_string(torch.argmax(logits))

','

Only one residual stream to attend to, so after softmax all elements are 1.

In [None]:
import pandas as pd
residuals = pd.DataFrame()

residuals["attention scores before softmax"]=torch.squeeze(cache["blocks.0.attn.hook_attn_scores"])
residuals["attention pattern after softmax"] =torch.squeeze(cache["blocks.0.attn.hook_pattern"])
residuals

Unnamed: 0,attention scores before softmax,attention pattern after softmax
0,4.157962,1.0
1,19.440796,1.0
2,10.462513,1.0
3,21.753639,1.0
4,8.815374,1.0
5,16.040213,1.0
6,7.02302,1.0
7,11.908913,1.0


Comparing residual vectors in block 1, input residual has little range compared to output from attention and mlp, and the resulting residual from block 1.

The violin plots show the range of elements of the vector on vertical axis, the width of each plot indicates the number of elements around that magnitude.

In [None]:

residuals = pd.DataFrame()

residuals["resid_pre"] = cache["blocks.0.hook_resid_pre"][0,0]
residuals["attn_out"] = cache["blocks.0.hook_attn_out"][0,0]
residuals["mlp_out"] = cache["blocks.0.hook_mlp_out"][0,0]
residuals["resid_post"] = cache["blocks.0.hook_resid_post"][0,0]

fig = px.violin(residuals)
fig.show()
      

Verifying that attn_out and mlp_out cover the entire change in the residual stream in the first block.

In [None]:
my_post = cache["blocks.0.hook_resid_pre"][0,0] + cache["blocks.0.hook_attn_out"][0,0] + cache["blocks.0.hook_mlp_out"][0,0]
delta = cache["blocks.0.hook_resid_post"][0,0] - my_post
residuals = pd.DataFrame()
residuals["Cached Residual"] = cache["blocks.0.hook_resid_post"][0,0] 
residuals["Calculated Residual"] = my_post
residuals["Difference"] = delta

fig = px.violin(residuals)
fig.show()

The residual out of the 5th transformer block has a greater range and fewer elements near 0.

In [None]:
fig = px.violin(cache['blocks.5.hook_resid_post'][0,0], labels = ["Layer 5 residual"])
fig.show()

In [None]:
plaintext = "Dublin is the capital of"
tokens = model.to_tokens(plaintext,prepend_bos=False)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=False)
model.to_string(torch.argmax(logits))

''

In [None]:
line(cache['blocks.5.hook_resid_post'][0,0])

This code evaluates the logit for token " Ireland" in the sentence "Dublin is the capital of Ireland".

$$\bra{\text{Ireland}} T[\text{Dublin is the capital}] |of>$$



In [None]:
operator = operator_from_prefix(model,"Dublin is the capital")
token = model.to_single_token(" of")

result1 = operator(token)
line(result1)