# Tiny Stories | Single Layer Bilinear Model
This notebook generates the figures for the paper "Weight-based Decomposition: A Case for Bilinear MLPs"

# Setup

In [None]:
!pip install einops
!pip install jaxtyping
!git clone https://github.com/tdooms/bilinear-interp.git
!pip install transformers
!pip install wandb
!pip install datasets
!pip install evaluate
!pip install accelerate
!pip install transformer_lens
!pip install nnsight

In [None]:
%cd /content/bilinear-interp

In [None]:
# !git pull

In [None]:
%load_ext autoreload
%autoreload 2

from language import Transformer, Config
import plotly.express as px
import torch
import pandas as pd
from einops import *
import itertools
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

torch.set_grad_enabled(False)

# name = "tdooms/TinyStories-1-512-i"
n_layer = 1
d_model = 1024
modifier = 'i5n'

# config = Config.from_pretrained(name)
model = Transformer.from_pretrained(n_layer=n_layer, d_model=d_model, modifier=modifier).cpu()
vocab = model.vocab

# color = dict(color_continuous_midpoint=0, color_continuous_scale="RdBu")
# facet = dict(height=200 * config.n_layer + 200, facet_col=0, facet_col_wrap=config.n_head)

# def set_facet_labels(fig):
#     for annotation in fig.layout.annotations:
#         facet = int(annotation.text.split("=")[-1])
#         annotation.update(text=f"Head {facet // config.n_head}.{facet % config.n_head}")
#     return fig

In [None]:
def style_df(df):
    cols = (df.dtypes == 'float32').values
    vals = df.iloc[:,cols]
    max = vals.max().max()
    min = vals.min().min()
    vmax = np.max([max, np.abs(min)])

    cm = sns.color_palette("RdBu", as_cmap=True)
    df = df.style.background_gradient(cmap=cm, vmin=-vmax, vmax=vmax)
    return df

def display_OVE_vec(vec):
    # vec: [d_head+1 vocab]
    df_list = []
    head_names = ['Direct'] + [f"Head {head}" for head in range(vec.shape[0]-1)]
    for head, name in enumerate(head_names):
        df = vocab.describe(vec[head], [name])
        if head == 0:
            df = df.rename(columns={'value': f'Value'})
        else:
            df = df.rename(columns={'value': f'Value {head-1}'})
        df_list.append(df)

    df = pd.concat(df_list, axis=1)
    df = style_df(df)

    with pd.option_context('display.max_rows', None, 'display.max_columns', None):  # more options can be specified also
        display(df)

def describe(tensor, k=10, axes=None, value=None):
    hig = torch.topk(tensor.flatten(), k=k, largest=True)
    low = torch.topk(tensor.flatten(), k=k, largest=False)

    values = torch.cat([hig.values, low.values.flip(0)])
    indices = torch.cat([hig.indices, low.indices.flip(0)])

    dims = torch.unravel_index(indices, tensor.size())
    if axes is None:
        axes = [f"Dim {i}" for i in range(len(dims))]
    if value is None:
        value = 'Value'
    data = {axis:v for axis,v in zip(axes, dims)}
    return pd.DataFrame({**data, value:values})

def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=True, scale=None) -> torch.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / np.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight

In [None]:
model.config

# Eigenvectors for specific output tokens

In [None]:
W_u = model.w_u.detach()
W_e = model.w_e.detach()
B = model.b.detach()[0]

OV = model.ov[0].detach()
OVE = OV @ W_e
OVE = torch.cat([W_e.unsqueeze(0), OVE], dim=0) #head, d_model, vocab

In [None]:
token = 'swim'
comparison_tokens = ['run', 'climb', 'eat', 'see', 'smell', 'walk', 'fly', 'sit', 'sleep']

tok_idx = vocab[token]
comp_tok_ids = vocab[comparison_tokens]

# Unembed = W_u[tok_idx]
Unembed = W_u[tok_idx] - W_u[comp_tok_ids].mean(dim=0)

Q = einsum(B, Unembed, "out ..., out -> ...")
eigvals, eigvecs = torch.linalg.eigh(Q)

In [None]:
plt.plot(eigvals, '.-')
plt.ylabel('Eigenvalue')
plt.xlabel('Index')
plt.title(f'Eigenvalues for "{token}" interaction matrix')

## Virtual Token Basis
Virtual tokens are the input tokens that are passed through the embedding layer and one attention head or skips attention (direct).

In [None]:
eig_idx = -1
eigvec = eigvecs[:,eig_idx]
eigvec_toks = einsum(eigvec, OVE, "d_model, head d_model tok -> head tok")

display_OVE_vec(eigvec_toks)

In [None]:
eig_idx = 2
eigvec = eigvecs[:,eig_idx]
eigvec_toks = einsum(eigvec, OVE, "d_model, head d_model tok -> head tok")

display_OVE_vec(eigvec_toks)

## Top Activating Inputs

In [None]:
from nnsight import LanguageModel
nnsight_model = LanguageModel(model, tokenizer=model.tokenizer)

