# Section 1 of Toy Models of Superposition

In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [2]:
def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

device = get_device()
print(f"Using {device} device")

Using mps device


## 0 Sparsity Model

In [3]:
class ToyModelNoSparsity(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.rand(2, 5), requires_grad=True)
        self.bias = nn.Parameter(torch.rand(5, 1), requires_grad=True)
        self.ReLU = nn.ReLU(inplace=True)
        
    def forward(self, x): # x is 5 * 1
        hidden = self.weights @ x
        final = self.weights.T @ hidden
        final += self.bias
        return self.ReLU(final)
        # return hidden

class CustomMSELoss(nn.Module):
    def __init__(self):
        super(CustomMSELoss, self).__init__()

    def forward(self, predictions, targets):
        sub_total = ((predictions - targets)**2).sum(0)
        return sub_total[0] + sub_total[1]
        # return (predictions[0] - targets[0])**2 + (predictions[1] - targets[1])**2

model = ToyModelNoSparsity().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_func = CustomMSELoss()

In [4]:
NUM_EPOCHS = 10
BATCHS_PER_EPOCH = 1000
BATCH_SIZE = 10


def train(loss_fn, optimizer):
    model.train()
    loss_total = 0
    for epoch in range(NUM_EPOCHS):
        for i in range(BATCHS_PER_EPOCH):
            x = torch.rand(BATCH_SIZE, 5, 1).to(device)
            pred = model(x)
            loss = loss_fn(pred, x)
            loss_total += loss.item()
    
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        print("EPHOCH:", epoch + 1, "loss:", loss_total / (BATCHS_PER_EPOCH * BATCH_SIZE))
        loss_total = 0

In [5]:
train(loss_func, optimizer)

EPHOCH: 1 loss: 0.13974299584925176
EPHOCH: 2 loss: 0.05760856195688248
EPHOCH: 3 loss: 0.012864610306825489
EPHOCH: 4 loss: 0.002407945555448532
EPHOCH: 5 loss: 0.00047870866390294396
EPHOCH: 6 loss: 8.57621162009309e-05
EPHOCH: 7 loss: 1.5125769470341765e-05
EPHOCH: 8 loss: 2.786391501012986e-06
EPHOCH: 9 loss: 4.839874901676922e-07
EPHOCH: 10 loss: 8.851857738108038e-08


In [6]:
print(model.weights.T)

tensor([[ 3.0642e-01, -9.5182e-01],
        [ 9.5193e-01,  3.0648e-01],
        [ 9.1428e-05,  3.1215e-04],
        [ 9.4946e-05,  3.2291e-04],
        [ 9.5167e-05,  3.2710e-04]], device='mps:0',
       grad_fn=<PermuteBackward0>)


In [7]:
print(model.bias)

Parameter containing:
tensor([[ 5.3713e-04],
        [-3.4023e-04],
        [ 2.0237e-01],
        [ 9.7385e-01],
        [ 9.2674e-01]], device='mps:0', requires_grad=True)
