# Language modelling
*Thomas Dooms*

The previous chapter covered how to study bilinear MLPs in a real-world scenario. We also discussed how to decompose interaction matrices into eigenvectors, which are interpretable and causally relevant. We now turn to language models and explain how this method can be leveraged to better understand latent interactions. This is done by leveraging latent feature bases, such as features from an SAE. Our aim is to demonstrate that weight-based interpretability is not some theoretical method that only works in small models but that it can be combined with many other techniques in complex real-world environments. In essence, weight-based interpretability allows us to 'trace' between existing features to better understand how they are formed or what their effect is.

The main analysis technique we will discuss here is training two SAEs around the bilinear MLP. We then use the features from the output SAE as output features (just like the digits we chose previously). That way, we can decompose the MLP to which SAE input features interact strongly toward that output. In a sense, this technique finds shallow circuits (but grounded in the weights) in an MLP. 

In [58]:
%load_ext autoreload
%autoreload 2

from language import Transformer
from sae import Tracer, Visualizer
from sae.functions import compute_truncated_eigenvalues

import plotly.express as px
import torch

device = "cuda:0"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Setting up.
We generously received quite a bit of compute from Eleuther to train capable bilinear transformers. Here, we will study a GPT2-medium level model (which is probably somewhat undertrained) but can still produce coherent sentences and has non-trivial knowledge about some topics. The model has 16 layers and about 300M parameters. We have trained SAEs around each MLP layer for later analysis. 

Since we are using a modified architecture, we wrote the infrastructure ourselves, which may result in some annoyances due to the lack of standardization. Anyway, we load the model as in the code below. We can also generate some text to assess its capabilities. 

In [59]:
# Load a bilinear language model, trained on FineWeb-EDU.
# This dataset is mostly (high-quality) scientific and educational texts.
# This corresponds to GPT2-medium (16 layers, 1024 dimensions).
torch.set_grad_enabled(False)
model = Transformer.from_pretrained("tdooms/fw-medium").to(device)

In [60]:
# We can see the model can generate coherent text.
# The model tries to be a smart-ass (FineWeb) but kinda fails.
model.generate("If I may give you a piece of advice:", max_length=30)

'If I may give you a piece of advice: life is only a walk in a shop. It can actually be said much more loudly than life can be said in a shop. Give it your appreciation'

We then instantiate some useful helper objects: 
- ``Tracer`` loads SAEs around the MLP of a given layer and contains some helper functions to compute interaction matrices between the two.
- ``Visualizer`` shows the (pre-computed) top activations of SAE features to understand their meaning.

In [61]:
# We set up a Tracer object, which is a utility class to find interesting interactions between two SAEs around an MLP. 
# Let's inspect a middle layer.
tracer = Tracer(model, layer=7, inp=dict(expansion=8), out=dict(expansion=8))

# We then create a visualizer for both SAEs.
# Implementation-wise, this queries some pre-computed max-activations and shows them in a nice format.
inp_vis = Visualizer(model, tracer.inp)
out_vis = Visualizer(model, tracer.out)

### Finding interesting output features.
With that out of the way, we now get to the interesting bit: analyzing interactions. This can be performed through several means and with multiple levels of rigor; this tutorial will only focus on the simplest: cherry-picking interesting examples. One way we found intriguing examples (that worked well) was by looking at high-level characteristics of all interaction matrices and choosing the outliers. As a metric, we choose the magnitude of the top eigenvectors. Keep in mind that this is simply a heuristic and not extremely principled. Given such features, let's look at their cosine similarities. 

In [62]:
# Running this cell may take awhile depending on your CPU/GPU.

# Compute output features whose top eigenvalues are high, likely indicating some interesting structure.
eigenvals = tracer.compute(compute_truncated_eigenvalues, project=False, k=2)
vals, idxs = eigenvals.topk(10)

