In [None]:
import torchvision
import torch


# ------------------- Get Data -------------------
# 60,000 samples, 28x28 pixels / sample
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
train_rand = torch.randperm(train_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
train_img_tensors = train_dataset.data.reshape(-1,784)[train_rand]
Y = train_dataset.targets[train_rand]

# 10,000 samples, 28x28 pixels / sample
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
test_rand = torch.randperm(test_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
test_img_tensors = test_dataset.data.reshape(-1,784)[test_rand]
test_Y = test_dataset.targets[test_rand]


# ------------------- Train -------------------
zero = torch.tensor(0)
batch_size = 32
# input nodes batch
# 32 x 784
X = train_img_tensors[:batch_size] / 255

h1_nodes = 128
w1 = torch.randn((h1_nodes, 784))
w1 = w1 / h1_nodes**0.5
l1 = X@w1.T
relu1 = torch.maximum(zero, l1)

h2_nodes = 64
w2 = torch.randn((h2_nodes, h1_nodes))
w2 = w2 / h2_nodes**0.5
l2 = relu1@w2.T
relu2 = torch.maximum(zero, l2)

out_nodes = 10
w3 = torch.randn((out_nodes, h2_nodes))
w3 = w3 / out_nodes**0.5
logits = relu2@w3.T

e_to_the_logits = torch.e**logits
e_to_the_logits_sum = torch.sum(e_to_the_logits, dim=1, keepdim=True)
e_to_the_logits_sum_inv = e_to_the_logits_sum**-1
probs = e_to_the_logits * e_to_the_logits_sum_inv # prob = softmax. softmax is an activation function, as is relu


logprobs = probs.log()
# ------------------- Loss -------------------

# understand this before moving on to anything else
# grabs predictions for 32 images, and correct values in each image
loss = -logprobs[range(batch_size), Y[:batch_size]].mean()
loss

# ------------------- Back -------------------

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Y[:batch_size]] = -1/batch_size
dprobs = (1 / probs) * dlogprobs
de_to_the_logits = e_to_the_logits_sum_inv * dprobs
de_to_the_logits_sum_inv = (e_to_the_logits * dprobs).sum(dim=1, keepdim=True)
de_to_the_logits_sum = -e_to_the_logits_sum**(-2) * de_to_the_logits_sum_inv
de_to_the_logits = de_to_the_logits + de_to_the_logits_sum
dlogits = torch.e**logits * de_to_the_logits
drelu2 = dlogits @ w3
dw3 = dlogits.T@(1.0*(relu2 > 0))

tensor([[-9.2772e-02, -2.9267e-02, -7.3800e-02, -2.3591e-02, -2.0155e-02,
         -5.3629e-02,  1.1020e-02,  1.0880e-02, -6.4127e-02, -5.6240e-02,
         -9.8745e-02, -6.4268e-02, -3.8800e-02, -1.3070e-02, -7.8112e-02,
         -7.9066e-02, -3.8441e-02, -7.5412e-02, -7.6691e-02,  0.0000e+00,
          2.5748e-04,  2.7123e-03,  3.5962e-02,  1.9398e-03, -8.9559e-02,
         -6.2226e-02, -7.3514e-02,  5.2089e-03, -2.7872e-02, -6.5855e-02,
         -4.7265e-02, -1.1084e-02, -2.5298e-02, -5.5015e-02,  1.2245e-02,
         -6.5059e-02, -8.8715e-02, -2.9619e-02, -9.6322e-02, -5.1364e-02,
          0.0000e+00, -1.1386e-01, -1.3319e-02, -8.2056e-02, -1.0266e-01,
          1.0587e-04,  8.5092e-03, -7.3800e-02, -3.6412e-02, -7.7200e-02,
         -7.1097e-02, -6.1490e-02, -1.0385e-01, -2.5891e-02,  9.4774e-03,
         -2.6085e-02, -7.4290e-02, -7.3800e-02, -9.9881e-02, -6.1241e-02,
         -1.6025e-03, -3.5696e-02, -4.5522e-03,  3.3297e-03],
        [-2.5120e-02, -8.5982e-02, -8.3991e-02, -3

In [41]:
w3.shape, relu2.shape, dlogits.shape

(torch.Size([10, 64]), torch.Size([32, 64]), torch.Size([32, 10]))