# Fashion-MNIST | Single Layer Bilinear Model
This notebook trains the model and generates the figures for the paper "Weight-based Decomposition: A Case for Bilinear MLPs"

# Setup

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
!pip install einops
!pip install jaxtyping
!git clone https://github.com/tdooms/bilinear-interp.git
# !pip install git+https://github.com/2020leon/rpca.git


In [None]:
%cd /content/bilinear-interp

In [None]:
# !git pull

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import itertools
import einops
from collections import defaultdict

from mnist.model import *
from mnist.utils import *
from mnist.plotting import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load Data

In [None]:
# Import Fashion MNIST dataset
train_dataset = torchvision.datasets.FashionMNIST(root='./data',
                                           train=True,
                                           transform=transforms.ToTensor(),
                                           download=True)

test_dataset = torchvision.datasets.FashionMNIST(root='./data',
                                          train=False,
                                          transform=transforms.ToTensor())

In [None]:
# Data loader
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                          batch_size=batch_size,
                                          shuffle=False)

DATA_CLASSES = ['T-shirt/Top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

In [None]:
examples = iter(test_loader)
example_data, example_targets = next(examples)

for label in range(10):
    plt.figure()
    mask = example_targets == label
    data = example_data[mask]
    num = min(data.shape[0],10)
    for i in range(num):
        plt.subplot(2,5,i+1)
        plt.imshow(data[i][0], cmap='binary')
    plt.suptitle(f'{label}: {DATA_CLASSES[label]}')
    plt.show()

# Train Model

In [None]:
cfg = MnistConfig()
cfg.random_seed = 0
cfg.n_layers = 1
cfg.hidden_dim = 300
cfg.num_epochs = 2 + 10 + 40
cfg.lr = 0.001
cfg.lr_decay = 0.5
cfg.lr_decay_step = 2
cfg.weight_decay = 1
cfg.rms_norm = False
cfg.bias = False
cfg.noise_sparse = 0
cfg.noise_dense = 0.33
cfg.layer_noise = 0.33

model = MnistModel(cfg).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
linearLR = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=0.01, end_factor=1, total_iters = 2)
stepLR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.lr_decay_step, gamma=cfg.lr_decay)
constLR = torch.optim.lr_scheduler.ConstantLR(optimizer, factor = cfg.lr_decay**(10/cfg.lr_decay_step), total_iters = 1000)
scheduler = torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[linearLR, stepLR, constLR], milestones=[2, 13])

model.train(train_loader, test_loader, optimizer = optimizer, scheduler = scheduler)

## Save model

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import pickle

In [None]:
filename = '/content/drive/MyDrive/AI Alignment/Codebooks_In_Superposition/deep_bilinear_model_with_biases_(test).pkl'
with open(filename, 'wb') as f:
  pickle.dump(model, f)

## Load model

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import pickle

In [None]:
filename = '/content/drive/MyDrive/AI Alignment/Codebooks_In_Superposition/deep_bilinear_model_with_biases_(test).pkl'
with open(filename, 'rb') as f:
  model = pickle.load(f)

# Eigendecomposition

In [None]:
W = model.layers[0].linear1.weight.to("cpu").detach()
V = model.layers[0].linear2.weight.to("cpu").detach()
W_out = model.linear_out.weight.to("cpu").detach()
W_in = model.linear_in.weight.to("cpu").detach()
B = get_B_tensor(W, V)
B = einops.rearrange(B, "out (in1 in2) -> out in1 in2", in1 = model.cfg.d_hidden)

B_proj = einops.einsum(W_out, B, "class h2, h2 in1 in2-> class in1 in2")

In [None]:
logits = torch.eye(B_proj.shape[0], B_proj.shape[0])
eig_plotter = EigenvectorPlotter(B_proj, logits, dataset=train_dataset, Embed = W_in)

for i in range(10):
    eig_plotter.plot_component(i, suptitle=f"Class: {DATA_CLASSES[i]}", vmax=0.25, classes = DATA_CLASSES, topk_eigs = 3, sort='eigs')
# plt.tight_layout()

In [None]:
file_pre = f'/content/drive/MyDrive/AI Safety/Bilinear Features/Noise Regulation/fmnist_noise_reg_weight_decay_and_noise_class_'
logits = torch.eye(B_proj.shape[0], B_proj.shape[0])
eig_plotter = EigenvectorPlotter(B_proj, logits, dataset=train_dataset, Embed = W_in)

for i in range(10):
    eig_plotter.plot_component(i, suptitle=f"Class: {DATA_CLASSES[i]}", vmax=0.25, classes = DATA_CLASSES,
                               topk_eigs = 3, sort='eigs',
                               filename = file_pre+str(i)+'.png')
