In [10]:
import torchvision
import torch


# ------------------- Get Data -------------------
# 60,000 samples, 28x28 pixels / sample
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
train_rand = torch.randperm(train_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
train_img_tensors = train_dataset.data.reshape(-1,784)[train_rand]
Y = train_dataset.targets[train_rand]

# 10,000 samples, 28x28 pixels / sample
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
test_rand = torch.randperm(test_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
test_img_tensors = test_dataset.data.reshape(-1,784)[test_rand]
test_Y = test_dataset.targets[test_rand]


# ------------------- Train -------------------
zero = torch.tensor(0)
batch_size = 32
# input nodes batch
# 32 x 784
X = train_img_tensors[:batch_size].float() / 255

h1_nodes = 128
w1 = torch.randn((h1_nodes, 784)).requires_grad_(True)
w1 = w1 / h1_nodes**0.5
l1 = X@w1.T
relu1 = torch.maximum(zero, l1)

h2_nodes = 64
w2 = torch.randn((h2_nodes, h1_nodes)).requires_grad_(True)
w2 = w2 / h2_nodes**0.5
l2 = relu1@w2.T
relu2 = torch.maximum(zero, l2)

out_nodes = 10
w3 = torch.randn((out_nodes, h2_nodes)).requires_grad_(True)
w3 = w3 / out_nodes**0.5
logits = relu2@w3.T

logit_maxes = logits.max(dim=1, keepdim=True).values
norm_logits = logits - logit_maxes

e_to_the_logits = torch.e**norm_logits
e_to_the_logits_sum = torch.sum(e_to_the_logits, dim=1, keepdim=True)
e_to_the_logits_sum_inv = e_to_the_logits_sum**-1
probs = e_to_the_logits * e_to_the_logits_sum_inv # prob = softmax. softmax is an activation function, as is relu


logprobs = probs.log()
# ------------------- Loss -------------------

# understand this before moving on to anything else
# grabs predictions for 32 images, and correct values in each image
loss = -logprobs[range(batch_size), Y[:batch_size]].mean()
loss

# ------------------- Back -------------------

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Y[:batch_size]] = -1/batch_size
dprobs = (1 / probs) * dlogprobs
de_to_the_logits = e_to_the_logits_sum_inv * dprobs
de_to_the_logits_sum_inv = (e_to_the_logits * dprobs).sum(dim=1, keepdim=True)
de_to_the_logits_sum = -e_to_the_logits_sum**(-2) * de_to_the_logits_sum_inv
de_to_the_logits = de_to_the_logits + de_to_the_logits_sum
dlogits = torch.e**logits * de_to_the_logits
drelu2 = dlogits @ w3
dw3 = dlogits.T @ relu2
dl2 = (l2 > 0) * drelu2
dw2 = dl2.T @ relu1
drelu1 = dl2 @ w2
dl1 = (l1 > 0) * drelu1
dw1 = dl1.T @ X
e_to_the_logits, loss

(tensor([[0.1750, 0.2277, 0.3016, 0.0830, 0.0658, 0.0972, 0.0368, 1.0000, 0.1991,
          0.3058],
         [0.0137, 0.1028, 0.0272, 0.0289, 0.0408, 0.0065, 0.0177, 1.0000, 0.0430,
          0.0962],
         [0.0082, 0.2907, 0.0621, 0.0184, 0.0581, 0.0418, 0.0032, 1.0000, 0.0352,
          0.4747],
         [0.0217, 0.1063, 0.2089, 0.0263, 0.0736, 0.0212, 0.0322, 0.1344, 0.0767,
          1.0000],
         [0.0036, 0.1693, 1.0000, 0.0578, 0.0109, 0.0122, 0.0042, 0.1375, 0.0307,
          0.5769],
         [0.0270, 0.5638, 0.3562, 0.0251, 1.0000, 0.7574, 0.0155, 0.7883, 0.0283,
          0.5515],
         [0.1375, 0.4880, 0.0648, 0.0688, 0.0642, 0.0407, 0.0683, 1.0000, 0.0865,
          0.5055],
         [0.0532, 0.5122, 0.0265, 0.0493, 0.0348, 0.0553, 0.0356, 0.1422, 0.3626,
          1.0000],
         [0.0335, 0.0502, 0.0028, 0.0117, 0.0372, 0.0062, 0.0020, 0.3779, 0.0087,
          1.0000],
         [0.0639, 0.0527, 0.0853, 0.0488, 0.0478, 0.0153, 0.0106, 0.5025, 0.0777,
         

In [11]:
# After loss computed
loss.backward()

# Compare your gradients vs PyTorch's
print(torch.allclose(dw1, w1.grad, atol=1e-4))
print(torch.allclose(dw2, w2.grad, atol=1e-4))
print(torch.allclose(dw3, w3.grad, atol=1e-4))
# print(torch.allclose(dw1, w1.grad, atol=1e-4))

  print(torch.allclose(dw1, w1.grad, atol=1e-4))


TypeError: allclose(): argument 'other' (position 2) must be Tensor, not NoneType

In [3]:
print((dw1 - w1.grad).abs().max())
print((dw2 - w2.grad).abs().max())
print((dw3 - w3.grad).abs().max())

tensor(nan, grad_fn=<MaxBackward1>)
tensor(nan, grad_fn=<MaxBackward1>)
tensor(nan, grad_fn=<MaxBackward1>)
