# MNIST | Two Layer Bilinear Model
This notebook trains the model and generates the figures for the paper "Weight-based Decomposition: A Case for Bilinear MLPs"

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install einops
!pip install jaxtyping
!git clone https://github.com/tdooms/bilinear-interp.git

In [None]:
%cd /content/bilinear-interp

In [None]:
# !git pull

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import itertools
import einops
from collections import defaultdict

from mnist.model import *
from mnist.utils import *
from mnist.plotting import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load Data

In [None]:
# Import MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data',
                                          train=False,
                                          transform=transforms.ToTensor())

In [None]:
# Data loader
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)


In [None]:
examples = iter(test_loader)
example_data, example_targets = next(examples)

for i in range(6):
    plt.subplot(2,3,i+1)
    plt.imshow(example_data[i][0], cmap='binary')
plt.show()

# Train Model

In [None]:
cfg = MnistConfig()
cfg.random_seed = 0
cfg.n_layers = 2
cfg.d_hidden = 300
cfg.num_epochs = 2 + 12 + 50
cfg.lr = 0.001
cfg.lr_decay =0.5
cfg.lr_decay_step = 2
cfg.weight_decay = 0.5
cfg.rms_norm = False
cfg.bias = False
cfg.noise_sparse = 0
cfg.noise_dense = 0.33
cfg.layer_noise = 0.33
cfg.logit_bias = False

model = MnistModel(cfg).to("cuda")

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
linearLR = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1, total_iters = 2)
stepLR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.lr_decay_step, gamma=cfg.lr_decay)
constLR = torch.optim.lr_scheduler.ConstantLR(optimizer, factor = cfg.lr_decay**(12/cfg.lr_decay_step), total_iters = 1000)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[linearLR, stepLR, constLR], milestones=[2, 15])

model.train(train_loader, test_loader, optimizer = optimizer, scheduler = scheduler)

## Save model

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import pickle

In [None]:
filename = '/content/drive/MyDrive/AI Safety/Bilinear Features/two_layer_mnist_20240523.pkl'
with open(filename, 'wb') as f:
  torch.save(model, f)

## Load model

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import pickle

In [None]:
filename = '/content/drive/MyDrive/AI Safety/Bilinear Features/two_layer_mnist_20240523.pkl'
with open(filename, 'rb') as f:
  model = torch.load(f)

# Eigen-decomposition

* For simple single layer model, we know the output directions from the weights of the linear readout layer.
* Low dim space. Only 10 directions for 10 digits.
* $W^\text{out}_{da} B_{ajk}$


In [None]:
W1 = model.layers[1].linear1.weight.to("cpu").detach()
V1 = model.layers[1].linear2.weight.to("cpu").detach()
W0 = model.layers[0].linear1.weight.to("cpu").detach()
V0 = model.layers[0].linear2.weight.to("cpu").detach()

W_out = model.linear_out.weight.to("cpu").detach()
W_in = model.linear_in.weight.to("cpu").detach()

B1 = einops.einsum(W_out, W1, V1, "class h, h in1, h in2 -> class in1 in2")
B_proj1 = 0.5 * B1 + 0.5 * B1.transpose(-2,-1)

In [None]:
class_idx = 9
Q = B_proj1[class_idx]
eigvals1, eigvecs1 = torch.linalg.eigh(Q)

plt.figure(figsize=(4,3))
plt.plot(eigvals1,'.-', markersize=7)
plt.ylabel('Eigenvalue')
plt.xlabel('Rank')
plt.title(f'Layer 2 Eigenvalues for "{class_idx}"')

In [None]:
eig_idxs = [-1, -2, 0, 1]
flip_signs = [-1, 1, -1, 1]

B0 = einops.einsum(eigvecs1[:,eig_idxs], W0, V0, "h eig, h in1, h in2 -> eig in1 in2")
B_proj0 = 0.5 * B0 + 0.5 * B0.transpose(-2,-1)
B_proj0 = torch.tensor(flip_signs).view(-1,1,1) * B_proj0

input = eigvecs1[:,eig_idxs].T.to(model.cfg.device)
# x = model.layers[1](input + 0.33 * input.std() * torch.randn_like(input))
# logits = model.linear_out(x + 0.33 * x.std() * torch.randn_like(x)).cpu().detach()
logits = model.linear_out(model.layers[1](input))
logits = logits.cpu().detach()

eig_plotter = EigenvectorPlotter(B_proj0, logits, dataset=train_dataset, Embed = W_in)

for i in range(B_proj0.shape[0]):
    eig_plotter.plot_component(i, suptitle=f"Layer-2 Eig Rank: {eig_idxs[i]}", vmax=0.25,
                               classes = range(10), topk_eigs = 4, sort='eigs')

In [None]:
filename_base = '/content/drive/MyDrive/AI Safety/Bilinear Features/Two Layers/'

for class_idx in range(10):
    Q = B_proj1[class_idx]
    eigvals1, eigvecs1 = torch.linalg.eigh(Q)

    plt.figure(figsize=(4,3))
    plt.plot(eigvals1,'.-', markersize=7)
    plt.ylabel('Eigenvalue')
    plt.xlabel('Rank')
    plt.title(f'Layer 2 Eigenvalues for "{class_idx}"')
    plt.savefig(filename_base + f'layer2_eigenvalues_{class_idx}.png')

    eig_idxs = [-1, -2, 0, 1]

    B0 = einops.einsum(eigvecs1[:,eig_idxs], W0, V0, "h eig, h in1, h in2 -> eig in1 in2")
    B_proj0 = 0.5 * B0 + 0.5 * B0.transpose(-2,-1)

    input = eigvecs1[:,eig_idxs].T.to(model.cfg.device)
    # x = model.layers[1](input + 0.33 * input.std() * torch.randn_like(input))
    # logits = model.linear_out(x + 0.33 * x.std() * torch.randn_like(x)).cpu().detach()
    logits = model.linear_out(model.layers[1](input))
    logits = logits.cpu().detach()

    eig_plotter = EigenvectorPlotter(B_proj0, logits, dataset=train_dataset, Embed = W_in)

    for i in range(B_proj0.shape[0]):
        eig_plotter.plot_component(i, suptitle=f"Layer-2 Eig Rank: {eig_idxs[i]}", vmax=0.25,
                               classes = range(10), topk_eigs = 4, sort='eigs')
        plt.savefig(filename_base + f'layer1_eigenvectors_{class_idx}_{i}.png')