In [1]:
import torch
import torchvision
import torch.nn.functional as F

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand].float()
Y = train_dataset.targets[rand]


w1 = torch.randn(784,128) / 784**0.5
w2 = torch.randn(128,64) / 128**0.5
w3 = torch.randn(64,10) / 64**0.5

b3 = torch.zeros(10)

batch_size = 32

for epoch in range(3):
    lr = 0.1 if epoch < 1 else 0.01
    loop_count = 0
    for i in range(0,len(X),batch_size):
        # Forward
        zero = torch.tensor([0])
        Xb = X[i:i+batch_size] / 255 # 32, 784
        Yb = Y[i:i+batch_size]

        l1 = Xb@w1
        relu1 = torch.maximum(zero, l1)
        
        l2 = relu1@w2
        relu2 = torch.maximum(zero, l2)

        logits = relu2@w3 + b3

        logit_maxes = logits.max(dim=1, keepdim=True).values # the 1 max value from each row
        norm_logits = logits - logit_maxes # makes numbers <= 0 for not too high exp outputs

        counts = norm_logits.exp() # e**logits -- top of softmax
        counts_sum = counts.sum(dim=1, keepdim=True) # sum e**logits
        counts_sum_inv = counts_sum**-1 # place sum e**logits on bottom of softmax
        probs = counts * counts_sum_inv # probs = softmax = e**logits / sum e**logits


        # Loss
        # It's correct classes, then ln, then average
        logprobs = probs.log()
        loss = -logprobs[range(len(Yb)), Yb].mean()


        # Backward
        dloss = 1.0
        dlogprobs = torch.zeros_like(logprobs)
        dlogprobs[range(len(Yb)), Yb] = -1.0 / len(Yb)
        dprobs = 1.0 / probs * dlogprobs
        dcounts = counts_sum_inv * dprobs
        dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
        dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
        dcounts += dcounts_sum
        dnorm_logits = norm_logits.exp() * dcounts
        dlogits = dnorm_logits
        dlogit_maxes = -dnorm_logits.sum(dim=1, keepdim=True)
        dlogits += F.one_hot(logits.max(dim=1).indices, logits.shape[1]) * dlogit_maxes
        drelu2 = dlogits@w3.T
        dw3 = relu2.T@dlogits
        dl2 = (l2 > 0) * drelu2
        drelu1 = dl2@w2.T
        dw2 = relu1.T@dl2
        dl1 = (l1 > 0) * drelu1
        dw1 = Xb.T@dl1
        db3 = dlogits.sum(0)


        # Update
        w1 = w1 - lr * dw1
        w2 = w2 - lr * dw2
        w3 = w3 - lr * dw3
        b3 = b3 - lr * db3
        
        if loop_count %100 == 0:
            print(f'Loss: {loss:.3f}')
        loop_count += 1

Loss: 2.234
Loss: 0.707
Loss: 0.491
Loss: 0.904
Loss: 0.207
Loss: 0.278
Loss: 0.308
Loss: 0.075
Loss: 0.375
Loss: 0.065
Loss: 0.074
Loss: 0.308
Loss: 0.157
Loss: 0.102
Loss: 0.107
Loss: 0.386
Loss: 0.027
Loss: 0.130
Loss: 0.272
Loss: 0.024
Loss: 0.136
Loss: 0.052
Loss: 0.243
Loss: 0.075
Loss: 0.097
Loss: 0.160
Loss: 0.024
Loss: 0.180
Loss: 0.030
Loss: 0.049
Loss: 0.154
Loss: 0.029
Loss: 0.047
Loss: 0.053
Loss: 0.239
Loss: 0.022
Loss: 0.056
Loss: 0.093
Loss: 0.016
Loss: 0.134
Loss: 0.046
Loss: 0.207
Loss: 0.069
Loss: 0.083
Loss: 0.142
Loss: 0.022
Loss: 0.174
Loss: 0.025
Loss: 0.034
Loss: 0.135
Loss: 0.026
Loss: 0.040
Loss: 0.046
Loss: 0.220
Loss: 0.020
Loss: 0.054
Loss: 0.082


In [14]:
# Test
# Get Test Data
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
rand = torch.randperm(test_dataset.data.shape[0])
X_test = test_dataset.data.reshape(-1,28*28)[rand].float()
Y_test = test_dataset.targets[rand]

batch_size = 32
correct = 0
total = 0
for i in range(0,len(X_test),batch_size):
    # Forward
    zero = torch.tensor([0])
    Xb = X_test[i:i+batch_size] / 255 # 32, 784
    Yb = Y_test[i:i+batch_size]

    l1 = Xb@w1
    relu1 = torch.maximum(zero, l1)
    
    l2 = relu1@w2
    relu2 = torch.maximum(zero, l2)

    logits = relu2@w3 + b3

    prediction = logits.argmax(dim=1)
    correct += (prediction == Yb).sum().item()
    total += len(Yb)
    
accuracy = correct / total
print(f'Accuracy: {accuracy*100:.0f}%')

Accuracy: 9%


In [13]:
w3 = w3*0