## Virtual Token Interactions

This notebook dives into virtual token interactions through the attention mechanism.

In [None]:
from shared.transformer import Transformer, Config
import plotly.express as px
import torch
from einops import *
import pandas as pd
import itertools

torch.set_grad_enabled(False)
color = dict(color_continuous_midpoint=0, color_continuous_scale="RdBu")

name = "tdooms/TinyStories-1-512"
config = Config.from_pretrained(name)
model = Transformer.from_pretrained(name, config=config).cuda()

model.center_unembed().fold_norms()
vocab = model.vocab

When analyzing MLPs in our single-layer bilinear models, we can construct all kinds of tensors describing interactions in the latent (residual) space. We can then project these through the embedding or unembedding to get the actual tokens that are being manipulated through the network. This operation is a simplification as it ignores the attention mechanism. Therefore, we call this the "direct embedding" because it assumes only taking the direct path.

When including the attention mechanism, projecting out the latent space isn't as straightforward. In essence, the latent space is now a sum of the attention outputs and of the direct path. 
Formulaically, we have.

$$residual^{mid} = \sum_i (\lambda_i OV_{i}E t_v) + E t_d$$

Where $\lambda_i = t_d QK_i t_i$. if we ignore $\lambda$, we seen that the residual is simply a sum of paths. So, instead of only projecting out of $E$ as in the direct path, we can project out the other terms too. 

``TODO``: write some coherent story.

We define virtual tokens as tokens that have passed through the OV circuit of an attention head. Our full embedding then becomes $\text{cat}(E, OV_0 E, OV_1 E)$.

In [None]:
# The augmented embedding, containing direct and indirect embeddings
# This is actually a stack of embeddings, which is easier to work with
# Note that the direct path is the first element
e_full = torch.cat([model.w_e[None], model.ov[0] @ model.w_e[None]], dim=0)
b = model.b[0]

# We project b through the full embedding to get the interactions between virtual and direct tokens
# We can't construct this full tensor though, therefore, we only do so for a single token
token = "girl"
idx = vocab[token]

blocks = einsum(e_full, e_full, b, model.w_u[idx], "b1 hid1 tok1, b2 hid2 tok2, out hid1 hid2, out -> b1 b2 tok1 tok2")

We can take any norm of the specific interaction blocks of the heads to see how strong they are

In [None]:
# l1_norms = blocks.mean((2, 3))
l2_norms = torch.linalg.norm(blocks, dim=(2, 3))

title = f"Attention Head Interactions for \"{token}\""
px.imshow(l2_norms.cpu(), **color, title=title).update_layout(title_x=0.5)

Okay, this gives is a map from {token -> attention interaction}. However, ideally, we'd like the inverse map. This would allow us to see which tokens maximally activate certain attention head interactions.

Luckily, creating that isn't actually very difficult.

In [None]:
e_summed = e_full.mean(-1)
means = einsum(e_summed, e_summed, b, model.w_u, "b1 hid1, b2 hid2, res hid1 hid2, out res -> out b1 b2")

Let's check if our implementation is correct.

In [None]:
px.imshow(means[vocab["girl"]].cpu(), **color, title=title).update_layout(title_x=0.5)

Let's make some observations. 
- In the single layer 256 model (which has 4 attention heads), most stuff seems to happen in the direct-direct path.
- The single layer 512 model (8 heads) has the strongest interactions through attention head 3. 

In [None]:
traces = einsum(model.w_e, model.w_e, model.qk[0], "hid1 in1, hid2 in2, ... hid1 hid2 -> ...")
px.bar(y=traces.cpu(), labels=dict(x="trace", y="head"), title="trace of QK circuits").update_layout(title_x=0.5)

We can now find the output tokens which are most positively or negatively activated for each pair of heads

In [None]:
df = pd.DataFrame(index=list(range(20)))

for i, j in itertools.combinations_with_replacement(range(config.n_head), 2):
    df[f"{i}-{j}"] = vocab.tokenize(means[:, i, j].abs().topk(20).indices)
    
df

The amount of structure is quite limited

In [None]:
px.imshow(means.mean(0).cpu(), **color, title="Mean Interaction Strength")