# N-grams in Bilinear Transformers

Bilinear transformers are great because they are even more linear in nature than the original architecture. This allows us to perform standardized analysis on each component separately (or even together). This notebook in particular focusses on extracting 2-grams from the weights. This notebook is meant as an introduction to the capabilities of bilinear layers and shouldn't be used to draw rigorous conclusions.

In [1]:
%load_ext autoreload
%autoreload 2

from stories.model import Transformer, Config
import plotly.express as px
import torch
from bidict import bidict
import pandas as pd
from einops import *

color = dict(color_continuous_midpoint=0, color_continuous_scale="RdBu")
name = "tdooms/MicroStories-1-256-b"

config = Config.from_pretrained(name)
model = Transformer.from_pretrained(name, config=config).cuda()

model.center_unembed().fold_norms()
vocab = model.vocab

## Direct Path

Let's start with the obvious way to look at 2-grams, the direct embedding-unembedding path. 

In [2]:
direct = (model.w_u @ model.w_e).detach().cpu()
assert direct.shape == (len(vocab), len(vocab))

vocab.get_max_activations(direct, ["input", "output"], 10)

Unnamed: 0,input,output,value
0,##ady,##ri,2.942549
1,##dd,##ro,2.862378
2,##en,##ock,2.80148
3,##nts,##my,2.778521
4,##ious,then,2.711935
5,##em,tri,2.666429
6,##ill,##ey,2.662986
7,##llow,once,2.649849
8,##ock,park,2.636638
9,##ht,fo,2.598485


I'm a bit surprised that this doesn't make a lot of sense. ``us ##ually quite`` and ``pe ##ople opened`` seem like good 3-grams but I can't discern any other strong reasons for these embeddings.

TODO: why doesn't this work?

## MLP path

Now, onto the good stuff, the MLP. In a normal neural network, we can't study the MLP with SVD or any linear technique. However, bilinear layers actually allow us to do so. In this section, we will limit ourselves to the direct MLP path, aka embedding -> MLP -> unembedding. To our knowledge, this hasn't been done before. To study the direct path, we can take the diagonal over the last two dimensions of the B tensor. I won't go into the math here for brevity, trust me bro. 

Before looking at the eigenvalues, let's look at the highest activations in general, this will result in a map of input -> output, meaning that we get the pairs of which the model is most sure.

In [3]:
diag = model.ube.diagonal(residual=True).detach().cpu()

In terms of this implementation; PyTorch doesn't support multi-dimensional top-k, so we have to flatten and then reconstruct the indices.

In [10]:
vocab.get_max_activations(diag[0].T, ["input", "output"], k=20)

Unnamed: 0,input,output,value
0,##ro,##dd,3.471501
1,la,##un,3.26104
2,cur,##s,3.158608
3,str,##ange,3.154869
4,da,##mp,3.064186
5,st,##am,3.061367
6,tri,##ump,3.042327
7,##my,##nts,3.031496
8,bea,##k,3.031174
9,fe,##ug,3.026405


Okay, so we see 3 different classes of 2-grams.
- token combinations into words (for ##ce -> force)
- (articles, noun) pairs
- other: (of course), (their names), (an forth)?, (an ##ach)?

Importantly, these are not the most frequent 2-grams, just the ones that the model decided to learn in the MLP path. 

Some notes:
- ``a ##ud`` and ``a ##ut`` can result in many words and are probably the tokens the model is most annoyed about that they weren't tokenizer together.
- I'm a bit surprised at the complexity of words that snuck into the dataset, "mustered", "anent", "histo...", "anach...", "aud..."

#### Preceding and Following tokens

Given this diagonal matrix, we can also analyze which words are most important indicators for the next word or the other way around.
For instance, we can ask:
- *"what tokens are most important for the model to decide to predict the token 'game'"* (preceding token).
- *"what tokens does the token 'game' infer most"* (following token).

In [5]:
token = "magic"
idx = vocab[token]

preceding = vocab.tokenize(torch.topk(diag[0, idx], k=10).indices)
following = vocab.tokenize(torch.topk(diag[0, :, idx], k=10).indices)

pd.DataFrame(dict(preceding=preceding, self=[token]*10, following=following))

Unnamed: 0,preceding,self,following
0,##ious,magic,her
1,##ro,magic,with
2,inside,magic,a
3,##one,magic,in
4,of,magic,e
5,something,magic,too
6,fe,magic,something
7,bea,magic,prin
8,##al,magic,per
9,after,magic,mi


Again, left and right are not related, this is simply a concise visualization. 
We can see that it makes a lot of sense (except ``magic ##ge``). Magician, magic dust, and so forth seem very obvious.

#### Articles

An interesting phenomenon observed above is that the model has very strong activations for articles. Let's study this a bit more in-depth.

We can do this quite simply by taking the weights for both for all subsequent tokens and plotting them together.

In [7]:
mask_non_words = torch.tensor([vocab.inv[idx][0].isalpha() for idx in range(len(vocab))])

vowels = ['a', 'e', 'i', 'o', 'u']

token = vocab.tokenize(torch.arange(len(vocab)))
df = pd.DataFrame(dict(x=diag[0, :, vocab["a"]].cpu(), y=diag[0, :, vocab["an"]].cpu(), token=token))
df = df[df.token.str[0].str.isalpha()]
df["guess"] = df.token.str[0].isin(vowels)

px.scatter(df, x="x", y="y", hover_name="token", color="guess", labels=dict(x="a", y="an")).show()

The result isn't as clean as I'd hoped but it seems that the model simply generally has a strong bias towards picking 'a' which is sensible. I'd assume this becomes more clear as models improve.

## Token Interactions
Until now, we've only looked at the direct path. This is fine, but the MLP encodes so much more information than (input output)-pairs. Specifically, it actually encodes (input, input, output)-triplets, being one of the reasons for its effectiveness.

So, in essence, until now, we've just looked at token interactions with itself. This reduces the UBE tensor to a matrix, which we can study. Now, we will perform another reduction, by just taking the first dimension of the UBE tensor, which means that we will get the input-input interactions for a certain token.

In [None]:
idx = vocab["girl"]
o = einsum(model.w_e, model.w_e, model.w_l, model.w_r, model.w_p, model.w_u[idx], "in1 emb1, in2 emb2, hid emb1, hid emb2, out hid, out -> in1 in2").detach()
o = 0.5 * (o + o.T)

topk = torch.topk(o.tril().flatten(), k=25)
input1, input2 = torch.unravel_index(topk.indices, o.size())
pd.DataFrame(dict(input1=vocab.tokenize(input1), input2=vocab.tokenize(input2), value=topk.values.cpu()))