# Bigrams in Bilinear Transformers

Bilinear transformers are great because they are even more linear in nature than the original architecture. This allows us to perform standardized analysis on each component separately (or even together). This notebook in particular focusses on extracting 2-grams from the weights. This notebook is meant as an introduction to the capabilities of bilinear layers and shouldn't be used to draw rigorous conclusions.

In [1]:
%load_ext autoreload
%autoreload 2

from language import Transformer
import plotly.express as px
import torch
import pandas as pd
from einops import *

torch.set_grad_enabled(False)
color = dict(color_continuous_midpoint=0, color_continuous_scale="RdBu")

model = Transformer.from_pretrained(d_model=1024, n_layer=1, modifier="i5").cuda()

vocab = model.vocab
config = model.config

1 and 2 layer transformers have slightly different behavior. The 1-layer transformer has a slightly more diverse MLP layer (because it kinda has to). Results shown in this notebook hold for both.

## Direct Path

Let's start with the obvious way to look at 2-grams, the direct embedding-unembedding path. 

In [2]:
direct = (model.w_u @ model.w_e).detach().cpu()
assert direct.shape == (len(vocab), len(vocab))

vocab.get_max_activations(direct.T, ["input", "output"], 10)

Unnamed: 0,input,output,value
0,else,##where,2.513922
1,',s,2.24838
2,even,though,2.189376
3,each,other,2.153148
4,me,##ow,2.12376
5,where,##ver,2.09008
6,my,##self,1.936118
7,us,##ually,1.90434
8,ever,since,1.893315
9,ben,##eath,1.88843


## MLP path

Now, onto the good stuff, the MLP. In a normal neural network, we can't study the MLP with SVD or any linear technique. However, bilinear layers actually allow us to do so. In this section, we will limit ourselves to the direct MLP path, aka embedding -> MLP -> unembedding. To our knowledge, this hasn't been done before. To study the direct path, we can take the diagonal over the last two dimensions of the B tensor. I won't go into the math here for brevity, trust me bro. 

Before looking at the eigenvalues, let's look at the highest activations in general, this will result in a map of input -> output, meaning that we get the pairs of which the model is most sure.

In [3]:
diag = model.ube.diagonal(residual=True).cpu()



I use a helper function ``get_max_activations`` This returns a data frame of indices and values of the max values in the provided tensor. The indices are automatically converted to tokens. Let's look at the top 1000 connections in the first MLP layer.

In [4]:
df = vocab.get_max_activations(diag.T, ["input", "output"], k=1_000, largest=True)
df

Unnamed: 0,input,output,value
0,else,##where,2.524856
1,each,other,2.243097
2,where,##ver,2.192546
3,',s,2.183418
4,even,though,2.164958
...,...,...,...
995,re,##be,1.094736
996,ener,##get,1.094490
997,why,dont,1.094218
998,clum,##p,1.093943


Okay, so it's obvious that the most first layer just connects the obvious bi-grams of words that didn't quite get included in the tokenizer. Let's quantify this.

In [5]:
px.line(df["output"].str.startswith("##").cumsum(), title="cumulative ## tokens").update_layout(title_x=0.5)

#### Preceding and Following tokens

Given this diagonal matrix, we can also analyze which words are most important indicators for the next word or the other way around.
For instance, we can ask:
- *"what tokens are most important for the model to decide to predict the token 'game'"* (preceding token).
- *"what tokens does the token 'game' infer most"* (following token).

In [6]:
token = "girl"
idx = vocab[token]

preceding = vocab[torch.topk(diag[idx], k=10).indices]
following = vocab[torch.topk(diag[:, idx], k=10).indices]

pd.DataFrame(dict(preceding=preceding, self=[token]*10, following=following))

Unnamed: 0,preceding,self,following
0,little,girl,named
1,young,girl,who
2,restless,girl,called
3,baby,girl,##s
4,the,girl,visited
5,clumsy,girl,giggled
6,##vous,girl,##hood
7,adventurous,girl,names
8,pretty,girl,'
9,another,girl,stepped


Left and right are not related, this is simply a concise visualization. 

#### Articles

Something interesting to look at is if the model has learned to use correct articles. Let's study this a bit more in-depth.
We can do this quite simply by taking the weights for both for all subsequent tokens and plotting them together.