# Plot the cosine similarity between these features to see if any are related.
dirs = tracer.out.w_dec.weight[:, idxs]
sims = torch.cosine_similarity(dirs[..., None], dirs[:, None], dim=0)

# Visualize them nicely.
labels = [f"{i}" for i in idxs.cpu()]
px.imshow(sims.cpu(), color_continuous_scale="RdBu", color_continuous_midpoint=0, x=labels, y=labels)

100%|██████████| 256/256 [01:06<00:00,  3.83it/s]


Interestingly, the top two features form a (somewhat) linear subspace. We can try to understand their meaning by looking at their top activations.

In [None]:
# Let's inspect the (somewhat) linear subspace.
# This function can visualize arbitrarily many features 
# It even has a dark mode, which you should disable if you're using a white background.
out_vis(3834, 751, dark=True)

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


feature 3834
9.2 :   next time the child logs on. [48;2;3;3;17mNo[0m work [48;2;6;5;25mis[0m [48;2;221;73;104mlost[0m[48;2;10;7;34m.[0m<0x0A>Windows Vista [48;2;3;3;17meven[0m provides easy[48;2;6;5;25m-[0mto-read activity reports that show how children have
8.9 :  [48;2;2;2;15m,[0m but he doesn’t let [48;2;9;7;31mthat[0m interf[48;2;221;73;104mere[0m [48;2;132;38;129mwith[0m [48;2;8;6;29mhim[0m enjoying his life like any other 22-year-old. He has a
8.8 :   The Tafahum initiative coun[48;2;8;6;29mters[0m big[48;2;74;16;121mot[0m[48;2;221;73;104mry[0m [48;2;13;10;40mand[0m promotes tolerance. In the aftermath of September 11, we need to develop
8.5 :   lawful content, applications, services or [48;2;10;7;34munh[0m[48;2;90;21;126marm[0m[48;2;221;73;104mful[0m [48;2;6;5;25mdevices[0m.<0x0A>- No Throttling: Broadband providers may not deliberately target some lawful
8.5 :  100 [48;2;3;3;17ma[0merial combat victories with [48;2;4;4;21mno[0m [48;2;2

The first feature fires on negative words that have recently been negated (not lost, no interference). The second feature does the inverse; it fires on positive words that were negated (not free, little relief). It's not surprising that this forms a linear subspace as the two are opposites, but it is still interesting. What I find fascinating is that we also found this subspace (using the same technique) in a completely different model trained on TinyStories. 

### Understanding feature interactions.
To understand how these features are formed, we can then look at the most salient entries in their interaction matrices. We find that the eigendecomposition of any output feature is quite low-rank. Sadly, while interpreting these eigenvectors in image models is easy (the eigenvectors are simply images), here, the eigenvectors are vectors of all SAE features. Consequently, we have to resort to sparsity.

Here, we take the top 50 interactions in the whole interaction matrix and plot their submatrix (the smallest matrix containing all of them). We see that this matrix is quite small, and the features interact in a seemingly systematic manner.

In [64]:
# Compute the interaction matrix between SAE inputs for a given output SAE feature.
q = tracer.q(3834, project=True)

# Find the most relevant rows/columns by max interactions.
idxs = q.flatten().abs().topk(50).indices
i1, i2 = torch.unravel_index(idxs, q.shape)
idxs = torch.cat([i1, i2]).unique()

# Plot the most important sub-interactions for the output feature.
labels = [f"{i}" for i in idxs.cpu()]
fig = px.imshow(q[idxs.flip(0)][:, idxs].cpu(), color_continuous_midpoint=0, color_continuous_scale="RdBu", x=labels, y=labels, )
fig.update_xaxes(side="top")

### Future directions.
While we believe these results are exciting and finally provide a way forward towards understanding MLPs, there is still a lot to explore. We're not at all convinced that the proposed methodology is the best way to understand feature interactions. One possible research avenue is (again) to consider shared structure between features. This can be done using variants of sparse coding or certain tensor decompositions.