In [1]:
import torch
import torchvision
import torch.nn.functional as F

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand].float()
Y = train_dataset.targets[rand]


w1 = torch.randn(784,128) / 784**0.5
w2 = torch.randn(128,64) / 128**0.5
w3 = torch.randn(64,10) / 64**0.5

b3 = torch.zeros(10)

batch_size = 32

for epoch in range(3):
    lr = 0.1 if epoch < 1 else 0.01
    for i in range(0,len(X),batch_size):
        # Forward
        zero = torch.tensor([0])
        Xb = X[i:i+batch_size] / 255 # 32, 784
        Yb = Y[i:i+batch_size]

        l1 = Xb@w1
        relu1 = torch.maximum(zero, l1)
        
        l2 = relu1@w2
        relu2 = torch.maximum(zero, l2)

        logits = relu2@w3 + b3

        logit_maxes = logits.max(dim=1, keepdim=True).values # the 1 max value from each row
        norm_logits = logits - logit_maxes # makes numbers <= 0 for not too high exp outputs

        counts = norm_logits.exp() # e**logits -- top of softmax
        counts_sum = counts.sum(dim=1, keepdim=True) # sum e**logits
        counts_sum_inv = counts_sum**-1 # place sum e**logits on bottom of softmax
        probs = counts * counts_sum_inv # probs = softmax = e**logits / sum e**logits


        # Loss
        # It's correct classes, then ln, then average
        logprobs = probs.log()
        loss = -logprobs[range(len(Yb)), Yb].mean()


        # Backward
        dloss = 1.0
        dlogprobs = torch.zeros_like(logprobs)
        dlogprobs[range(len(Yb)), Yb] = -1.0 / len(Yb)
        dprobs = 1.0 / probs * dlogprobs
        dcounts = counts_sum_inv * dprobs
        dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
        dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
        dcounts += dcounts_sum
        dnorm_logits = norm_logits.exp() * dcounts
        dlogits = dnorm_logits
        dlogit_maxes = -dnorm_logits.sum(dim=1, keepdim=True)
        dlogits += F.one_hot(logits.max(dim=1).indices, logits.shape[1]) * dlogit_maxes
        drelu2 = dlogits@w3.T
        dw3 = relu2.T@dlogits
        dl2 = (l2 > 0) * drelu2
        drelu1 = dl2@w2.T
        dw2 = relu1.T@dl2
        dl1 = (l1 > 0) * drelu1
        dw1 = Xb.T@dl1
        db3 = dlogits.sum(0)


        # Update
        w1 = w1 - lr * dw1
        w2 = w2 - lr * dw2
        w3 = w3 - lr * dw3
        b3 = b3 - lr * db3
        
        print(loss)

tensor(2.3548)
tensor(2.3129)
tensor(2.2627)
tensor(2.2440)
tensor(2.2277)
tensor(2.2023)
tensor(2.2183)
tensor(2.1664)
tensor(2.1465)
tensor(2.0796)
tensor(2.1085)
tensor(2.1309)
tensor(1.9663)
tensor(2.0248)
tensor(2.0740)
tensor(2.0178)
tensor(1.9805)
tensor(1.9003)
tensor(1.8625)
tensor(1.9010)
tensor(1.8520)
tensor(1.7655)
tensor(1.6726)
tensor(1.7661)
tensor(1.5009)
tensor(1.7434)
tensor(1.5808)
tensor(1.5005)
tensor(1.6302)
tensor(1.3930)
tensor(1.4676)
tensor(1.4147)
tensor(1.1608)
tensor(1.4194)
tensor(1.2047)
tensor(1.4450)
tensor(1.2615)
tensor(1.1702)
tensor(1.0602)
tensor(1.0659)
tensor(1.2204)
tensor(1.0932)
tensor(1.3577)
tensor(0.9958)
tensor(0.9153)
tensor(1.0706)
tensor(0.7935)
tensor(0.9784)
tensor(0.8514)
tensor(1.0897)
tensor(1.0692)
tensor(1.0699)
tensor(0.8660)
tensor(0.7986)
tensor(0.7646)
tensor(1.0713)
tensor(0.9288)
tensor(1.0527)
tensor(0.9323)
tensor(0.4837)
tensor(0.6447)
tensor(0.9069)
tensor(0.8895)
tensor(0.8391)
tensor(0.9097)
tensor(0.8732)
tensor(0.7

In [2]:
# Test
# Get Test Data
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
rand = torch.randperm(test_dataset.data.shape[0])
X_test = test_dataset.data.reshape(-1,28*28)[rand].float()
Y_test = test_dataset.targets[rand]

batch_size = 32
correct = 0
total = 0
for i in range(0,len(X_test),batch_size):
    # Forward
    zero = torch.tensor([0])
    Xb = X_test[i:i+batch_size] / 255 # 32, 784
    Yb = Y_test[i:i+batch_size]

    l1 = Xb@w1
    relu1 = torch.maximum(zero, l1)
    
    l2 = relu1@w2
    relu2 = torch.maximum(zero, l2)

    logits = relu2@w3 + b3

    prediction = logits.argmax(dim=1)
    correct += (prediction == Yb).sum().item()
    total += len(Yb)
accuracy = correct / total
print(accuracy)

0.9647