In [7]:
mask_non_words = torch.tensor([vocab.inv[idx][0].isalpha() for idx in range(len(vocab))])

vowels = ['a', 'e', 'i', 'o', 'u']

df = pd.DataFrame(dict(x=diag[:, vocab["a"]].cpu(), y=diag[:, vocab["an"]].cpu(), token=vocab.tokens))
df = df[df.token.str[0].str.isalpha()]
df["guess"] = df.token.str[0].isin(vowels)

px.scatter(df, x="x", y="y", hover_name="token", color="guess", labels=dict(x="a", y="an")).show()

The result isn't as clean as I'd hoped but it seems that the model simply generally has a strong bias towards picking 'a' which is sensible. If you hover over most tokens, it's clear why it's "unsure" about some of them, a proper filtering of verbs and such will probably improve the separation. Alos, I'd assume this becomes more clear as models improve.

## Token Interactions
Until now, we've only looked at the direct path. This is fine, but the MLP encodes so much more information than (input output)-pairs. Specifically, it actually encodes (input, input, output)-triplets, being one of the reasons for its effectiveness.

So, in essence, until now, we've just looked at token interactions with itself. This reduces the UBE tensor to a matrix, which we can study. Now, we will perform another reduction, by just taking the first dimension of the UBE tensor, which means that we will get the input-input interactions for a certain token.

In [8]:
idx = vocab["game"]
inter = model.ube.interaction(idx, residual=True).cpu()

# topk = torch.topk(inter.tril().flatten(), k=25)
# input1, input2 = torch.unravel_index(topk.indices, inter.size())
# pd.DataFrame(dict(input1=vocab[input1], input2=vocab[input2], value=topk.values.cpu()))

vals, vecs = torch.linalg.eigh(inter)

a = pd.DataFrame({"0": vocab[torch.topk(vecs[:, -1], k=20).indices]})
a["1t"] = vocab[torch.topk(vecs[:, -2], k=20).indices]
a["1v"] = torch.topk(vecs[:, -2], k=20).values

a["2t"] = vocab[torch.topk(vecs[:, -3], k=20).indices]
a["3t"] = vocab[torch.topk(vecs[:, -4], k=20).indices]
a["3v"] = torch.topk(vecs[:, -4], k=20).values

a["4t"] = vocab[torch.topk(vecs[:, -5], k=20).indices]
a["5t"] = vocab[torch.topk(vecs[:, -6], k=20).indices]
a

Unnamed: 0,0,1t,1v,2t,3t,3v,4t,5t
0,baseball,their,0.104808,favorite,##ist,0.047357,other,don
1,video,our,0.097071,##iest,grand,0.042416,story,worry
2,board,my,0.095033,best,mean,0.040716,that,##ep
3,basketball,his,0.089894,##est,exciting,0.040571,al,if
4,football,your,0.088183,favourite,##ver,0.038498,de,play
5,soccer,another,0.084247,##de,easy,0.03838,one,run
6,hockey,s,0.080735,biggest,quiet,0.038002,book,park
7,##hip,favorite,0.079435,that,##mp,0.0379,talk,rec
8,##aw,good,0.073628,great,stubb,0.037877,whole,do
9,card,her,0.070882,size,##ient,0.037596,our,explain


In [9]:
# vals, vecs = torch.linalg.eigh(inter)


a

Unnamed: 0,0,1t,1v,2t,3t,3v,4t,5t
0,baseball,their,0.104808,favorite,##ist,0.047357,other,don
1,video,our,0.097071,##iest,grand,0.042416,story,worry
2,board,my,0.095033,best,mean,0.040716,that,##ep
3,basketball,his,0.089894,##est,exciting,0.040571,al,if
4,football,your,0.088183,favourite,##ver,0.038498,de,play
5,soccer,another,0.084247,##de,easy,0.03838,one,run
6,hockey,s,0.080735,biggest,quiet,0.038002,book,park
7,##hip,favorite,0.079435,that,##mp,0.0379,talk,rec
8,##aw,good,0.073628,great,stubb,0.037877,whole,do
9,card,her,0.070882,size,##ient,0.037596,our,explain


This will mostly become useful once we introduce some additional inspection techniques in later attention layers.