In [3]:
import torchvision
import torch


# ------------------- Get Data -------------------
# 60,000 samples, 28x28 pixels / sample
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
train_rand = torch.randperm(train_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
train_img_tensors = train_dataset.data.reshape(-1,784)[train_rand]
Y = train_dataset.targets[train_rand]

# 10,000 samples, 28x28 pixels / sample
test_dataset = torchvision.datasets.MNIST(train=False, download=True, root='./')
test_rand = torch.randperm(test_dataset.data.shape[0]) # for picking random images/labels, like randint except has no repeats
test_img_tensors = test_dataset.data.reshape(-1,784)[test_rand]
test_Y = test_dataset.targets[test_rand]


# ------------------- Train -------------------
zero = torch.tensor(0)
batch_size = 32
# input nodes batch
# 32 x 784
X = train_img_tensors[:batch_size].float() / 255

h1_nodes = 128
w1 = torch.randn((h1_nodes, 784)) / 784**0.5
w1.requires_grad_(True)
l1 = X@w1.T
relu1 = torch.maximum(zero, l1)

h2_nodes = 64
w2 = torch.randn((h2_nodes, h1_nodes)) / h1_nodes**0.5
w2.requires_grad_(True)
l2 = relu1@w2.T
relu2 = torch.maximum(zero, l2)

out_nodes = 10
w3 = torch.randn((out_nodes, h2_nodes)) / h2_nodes**0.5
w3.requires_grad_(True)
logits = relu2@w3.T

e_to_the_logits = torch.e**logits
e_to_the_logits_sum = torch.sum(e_to_the_logits, dim=1, keepdim=True)
e_to_the_logits_sum_inv = e_to_the_logits_sum**-1
probs = e_to_the_logits * e_to_the_logits_sum_inv # prob = softmax. softmax is an activation function, as is relu


logprobs = probs.log()
# ------------------- Loss -------------------

# understand this before moving on to anything else
# grabs predictions for 32 images, and correct values in each image
loss = -logprobs[range(batch_size), Y[:batch_size]].mean()
loss

# ------------------- Back -------------------

dlogprobs = torch.zeros_like(logprobs)
dlogprobs[range(batch_size), Y[:batch_size]] = -1/batch_size
dprobs = (1 / probs) * dlogprobs
de_to_the_logits = e_to_the_logits_sum_inv * dprobs
de_to_the_logits_sum_inv = (e_to_the_logits * dprobs).sum(dim=1, keepdim=True)
de_to_the_logits_sum = -e_to_the_logits_sum**(-2) * de_to_the_logits_sum_inv
de_to_the_logits = de_to_the_logits + de_to_the_logits_sum
dlogits = torch.e**logits * de_to_the_logits
drelu2 = dlogits @ w3
dw3 = dlogits.T @ relu2
dl2 = (l2 > 0) * drelu2
dw2 = dl2.T @ relu1
drelu1 = dl2 @ w2
dl1 = (l1 > 0) * drelu1
dw1 = dl1.T @ X

In [None]:
# After loss computed
loss.backward()

# Compare your gradients vs PyTorch's
print(torch.allclose(dw1, w1.grad, atol=1e-4))
print(torch.allclose(dw2, w2.grad, atol=1e-4))
print(torch.allclose(dw3, w3.grad, atol=1e-4))

True
True
True


In [None]:
import torch
import torchvision

# Get Data
train_dataset = torchvision.datasets.MNIST(train=True, download=True, root='./')
rand = torch.randperm(train_dataset.data.shape[0])
X = train_dataset.data.reshape(-1,28*28)[rand]
Y = train_dataset.targets[rand]

# Forward
batch_size = 32
X = X[:batch_size].float() / 255 # 32,784

input_nodes = X.shape[1] # 784
h1_nodes = 64
w1 = torch.randn(h1_nodes, input_nodes) / input_nodes**0.5 # 64,784
w1.requires_grad_(True)
l1 = X@w1.T # 32,64

h2_nodes = 128
w2 = torch.randn(h2_nodes, h1_nodes)

# Backward

tensor([[-0.9324,  0.1170,  0.9071,  ...,  0.4601, -1.2758,  0.7112],
        [ 0.4783, -1.1647, -0.5102,  ..., -0.4442, -0.7226,  1.3399],
        [ 0.1622,  1.2657, -0.8576,  ..., -1.9877,  0.3447, -0.1993],
        ...,
        [ 0.1897,  1.2692,  0.4251,  ..., -1.1419,  1.4508, -0.2449],
        [ 0.5428, -1.2337,  0.7475,  ..., -1.0843, -1.1363, -0.5797],
        [-0.5593, -0.7174,  1.3487,  ..., -1.0512,  0.4484,  1.8252]])

In [None]:
x = [[1]]
w = [3]



tensor([-0.9324,  0.1170,  0.9071, -1.4386,  0.1840,  1.0503,  0.0721,  0.4707,
        -0.9222, -0.0851,  2.4080, -0.2026,  0.2711,  0.1810, -0.1934,  0.2746,
        -0.8197, -0.5770,  0.3930, -2.2109,  0.5486, -0.3192, -0.1818,  1.5491,
        -0.0632,  0.8684, -0.5252, -1.3926, -0.9046,  0.6134,  0.8063,  0.6605,
         0.8105,  0.4281,  0.1734, -2.0207, -2.3229, -0.5025, -0.6211, -0.1063,
        -0.6072,  0.2242,  0.3163,  0.2059, -2.0174,  0.0556,  0.0238,  0.2537,
         2.5588,  0.3726,  0.8176,  0.9084, -0.7613,  1.7464, -0.4457, -0.2634,
         0.0761,  0.0709,  0.9940,  0.0032,  1.1738,  0.4601, -1.2758,  0.7112])