In [3]:
import torchvision
import torch


# ------------------- Get Data -------------------
# 60,000 samples, 28x28 pixels / sample
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
train_rand = torch.randperm(train_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
train_img_tensors = train_dataset.data.reshape(-1,784)[train_rand]
Y = train_dataset.targets[train_rand]

# 10,000 samples, 28x28 pixels / sample
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
test_rand = torch.randperm(test_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
test_img_tensors = test_dataset.data.reshape(-1,784)[test_rand]
test_Y = test_dataset.targets[test_rand]


# ------------------- Train -------------------
zero = torch.tensor(0)
batch_size = 32
# input nodes batch
# 32 x 784
X = train_img_tensors[:batch_size].float() / 255

h1_nodes = 128
w1 = torch.randn((h1_nodes, 784)) / 784**0.5
w1.requires_grad_(True)
l1 = X@w1.T
relu1 = torch.maximum(zero, l1)

h2_nodes = 64
w2 = torch.randn((h2_nodes, h1_nodes)) / h1_nodes**0.5
w2.requires_grad_(True)
l2 = relu1@w2.T
relu2 = torch.maximum(zero, l2)

out_nodes = 10
w3 = torch.randn((out_nodes, h2_nodes)) / h2_nodes**0.5
w3.requires_grad_(True)
logits = relu2@w3.T

e_to_the_logits = torch.e**logits
e_to_the_logits_sum = torch.sum(e_to_the_logits, dim=1, keepdim=True)
e_to_the_logits_sum_inv = e_to_the_logits_sum**-1
probs = e_to_the_logits * e_to_the_logits_sum_inv # prob = softmax. softmax is an activation function, as is relu


logprobs = probs.log()
# ------------------- Loss -------------------

# understand this before moving on to anything else
# grabs predictions for 32 images, and correct values in each image
loss = -logprobs[range(batch_size), Y[:batch_size]].mean()
loss

# ------------------- Back -------------------

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Y[:batch_size]] = -1/batch_size
dprobs = (1 / probs) * dlogprobs
de_to_the_logits = e_to_the_logits_sum_inv * dprobs
de_to_the_logits_sum_inv = (e_to_the_logits * dprobs).sum(dim=1, keepdim=True)
de_to_the_logits_sum = -e_to_the_logits_sum**(-2) * de_to_the_logits_sum_inv
de_to_the_logits = de_to_the_logits + de_to_the_logits_sum
dlogits = torch.e**logits * de_to_the_logits
drelu2 = dlogits @ w3
dw3 = dlogits.T @ relu2
dl2 = (l2 > 0) * drelu2
dw2 = dl2.T @ relu1
drelu1 = dl2 @ w2
dl1 = (l1 > 0) * drelu1
dw1 = dl1.T @ X

In [4]:
# After loss computed
loss.backward()

# Compare your gradients vs PyTorch's
print(torch.allclose(dw1, w1.grad, atol=1e-4))
print(torch.allclose(dw2, w2.grad, atol=1e-4))
print(torch.allclose(dw3, w3.grad, atol=1e-4))
# print(torch.allclose(dw1, w1.grad, atol=1e-4))

True
True
True


In [None]:
import torch
import torchvision

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand]
Y = train_dataset.targets[rand]

test_dataset.shape, 

# Forward

# Backward

tensor([  0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,  20, 1

In [3]:
print((dw1 - w1.grad).abs().max())
print((dw2 - w2.grad).abs().max())
print((dw3 - w3.grad).abs().max())

tensor(nan, grad_fn=<MaxBackward1>)
tensor(nan, grad_fn=<MaxBackward1>)
tensor(nan, grad_fn=<MaxBackward1>)
