# Low Rank Approximation (not Low Rank Adaptation)

## setup

In [27]:
import numpy as np
import math
import copy
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
from sklearn.decomposition import PCA
%matplotlib inline

device = 'cuda' if torch.cuda.is_available() else 'cpu'

## data

In [28]:
data = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data', transform=torchvision.transforms.ToTensor(), download=True, train=True),
    batch_size=128,
    shuffle=True)

data_test = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('./data', transform=torchvision.transforms.ToTensor(), download=True, train=False),
    batch_size=128,
    shuffle=True)

## helpers

In [29]:
@torch.no_grad()
def accuracy(model, data):
    model = model.to(device)
    model.eval()
    correct = 0
    with torch.no_grad():
        for x, y in data:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            correct += (y_hat.argmax(dim=-1) == y).float().sum().item()
    return correct / len(data.dataset)

## Toy model

In [30]:
class ToyMNIST(nn.Module):
    def __init__(self):
        super().__init__()
        self.w1 = nn.Linear(784, 512)
        self.relu1 = nn.ReLU()
        self.w2 = nn.Linear(512, 256)
        self.relu2 = nn.ReLU()
        self.w3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.w1(x)
        x = self.relu1(x)
        x = self.w2(x)
        x = self.relu2(x)
        x = self.w3(x)
        return x

In [31]:
test_accuracies = []
losses = []
model = ToyMNIST().to(device)

In [32]:
def train(model, epochs=10, log_every=1, lr=3e-4):
    model = model.to(device)
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        for x, y in data:
            x, y = x.to(device), y.to(device)
            y_hat = model(x)
            loss = F.cross_entropy(y_hat, y)
            opt.zero_grad()
            loss.backward()
            opt.step()
        # stats
        test_accuracies.append(accuracy(model, data_test))
        losses.append(loss.item())
        if epoch % log_every == 0:
            print(f'{epoch} loss: {loss.item():.4f} test accuracy: {test_accuracies[-1]:.4f}')

train(model, epochs=10)

0 loss: 0.1429 test accuracy: 0.9362
1 loss: 0.1247 test accuracy: 0.9553
2 loss: 0.1328 test accuracy: 0.9653
3 loss: 0.0800 test accuracy: 0.9703
4 loss: 0.0501 test accuracy: 0.9736
5 loss: 0.0178 test accuracy: 0.9770
6 loss: 0.0415 test accuracy: 0.9789
7 loss: 0.1867 test accuracy: 0.9786
8 loss: 0.0109 test accuracy: 0.9779
9 loss: 0.0152 test accuracy: 0.9797


In [37]:
# eval on train and test
# ----------------------
# print(f'{accuracy(model, data)=:.4f} {accuracy(model, data_test)=:.4f}')

accuracy(model, data)=0.9959 accuracy(model, data_test)=0.9797


In [36]:
# save
# ----
# torch.save(model.state_dict(), 'weights/toymnist.pt')

# load
# ----
# toy = ToyMNIST().to(device)
# toy.load_state_dict(torch.load('weights/toymnist.pt'))

## ðŸŽµ work it, make it, do it faster, make me smaller, solve it stronger, code together ! ðŸŽ¶

In [39]:
# <CODE GOES HERE>