Bilinear transformers are great because they are even more linear in nature than the original architecture. This allows us to perform standardized analysis on each component separately (or even together). This notebook in particular studies the singular values of each object. This notebook is very hand-wavey in nature and doesn't claim that the analyses actually hold, it is merely meant as a collection of cool-looking plots that can motivate further research.

In [4]:
%load_ext autoreload
%autoreload 2

from shared.transformer import Transformer, Config
import plotly.express as px
import torch
import pandas as pd
from einops import *
import itertools

torch.set_grad_enabled(False)

name = "tdooms/TinyStories-1-256"

config = Config.from_pretrained(name)
model = Transformer.from_pretrained(name, config=config)
vocab = model.vocab

color = dict(color_continuous_midpoint=0, color_continuous_scale="RdBu")
facet = dict(height=200 * config.n_layer + 200, facet_col=0, facet_col_wrap=config.n_head)

def set_facet_labels(fig):
    for annotation in fig.layout.annotations:
        facet = int(annotation.text.split("=")[-1])
        annotation.update(text=f"Head {facet // config.n_head}.{facet % config.n_head}")
    return fig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


1 and 2 layer transformers have slightly different behavior. The 1-layer transformer has a slightly more diverse MLP layer (because it kinda has to). Results shown in this notebook hold for both.

### Attention Heads
The first part of this analysis will mostly cover well-established methods to analyze the attention mechanism from the weights.
We start by looking at how the positional encoding influences each head.

For rotary embeddings this can be annoying but I think a semi-principled thing to do is simply propagate a matrix full of ones through the attention mechanism.

In [5]:
from shared.transformer import Rotary, apply_rotary_pos_emb

q = torch.ones(1, 1, config.n_ctx, config.n_ctx) @ model.w_q.mT
k = torch.ones(1, 1, config.n_ctx, config.n_ctx) @ model.w_k.mT

