In [None]:
import torch
import torchvision
import torch.nn.functional as F

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand].float()
Y = train_dataset.targets[rand]


w1 = torch.randn(784,128) / 784**0.5
w2 = torch.randn(128,64) / 128**0.5
w3 = torch.randn(64,10) / 64**0.5

b3 = torch.zeros(10)

batch_size = 32

for epoch in range(3):
    lr = 0.1 if epoch < 1 else 0.01
    for i in range(0,len(X),batch_size):
        # Forward
        zero = torch.tensor([0])
        Xb = X[i:i+batch_size] / 255 # 32, 784
        Yb = Y[i:i+batch_size]

        l1 = Xb@w1
        relu1 = torch.maximum(zero, l1)
        
        l2 = relu1@w2
        relu2 = torch.maximum(zero, l2)

        logits = relu2@w3 + b3

        logit_maxes = logits.max(dim=1, keepdim=True).values # the 1 max value from each row
        norm_logits = logits - logit_maxes # makes numbers <= 0 for not too high exp outputs

        counts = norm_logits.exp() # e**logits -- top of softmax
        counts_sum = counts.sum(dim=1, keepdim=True) # sum e**logits
        counts_sum_inv = counts_sum**-1 # place sum e**logits on bottom of softmax
        probs = counts * counts_sum_inv # probs = softmax = e**logits / sum e**logits


        # Loss
        # It's correct classes, then ln, then average
        logprobs = probs.log()
        loss = -logprobs[range(len(Yb)), Yb].mean()


        # Backward
        dloss = 1.0
        dlogprobs = torch.zeros_like(logprobs)
        dlogprobs[range(len(Yb)), Yb] = -1.0 / len(Yb)
        dprobs = 1.0 / probs * dlogprobs
        dcounts = counts_sum_inv * dprobs
        dcounts_sum_inv = (counts * dprobs).sum(1, keepdim=True)
        dcounts_sum = -counts_sum**-2 * dcounts_sum_inv
        dcounts += dcounts_sum
        dnorm_logits = norm_logits.exp() * dcounts
        dlogits = dnorm_logits
        dlogit_maxes = -dnorm_logits.sum(dim=1, keepdim=True)
        dlogits += F.one_hot(logits.max(dim=1).indices, logits.shape[1]) * dlogit_maxes
        drelu2 = dlogits@w3.T
        dw3 = relu2.T@dlogits
        dl2 = (l2 > 0) * drelu2
        drelu1 = dl2@w2.T
        dw2 = relu1.T@dl2
        dl1 = (l1 > 0) * drelu1
        dw1 = Xb.T@dl1
        db3 = dlogits.sum(0)


        # Update
        w1 = w1 - lr * dw1
        w2 = w2 - lr * dw2
        w3 = w3 - lr * dw3
        b3 = b3 - lr * db3
        
        print(loss)

torch.Size([60000, 784])
tensor(2.2771)
tensor(2.2888)
tensor(2.2830)
tensor(2.3061)
tensor(2.2184)
tensor(2.1913)
tensor(2.2155)
tensor(2.1990)
tensor(2.1597)
tensor(2.0935)
tensor(2.0668)
tensor(2.0853)
tensor(2.1014)
tensor(2.0726)
tensor(1.9856)
tensor(1.9682)
tensor(1.9888)
tensor(1.8686)
tensor(1.8005)
tensor(1.9473)
tensor(1.7925)
tensor(1.8401)
tensor(1.7316)
tensor(1.7350)
tensor(1.7720)
tensor(1.7464)
tensor(1.4851)
tensor(1.6196)
tensor(1.6445)
tensor(1.3587)
tensor(1.3013)
tensor(1.3661)
tensor(1.4129)
tensor(1.4165)
tensor(1.2010)
tensor(1.2820)
tensor(1.2321)
tensor(1.3573)
tensor(1.1146)
tensor(1.1304)
tensor(1.0977)
tensor(0.9781)
tensor(1.0413)
tensor(0.9888)
tensor(0.9166)
tensor(0.7216)
tensor(0.8811)
tensor(1.1642)
tensor(0.7232)
tensor(0.9894)
tensor(0.9710)
tensor(0.6292)
tensor(0.9356)
tensor(0.7876)
tensor(1.0912)
tensor(1.1403)
tensor(0.9695)
tensor(0.7779)
tensor(0.8939)
tensor(0.8866)
tensor(0.6577)
tensor(0.9048)
tensor(0.6145)
tensor(0.7866)
tensor(0.6946)


KeyboardInterrupt: 

In [2]:
# Get Test Data
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
rand = torch.randperm(test_dataset.data.shape[0])
X_test = test_dataset.data.reshape(-1,28*28)[rand].float()
Y_test = test_dataset.targets[rand]

batch_size = 32
correct = 0
total = 0
for i in range(0,len(X_test),batch_size):
    # Forward
    zero = torch.tensor([0])
    Xb = X_test[i:i+batch_size] / 255 # 32, 784
    Yb = Y_test[i:i+batch_size]

    l1 = Xb@w1
    relu1 = torch.maximum(zero, l1)
    
    l2 = relu1@w2
    relu2 = torch.maximum(zero, l2)

    logits = relu2@w3 + b3

    prediction = logits.argmax(dim=1)
    correct += (prediction == Yb).sum().item()
    total += len(Yb)
accuracy = correct / total
print(accuracy)

0.9663
