# Mechanistic interpretability

After reading the [Anthropic paper Elhage et al. (2021)](https://transformer-circuits.pub/2021/framework/index.html), I want to get a practical understanding of the basic concepts by implementing:

- QK and OV matrices
- skip-trigram

## QK and OV matrices

### Theory

Using tensor calculus, the transformer can be represented as 
$$ T=\mathrm{Id} \otimes W_U W_E+\sum_{h \in H} A^h \otimes\left(W_U W_{O V}^h W_E\right)$$
where ${\rm Id}$ is the identity matrix, $A^h$ are the attention for different attention heads given by 
$$A^h=\operatorname{softmax}\left(t^T \cdot W_E^T W_{Q K}^h W_E \cdot t\right),$$ and $t$ are the input tokens. 
Here, $\otimes$ is the outer product (or tensor product) and I am still not entirely sure what the $\cdot$ represents. Presumably, matrix multiplication?

Now we introduce the "query-key circuit" QK and the "output-value circuit" OV, given by
$$OV \equiv W_U W_{O V}^h W_E$$
and
$$QK \equiv W_E^T W_{Q K}^h W_E.$$

The transformer can be then written as 
$$A^h=\operatorname{softmax}\left(t^T \cdot QK \cdot t\right),$$
$$ T=\mathrm{Id} \otimes W_U W_E+\sum_{h \in H} A^h \otimes OV.$$
The first term above encapsulates bigrams. The second, skip-trigrams.

### Numpy implementation

Notice that while I have handled the weight matrices Q, K and V, this actually involve a number of other weight matrices I am not familiar with. Without knowing more about these matrices, this is how I would implement QK and OV with numpy:
```python
import numpy as np

OV=WU @ WOV @ WE
QK=WE.T @ WQK @ WE
```

Assuming `softmax` is already present, the transformer would then be implemented using the paper's formulation as
```python
# identity matrix with a given shape
Id=np.eye(shape_I)

# First transformer term
T1=np.outer(Id,WU @ WE) 

# Second transformer term (totally inefficient)
T2=0
for QK in QK_list:
    # Attention heads
    Ah=softmax(t.T @ QK @ t)

    # Second transformer term
    T2+=np.outer(Ah,OV)

T=T1+T2
```

## Skip-trigram

Here I implement an algorithm that prints all skip-trigram for a given sequence.

Input sequence

In [2]:
seq=range(8)

### Trigrams

In [3]:
def trigram(seq):
    for i in seq:
        if i<len(seq)-2: 
            print(i,i+1,i+2)

In [4]:
trigram(seq)

0 1 2
1 2 3
2 3 4
3 4 5
4 5 6
5 6 7


### Skip-trigram

In [5]:
def skip_trigram(seq,step=1):
    for i in seq:
        if i<len(seq)-2-step: 
            print(i,i+1+step,i+2+step)

In [7]:
skip_trigram(seq)

0 2 3
1 3 4
2 4 5
3 5 6
4 6 7
