# Section 1 of Toy Models of Superposition

In [1]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

In [2]:
def get_device():
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

device = get_device()
print(f"Using {device} device")

Using mps device


## 0 Sparsity Model

In [3]:
class ToyModelNoSparsity(nn.Module):
    def __init__(self):
        super().__init__()
        self.weights = nn.Parameter(torch.rand(2, 5), requires_grad=True)
        self.bias = nn.Parameter(torch.rand(5, 1), requires_grad=True)
        self.ReLU = nn.ReLU(inplace=True)
        
    def forward(self, x): # x is 5 * 1
        hidden = torch.matmul(self.weights, x)
        final = torch.matmul(self.weights.T, hidden)
        final += self.bias
        return self.ReLU(final)
        # return hidden

class CustomMSELoss(nn.Module):
    def __init__(self):
        super(CustomMSELoss, self).__init__()

    def forward(self, predictions, targets):
        return (predictions[0] - targets[0])**2 + (predictions[1] - targets[1])**2

model = ToyModelNoSparsity().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
loss_func = CustomMSELoss()

def train(loss_fn, optimizer):
    model.train()
    loss_total = 0
    for i in range(20000):
        # x = torch.ones(5, 1).to(device)
        x = torch.rand(5, 1).to(device)
        pred = model(x)
        loss = loss_fn(pred, x)
        loss_total += loss.item()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        # return
        

        if i % 1000 == 0 and i != 0:
            print("loss:", loss_total / 1000)
            loss_total = 0

In [4]:
# for i in range(10):
#     x = torch.rand(5,1).to(device)
#     print(x)
#     print(model(x))

In [5]:
# importance = torch.tensor([1, 1, 0.0, 0.0, 0.0])

In [6]:
train(loss_func, optimizer)

loss: 0.33812452913391783
loss: 0.13467690217640485
loss: 0.12783609981200425
loss: 0.12179778016703495
loss: 0.11305085868879905
loss: 0.0929384533845623
loss: 0.09619273880020683
loss: 0.09053687064720725
loss: 0.07961259642376217
loss: 0.08178187953260477
loss: 0.07655097778526943
loss: 0.07224933696678999
loss: 0.06987324978539983
loss: 0.06915357640144884
loss: 0.06935293157021079
loss: 0.05771950121252894
loss: 0.05540801834274407
loss: 0.04960186266301389
loss: 0.0416733879721578


In [7]:
print(model.weights.T)

tensor([[-0.6705,  0.4646],
        [ 0.7398,  0.5218],
        [-0.0177,  0.3185],
        [ 0.0542,  0.1224],
        [-0.0057,  0.2717]], device='mps:0', grad_fn=<PermuteBackward0>)


In [8]:
print(model.bias)

Parameter containing:
tensor([[0.1638],
        [0.0324],
        [0.7138],
        [0.0383],
        [0.3858]], device='mps:0', requires_grad=True)
