In [1]:
import torch
import torchvision
import torch.nn.functional as F

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand].float()
Y = train_dataset.targets[rand]


w1 = torch.randn(784,128) / 784**0.5
w2 = torch.randn(128,64) / 128**0.5
w3 = torch.randn(64,10) / 64**0.5

b3 = torch.zeros(10)

batch_size = 32

for epoch in range(3):
    lr = 0.1 if epoch < 1 else 0.01
    loop_count = 0
    for i in range(0,len(X),batch_size):
        # Forward
        zero = torch.tensor([0])
        Xb = X[i:i+batch_size] / 255 # 32, 784
        Yb = Y[i:i+batch_size]

        l1 = Xb@w1
        relu1 = torch.maximum(zero, l1)
        
        l2 = relu1@w2
        relu2 = torch.maximum(zero, l2)

        logits = relu2@w3 + b3

        logit_maxes = logits.max(dim=1, keepdim=True).values # the 1 max value from each row
        norm_logits = logits - logit_maxes # makes numbers <= 0 for not too high exp outputs

        counts = norm_logits.exp() # e**logits -- top of softmax
        counts_sum = counts.sum(dim=1, keepdim=True) # sum e**logits
        counts_sum_inv = counts_sum**-1 # place sum e**logits on bottom of softmax
        probs = counts * counts_sum_inv # probs = softmax = e**logits / sum e**logits


        # Loss
        # It's correct classes, then ln, then average
        logprobs = probs.log()
        loss = -logprobs[range(len(Yb)), Yb].mean()


        # Backward
        dloss = 1.0
        dlogprobs = torch.zeros_like(logprobs)
        dlogprobs[range(len(Yb)), Yb] = -1.0 / len(Yb)
        dprobs = 1.0 / probs * dlogprobs
        dcounts = counts_sum_inv * dprobs
        dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
        dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
        dcounts += dcounts_sum
        dnorm_logits = norm_logits.exp() * dcounts
        dlogits = dnorm_logits
        dlogit_maxes = -dnorm_logits.sum(dim=1, keepdim=True)
        dlogits += F.one_hot(logits.max(dim=1).indices, logits.shape[1]) * dlogit_maxes
        drelu2 = dlogits@w3.T
        dw3 = relu2.T@dlogits
        dl2 = (l2 > 0) * drelu2
        drelu1 = dl2@w2.T
        dw2 = relu1.T@dl2
        dl1 = (l1 > 0) * drelu1
        dw1 = Xb.T@dl1
        db3 = dlogits.sum(0)


        # Update
        w1 = w1 - lr * dw1
        w2 = w2 - lr * dw2
        w3 = w3 - lr * dw3
        b3 = b3 - lr * db3
        
        if loop_count %100 == 0:
            print(loss)
        loop_count += 1

tensor(2.3067)
tensor(0.4012)
tensor(0.2863)
tensor(0.4701)
tensor(0.3065)
tensor(0.0700)
tensor(0.2526)
tensor(0.2502)
tensor(0.2002)
tensor(0.2134)
tensor(0.1760)
tensor(0.0875)
tensor(0.2063)
tensor(0.0819)
tensor(0.2761)
tensor(0.2198)
tensor(0.1094)
tensor(0.1052)
tensor(0.2709)
tensor(0.0467)
tensor(0.1769)
tensor(0.0633)
tensor(0.1116)
tensor(0.1139)
tensor(0.0158)
tensor(0.0647)
tensor(0.0485)
tensor(0.0690)
tensor(0.0270)
tensor(0.0694)
tensor(0.0550)
tensor(0.1402)
tensor(0.0569)
tensor(0.0821)
tensor(0.1209)
tensor(0.0850)
tensor(0.0364)
tensor(0.1438)
tensor(0.0298)
tensor(0.1836)
tensor(0.0488)
tensor(0.0772)
tensor(0.1064)
tensor(0.0112)
tensor(0.0509)
tensor(0.0436)
tensor(0.0690)
tensor(0.0273)
tensor(0.0589)
tensor(0.0525)
tensor(0.1324)
tensor(0.0472)
tensor(0.0686)
tensor(0.1085)
tensor(0.0637)
tensor(0.0311)
tensor(0.1253)


In [15]:
# Test
# Get Test Data
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
rand = torch.randperm(test_dataset.data.shape[0])
X_test = test_dataset.data.reshape(-1,28*28)[rand].float()
Y_test = test_dataset.targets[rand]

batch_size = 32
correct = 0
total = 0
for i in range(0,len(X_test),batch_size):
    # Forward
    zero = torch.tensor([0])
    Xb = X_test[i:i+batch_size] / 255 # 32, 784
    Yb = Y_test[i:i+batch_size]

    l1 = Xb@w1
    relu1 = torch.maximum(zero, l1)
    
    l2 = relu1@w2
    relu2 = torch.maximum(zero, l2)

    logits = relu2@w3 + b3

    prediction = logits.argmax(dim=1)
    correct += (prediction == Yb).sum().item()
    total += len(Yb)
accuracy = correct / total
print(f'Accuracy: {accuracy*100:.0f}%')

Accuracy: 96%
