In [44]:
import math
import torch
from fastai import datasets

import gzip
import pickle

# data

In [45]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(torch.tensor, (x_train, y_train, x_valid, y_valid))

x_train, y_train, x_valid, y_valid =  get_data()

In [46]:
def normalize(x, m, s):
    return (x - m ) / s

In [47]:
x_train = normalize(x_train, x_train.mean(), x_train.std())
x_valid = normalize(x_valid, x_train.mean(), x_train.std())

In [48]:
n, m = x_train.shape
nh = 30

# forward pass

In [49]:
def relu(x):
    return x.clamp_min(0.)

In [50]:
W1 = torch.randn(m, nh) * math.sqrt(2/m)
b1 = torch.zeros(nh)

W2 = torch.randn(nh, 1) * math.sqrt(2/nh)
b2 = torch.zeros(1)

In [51]:
def model(inp):
    l1 = inp @ W1 + b1
    l2 = relu(l1)
    out = l2 @ W2 + b2
    return out

In [52]:
output = model(x_train)
output.shape

torch.Size([50000, 1])

In [53]:
y_train.shape

torch.Size([50000])

In [54]:
def mse(output, target):
    return (output.squeeze(-1) - target).pow(2).mean()

In [55]:
mse(output, y_train)

tensor(17.0231)

# gradients and backwards pass

In [56]:
def mse_grad(inp, target):
    inp.g = 2*((inp.squeeze(-1) - target) / target.shape[0])[..., None]

In [57]:
def relu_grad(inp, out):
    inp.g = out.g * (inp > 0).float()

In [58]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.T
    w.g = inp.T @ out.g
    b.g = out.g.sum(axis=0)

In [59]:
W1 = torch.randn(m, nh) * math.sqrt(2/m)
b1 = torch.zeros(nh)

W2 = torch.randn(nh, 1) * math.sqrt(2/nh)
b2 = torch.zeros(1)

In [60]:
def forward_and_backward(inp, targ):
    # forward
    l1 = inp @ W1 + b1
    l2 = relu(l1)
    out = l2 @ W2 + b2
    
    # backward
    mse_grad(out, targ)
    lin_grad(l2, out, W2, b2)
    relu_grad(l1, l2)
    lin_grad(inp, l1, W1, b1)

forward_and_backward(x_train, y_train)

In [61]:
w1g = W1.g.clone()
w2g = W2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

xt2 = x_train.clone().requires_grad_(True)
w12 = W1.clone().requires_grad_(True)
w22 = W2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

loss = forward(xt2, y_train)
loss.backward()

In [62]:
def test_near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)

In [63]:
(
test_near(w22.grad, w2g),
test_near(b22.grad, b2g),
test_near(w12.grad, w1g),
test_near(b12.grad, b1g),
test_near(xt2.grad, ig )
)

(True, True, True, True, True)

# modular

In [64]:
class Linear():
    def __init__(self, W, b):        
        self.W, self.b = W, b
        
    def __call__(self, inp):
        self.inp = inp
        self.out = inp @ self.W + self.b
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.W.T
        self.W.g = self.inp.T @ self.out.g
        self.b.g = self.out.g.sum(axis=0)

In [65]:
class Relu():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.) - 0.5
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g * (self.inp > 0).float()    

In [66]:
class Mse():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        return (self.inp.squeeze(-1) - targ).pow(2).mean()
    
    def backward(self):
        self.inp.g = 2*((self.inp.squeeze(-1) - self.targ) / self.targ.shape[0])[..., None]

In [67]:
class Model():
    def __init__(self, W1, b1, W2, b2):
        self.layers = [Linear(W1, b1), Relu(), Linear(W2, b2)]
        self.loss = Mse()
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for layer in self.layers[::-1]:
            layer.backward()

In [68]:
model = Model(W1, b1, W2, b2)

In [69]:
model(x_train, y_train)

tensor(31.6460)

In [70]:
model.backward()

In [71]:
w1g = W1.g.clone()
w2g = W2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

xt2 = x_train.clone().requires_grad_(True)
w12 = W1.clone().requires_grad_(True)
w22 = W2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

loss = forward(xt2, y_train)
loss.backward()

In [72]:
(
test_near(w2g, W2.g),
test_near(b2g, b2.g),
test_near(w1g, W1.g),
test_near(b1g, b1.g),
test_near(ig, x_train.g)
)

(True, True, True, True, True)

# inherit Module

In [73]:
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
        
    def forward(self): raise Exception('not implemented')
    def backward(self): self.bwd(self.out, *self.args)      

In [74]:
class Linear(Module):
    def __init__(self, W, b):        
        self.W, self.b = W, b
        
    def forward(self, inp):
        return inp @ self.W + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.W.T
        self.W.g = inp.T @ out.g
        self.b.g = out.g.sum(axis=0)

In [75]:
class Relu(Module):
    def forward(self, inp):
        return inp.clamp_min(0.) - 0.5
    
    def bwd(self, out, inp):
        inp.g = out.g * (inp > 0).float()

In [76]:
class Mse(Module):
    def forward(self, inp, targ):
        return (inp.squeeze(-1) - targ).pow(2).mean()
    
    def bwd(self, out, inp, targ):
        inp.g = 2*((inp.squeeze(-1) - targ) / targ.shape[0])[..., None]

In [77]:
class Model():
    def __init__(self, W1, b1, W2, b2):
        self.layers = [Linear(W1, b1), Relu(), Linear(W2, b2)]
        self.loss = Mse()
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for layer in self.layers[::-1]:
            layer.backward()

In [78]:
model = Model(W1, b1, W2, b2)
model(x_train, y_train)

tensor(31.6460)

In [79]:
model.backward()

In [80]:
(
test_near(w2g, W2.g),
test_near(b2g, b2.g),
test_near(w1g, W1.g),
test_near(b1g, b1.g),
test_near(ig, x_train.g)
)

(True, True, True, True, True)

# torch nn

In [117]:
from torch import nn

In [118]:
def mse(output, target):
    return (output - target).pow(2).mean()

In [119]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in, nh), nn.ReLU(), nn.Linear(nh, n_out)]
        self.loss = mse
    
    def __call__(self, x, targ):
        for layer in self.layers:
            x = layer(x)
        return self.loss(x.squeeze(), targ)

In [120]:
model = Model(m, nh, 1)

In [121]:
loss = model(x_train, y_train)
loss

tensor(29.0795, grad_fn=<MeanBackward0>)

In [122]:
loss.backward()