In [None]:
def print_attn_weighted_text(data, batch_idx, pos_idx, eigvec, eigval, start_toks = 15, end_toks = 3):
    start = max(pos_idx-start_toks, 0)
    end = pos_idx + end_toks

    input_ids = data['input_ids'][batch_idx]
    with nnsight_model.trace(input_ids) as trace:
        attn_input = nnsight_model.transformer.h[0].attn.input[0][0].save()
        qkv = nnsight_model.transformer.h[0].attn.qkv.output.save()
        o = nnsight_model.transformer.h[0].attn.o.output.save()
        mlp_input = nnsight_model.transformer.h[0].mlp.input[0][0].save()

    # get attention weighted contributions to pos_idx
    q, k, v = rearrange(qkv, 'batch seq (n_proj n_head d_head) -> n_proj batch n_head seq d_head', n_proj=3, n_head=model.config.n_head).unbind(dim=0)
    q, k = model.transformer.h[0].attn.rotary(q,k, q.device)
    attn_weight = scaled_dot_product_attention(q, k, v)
    z = attn_weight[:,:,pos_idx].unsqueeze(-1) * v    #get attn-weighted value vecs that contribute to pos_idx
    z = rearrange(z, 'batch n_head seq d_head -> batch seq (n_head d_head)')
    o = model.transformer.h[0].attn.o(z)

    # add direct path contribution
    o[:, pos_idx] += attn_input[:,pos_idx]

    # dot product with eigvec
    sims = o[0] @ eigvec
    max_sim = sims.abs().max()
    sign = torch.sign(sims.sum())
    sims = sign * sims

    # get color rgb
    colors = 255 * plt.cm.RdBu((sims[start:pos_idx+1]+max_sim)/ (2*max_sim))
    start_text = model.vocab[data['input_ids'][batch_idx, start:pos_idx+1]]
    end_text = model.vocab[data['input_ids'][batch_idx, pos_idx+1:end+1]]

    # compute brightness/luminance in order to change text color for dark backgrounds
    linear_colors = colors/255
    linear_colors[linear_colors <= 0.04045] = linear_colors[linear_colors <= 0.04045]/12.92
    linear_colors[linear_colors > 0.04045] = ((linear_colors[linear_colors > 0.04045] + 0.055) / 1.055) ** 2
    luminance=0.2126*linear_colors[:,0]+0.7152 * linear_colors[:,1]+0.0722 * linear_colors[:,2]
    luminance[luminance <= 0.008856] = 903.3 * luminance[luminance <= 0.008856]
    luminance[luminance > 0.008856] = luminance[luminance > 0.008856] ** (1/3) * 116 - 16

    color_text = ["\033[" + ("37;" if luminance[i] < 60 else "30;") + f"48;2;{int(colors[i,0])};{int(colors[i,1])};{int(colors[i,2])}m" + \
                (start_text[i] if i < len(start_text)-1 else "\033[1m[" + start_text[i]+"]\033[0m") for i in range(len(start_text))
                ]

    act = eigval * (sims.sum())**2
    print(f"Activation: {act:.2f} | " + ' '.join(color_text+end_text))

In [None]:
data = model.dataset(collated = True, tokenized = True, split = 'validation')

In [None]:
eig_idxs = torch.tensor(list(range(20)) + list(range(-20,0)))

eig_acts = []

batch_size = 1000
for i in range(0, data['input_ids'].shape[0], batch_size):
    input_ids = data['input_ids'][i:i+batch_size]
    with nnsight_model.trace(input_ids) as trace:
        mlp_input = nnsight_model.transformer.h[0].mlp.input[0][0].save()
    sims = einsum(mlp_input, eigvecs[:,eig_idxs], "batch pos d_model, d_model eig -> batch pos eig")
    acts = (sims**2)
    eig_acts.append(acts)

eig_acts = torch.cat(eig_acts, dim=0)

In [None]:
idxs = [-1, -2, -3, -4, 0, 1, 2, 3]
k = 200

topk = rearrange(eig_acts[:,:,idxs], "batch pos eig -> (batch pos) eig").topk(dim=0, k = k)
batch_idxs, pos_idxs = torch.unravel_index(topk.indices, eig_acts.shape[:2])

for idx in idxs:
    topk = rearrange(eig_acts[:,:,idx], "batch pos -> (batch pos)").topk(dim=0, k = k)
    batch_idxs, pos_idxs = torch.unravel_index(topk.indices, eig_acts.shape[:2])

    eigvec = eigvecs[:,eig_idxs[idx]]
    eigval = eigvals[eig_idxs[idx]]
    print(f'\nEig Idx: {idx} for "{token}"\n')

    batches_used = set()
    count = 0
    for batch_idx, pos_idx in zip(batch_idxs, pos_idxs):
        if batch_idx.item() in batches_used:
            continue
        else:
            count +=1

        if count > 8:
            break
        print_attn_weighted_text(data, batch_idx, pos_idx, eigvec, eigval, start_toks = 20)
        batches_used.add(batch_idx.item())

