In [167]:
import torch 
from torch.nn import functional as F 
from torch import nn, optim
from sys import exit as e
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score

In [168]:
trans = transforms.Compose([
  transforms.ToTensor()
])

train_dataset = MNIST("./data", train=True, download=True, transform=trans)
test_dataset = MNIST("./data", train=False, download=True, transform=trans)

train_dataset, val_dataset = torch.utils.data.random_split(train_dataset, [50000, 10000])
len(train_dataset), len(val_dataset)

(50000, 10000)

In [169]:
batch_size = 32
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
val_loader = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
test_loader = DataLoader(test_dataset, shuffle=True, batch_size=batch_size, drop_last=True)

In [170]:
torch.manual_seed(42)

<torch._C.Generator at 0x7f96b727a450>

In [171]:
# utility function we will use later when comparing manual gradients to PyTorch gradients
def cmp(s, dt, t):
  ex = torch.all(dt == t.grad).item()
  app = torch.allclose(dt, t.grad)
  maxdiff = (dt - t.grad).abs().max().item()
  print(f'{s:15s} | exact: {str(ex):5s} | approximate: {str(app):5s} | maxdiff: {maxdiff}')

In [172]:
n_dim = 784 
n_hidden1 = 256 
n_hidden2 = 256 
n_output = 10
device = torch.device('cuda')
W1 = torch.randn((n_dim, n_hidden1), device=device) / n_dim**0.5 
b1 = torch.randn((n_hidden1), device=device)
W2 = torch.randn((n_hidden1, n_output), device=device)
b2 = torch.randn((n_output), device=device)

W1 *= 2**0.5
parameters = [W1, b1, W2, b2]

for p in parameters:
  p.requires_grad=True

print(sum(p.nelement() for p in parameters))

203530


In [173]:
n_epochs = 20
with torch.no_grad():
  for epoch in range(n_epochs):
    for idx, (x, label) in enumerate(train_loader):
      x = x.to(device)
      label = label.to(device)

      x = x.flatten(start_dim=1)
      h1 = x @ W1 + b1 
      h1R = torch.relu(h1)
      logits = h1R @ W2 + b2 

      loss = F.cross_entropy(logits, label)

      # for p in parameters:
      #   p.grad = None

      # backprop
      dlogits = F.softmax(logits, dim=-1)
      dlogits[range(batch_size), label] -= 1 
      dlogits /= batch_size

      dh1R = dlogits @ W2.T
      dW2 = h1R.T @ dlogits
      db2 = dlogits.sum(0)

      dh1 = torch.zeros_like(h1)
      dh1[h1 > 0]  = 1
      dh1  = dh1 * dh1R

      dW1 = x.T @ dh1
      db1 = dh1.sum(0)
      grads = [dW1, db1, dW2, db2]

      lr = 0.001 if epoch < 100000 else 0.01 # step learning rate decay
      for p, grad in zip(parameters, grads):
        p.data += -lr * grad

      # all_tensors = [logprobs, probs, logits_exp_sum, logits_exp, logits_main, logits_max, logits, h1R, W2, b2, h1, W1,b1]
      # for t in all_tensors:
        # t.retain_grad()
      # loss2 = F.cross_entropy(logits, label)
      # loss2.backward()

    if epoch % 10 == 0:
      print(f"{epoch}/{n_epochs}: {loss.item()}")

  #   break
  # break



0/20: 0.5183426141738892
10/20: 0.23148706555366516


In [174]:
dlogits = F.softmax(logits, dim=-1)
dlogits[range(batch_size), label] -= 1 
dlogits /= batch_size

dh1R = dlogits @ W2.T
dW2 = h1R.T @ dlogits
db2 = dlogits.sum(0)

dh1 = torch.zeros_like(h1)
dh1[h1 > 0]  = 1
dh1  = dh1 * dh1R

dW1 = x.T @ dh1
db1 = dh1.sum(0)
grads = [dW1, db1, dW2, db2]

lr = 0.1 if epoch < 100000 else 0.01 # step learning rate decay
for p, grad in zip(parameters, grads):
  p.data += -lr * grad

cmp('logits', dlogits, logits)
cmp('h1R', dh1R, h1R)
cmp('W2', dW2, W2)
cmp('b2', db2, b2)

cmp('h1', dh1, h1)
cmp('dW1', dW1, W1)
cmp('db1', db1, b1)

TypeError: all() received an invalid combination of arguments - got (bool), but expected one of:
 * (Tensor input, *, Tensor out)
 * (Tensor input, tuple of ints dim, bool keepdim, *, Tensor out)
 * (Tensor input, int dim, bool keepdim, *, Tensor out)
 * (Tensor input, name dim, bool keepdim, *, Tensor out)


In [175]:
h1.shape, W1.shape, b1.shape, dh1.shape, x.shape

(torch.Size([32, 256]),
 torch.Size([784, 256]),
 torch.Size([256]),
 torch.Size([32, 256]),
 torch.Size([32, 784]))

In [176]:
with torch.no_grad():
  for x, label in val_loader:
    x = x.to(device)
    label = label.to(device)

    x = x.flatten(start_dim=1)
    h1 = x @ W1 + b1 
    h1R = torch.relu(h1)
    logits = h1R @ W2 + b2 

    probs = F.softmax(logits, dim=-1)
    y_preds = torch.argmax(probs, dim=-1).detach().tolist()
    y_true = label.detach().tolist()

    print(accuracy_score(y_preds, y_true))
    break
  

0.875