cos, sin = Rotary(config.d_model // config.n_head)(q.size(-2), q.device)
q, k = apply_rotary_pos_emb(q, k, cos, sin)

qk_pos = (q @ k.mT)

fig = px.imshow(qk_pos.flatten(0, 1), **facet, **color)\
    .update_xaxes(showticklabels=False).update_yaxes(showticklabels=False).update_layout(title="QK Positional Embeddings", title_x=0.5)
set_facet_labels(fig)

When using causal attention, only the lower triangular part is considered. This plot shows that only the first layer really cares about positions, as is common with learnt positional embeddings. Attention head 0 and 3 have a strong bias towards very close by tokens, and head 1 mostly likes the first token.

In [None]:
qk_token = ((model.w_e[None, None]).mT @ model.qk @ model.w_e[None, None]).detach()
fig = px.imshow(qk_token.flatten(0, 1)[:, :256, :256].cpu(), **facet, **color)\
    .update_xaxes(showticklabels=False).update_yaxes(showticklabels=False).update_layout(title="QK Token Embeddings", title_x=0.5)
set_facet_labels(fig)

The above plot takes a "random" sample somewhere near the start. This isn't all too telling what is going on. Head 1 and 3 clearly have a positive and negative bias (respectively) towards the same token, but the others seem fine. Let's first check the traces.

In [None]:
traces = qk_token.diagonal(dim1=-2, dim2=-1).sum(-1)
ranges = itertools.product(range(config.n_layer), range(config.n_head))
pd.DataFrame({f"head {i}.{j}": [int(traces[i, j])] for i, j in ranges})

As expected, very large values for head 0.1 and 0.3

In [None]:
# This can take a while to run (~30s on my machine)
heads = torch.linalg.svdvals(qk_token[0]).cpu()
df = pd.DataFrame({f"H0.{i}": data[:64] for i, data in enumerate(heads)})
px.line(df, title="singular values of token embeddings in each head").update_layout(title_x=0.5)

Not too sure what to conclude from this. This just looks normal. The opposite analysis on the OV circuits is even more standard, no need to look further into that currently.

In [None]:
ov_token = (model.w_u[None, None] @ model.ov @ model.w_u[None, None].mT).detach()

# traces = ov_token.diagonal(dim1=1, dim2=2).sum(-1)
# pd.DataFrame({f"head {i}": [int(v)] for i, v in enumerate(traces)})

# vals = torch.linalg.svdvals(ov_token[:1]).T[:64].cpu()
# px.line(vals, title="singular values of token embeddings in each head").update_layout(title_x=0.5)

px.imshow(ov_token.flatten(0, 1)[:, :256, :256].cpu(), **facet, **color)


### Embedding / Unembeddings
We've studied the attention part of the transformer but there is way more to study. Let's look at the embedding and unembedding.

In [None]:
vals = torch.linalg.svdvals(model.w_e.detach()).cpu()
px.line(vals, title="singular values of token embeddings").update_layout(title_x=0.5)

This plot is more interesting; it has a drop near the end. This is generally very uncommon, there is usually some structure on the most important dims, which can take many shapes but eventually the plots slowly fades to 0. In this case, the is some information (or lack thereof) stored in the last part.

Let's see what could cause this; let's inspect the positional embeddings appended to the token embeddings.

In [None]:
w_e_pos = torch.cat([model.w_e, model.w_pos], dim=1)
vals = torch.linalg.svdvals(w_e_pos).cpu()
px.line(vals, title="singular values of position and token embeddings").update_layout(title_x=0.5)

Seems like there is still something going on, the positional embeddings have removed about half the gap but it's still quite obvious.

In [None]:
vals = torch.linalg.svdvals(model.w_u).cpu()
px.line(vals, title="singular values of unembeddings").update_layout(title_x=0.5)

Pretty expected, the unembedding matrix uses all the hidden dimensions to construct their output. There is a single really big value, let's analyze what it represents.

In [None]:
u, s, v = torch.linalg.svd(model.w_u)

top_u = torch.topk(u[0].cpu(), 20)
top_v = torch.topk((v[0] @ model.w_e).cpu(), 20) # I'm always confused about which index this should be for some reason...

outputs = vocab.tokenize(top_u.indices)
inputs = vocab.tokenize(top_v.indices)

pd.DataFrame(dict(out_toks=outputs, out_vals=top_u.values, in_toks=inputs, in_vals=top_v.values))

Note that, while the tokens are in a table, they are not related in any way. 
Given this, there is nothing too revealing, sadly. 
Nicola Cancedda said that this is related to "distributional" information about the output tokens but I'm not too sure...

### Diagonal of MLP

We can also look at the diagonal (direct interaction) SVD of the MLP.

In [None]:
diag = model.ube.diagonal().detach()
vals = torch.linalg.svdvals(diag)[:, :256].cpu()
px.line(vals.T, title="singular values of ube diagonal").update_layout(title_x=0.5)

So, we can see that the first dimension bears most information and that there is a very long tail, comparable to the unembedding. 
I haven't looked too in-depth at what this represents but I think figuring this out shouldn't be too hard.

The paper about spectral filters may be a good starting point to get interesting findings as we can propagate these filters through our whole network.
If we can characterize some of these dims (especially the first one), this would be cool.

Furthermore, given any other decomposition, this can be fully propagated through this diagonal for maximal insight into the MLP.

### Interaction the MLP

We can also study the SVD for certain words (note that the interaction matrix is symmetric which means we can also study eigen-decompositions).

In [None]:
interaction = model.ube.interaction(vocab["game"], residual=True)
u, s, v = torch.svd(interaction)

px.line(s.T[:256].cpu())

The first MLP seems to care way more about some low dimensional structure of "game" it seems, probably related to tokenization.
Let's see if the singular vectors reveal anything useful.

In [None]:
input1 = vocab.tokenize(torch.topk(v[0, 0], k=10).indices)
input2 = vocab.tokenize(torch.topk(v[0, :, 0], k=10).indices)
pd.DataFrame(dict(input1=input1, input2=input2))

I'm probably doing something wrong by ignoring the u matrix but I haven't thought too deeply about this.