In [None]:
#default_exp lesson1

In [None]:
#export
from torch import tensor
import torch
import pickle
from fastai.utils import *
from fastai.datasets import *
import gzip
import math
from typing import *

In [None]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

In [3]:
data_path = download_data(MNIST_URL, ext='.gz')

Downloading http://deeplearning.net/data/mnist/mnist.pkl


In [4]:
! ls {data_path}

/root/.fastai/data/mnist.pkl.gz


In [None]:
#export
def get_data(data_path):
  with gzip.open(data_path, mode="rb") as d:
    ((X_train, y_train), (X_valid, y_valid), _) = pickle.load(d, encoding='latin-1')
    X_train, y_train, X_valid, y_valid = map(torch.tensor, (X_train, y_train, X_valid, y_valid))
    return (X_train, y_train, X_valid, y_valid)

In [None]:
X_train, y_train, X_valid, y_valid = get_data(data_path)

In [8]:
X_train.shape

torch.Size([50000, 784])

In [9]:
X_valid.shape

torch.Size([10000, 784])

In [10]:
y_train.shape

torch.Size([50000])

In [None]:
y_train = y_train.unsqueeze(-1)

In [12]:
y_train.shape

torch.Size([50000, 1])

In [13]:
y_valid.shape

torch.Size([10000])

In [None]:
y_valid = y_valid.unsqueeze(-1)

In [None]:
nh = 1000
n_out = 1
# 784 x 1000
w1 = torch.randn(X_train.shape[1], nh)/math.sqrt(X_train.shape[1])
# 1 x 1000
b1 = torch.zeros(1, nh)
# 1000 x 1
w2 = torch.randn(nh, n_out)/math.sqrt(nh)
# 1 x 1
b2 = torch.zeros(1, n_out)

In [None]:
#export
def linear(w, x, b): 
  return x@w + b

In [None]:
#export
def relu(z): return z.clamp_min(0.).float()

In [None]:
# NOTE: this does not make sense as a loss function in this case
# we're using it purely for pedagogical purposes

In [None]:
#export
def mse(y, y_hat): return (y_hat - y).pow(2).mean()

In [None]:
def sigmoid(x): return 1./(1+torch.exp(-x))

In [22]:
w1.shape

torch.Size([784, 1000])

In [23]:
X_train.shape

torch.Size([50000, 784])

In [24]:
w2.shape

torch.Size([1000, 1])

In [25]:
w1.shape

torch.Size([784, 1000])

In [26]:
X_train.shape

torch.Size([50000, 784])

In [27]:
c = tensor(y_train).max().item() + 1; c

  """Entry point for launching an IPython kernel.


10

In [None]:
y_cls_train = torch.zeros(size=(y_train.size()[0], c))

In [29]:
y_cls_train.size()

torch.Size([50000, 10])

In [None]:

# def dense2cls(dense):
#   c = dense.max().item() + 1
#   cls = torch.zeros((dense.size()[0], c))
#   return cls.scatter(1, dense[:, None], 1)

In [None]:
# y_train_cls, y_valid_cls = dense2cls(y_train), dense2cls(y_valid)

In [None]:
def forward(x, y):
  a1 = relu(linear(w1, x, b1))
  out = linear(w2, a1, b2)
  mse_ = mse(y, out)
  return mse_

In [33]:
b1.shape

torch.Size([1, 1000])

In [34]:
forward(X_train, y_train)

tensor(27.5301)

In [None]:
#export
def d_mse(output, y): output.g =  2.*(output - y) / output.shape[0]
def d_relu(inp, output): inp.g = (inp > 0).float() * output.g

In [None]:
#export
def d_linear(inp, output, w, b): 
  inp.g = output.g@w.t()
  w.g = inp.t()@output.g
  b.g = output.g.sum(0)

In [None]:
#export
def forward_and_backward(x, y):
  # 50k x 1000
  z1 = linear(w1, x, b1)
  # 50k x 1000
  a1 = relu(z1)
  # 50k x 1
  z2 = linear(w2, a1, b2)
  # 1
  mse_ = mse(y, z2)

  # 50k x 1
  d_mse(z2, y)
  d_linear(a1, z2, w2, b2)
  d_relu(z1, a1)
  d_linear(x, z1, w1, b1)
  return mse_

In [None]:
#export
def test_allclose(a, b): assert torch.allclose(a, b, rtol=1e-3, atol=1e-5)

In [None]:
man_loss = forward_and_backward(X_train, y_train)

In [40]:
(w1.g != 0).float().sum()

tensor(691029.)

In [41]:
man_loss

tensor(27.5301)

In [None]:
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig = X_train.g.clone()

In [None]:
xt2 = X_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [None]:
def forward2(inp, outp):
  z1 = linear(w12, inp, b12)
  a1 = relu(z1)
  z2 = linear(w22, a1, b22)
  loss = mse(outp, z2)
  return loss

In [None]:
loss = forward2(xt2, y_train)

In [46]:
loss

tensor(27.5301, grad_fn=<MeanBackward0>)

In [None]:
loss.backward()

In [48]:
(w12.grad - w1g).sum()

tensor(0.0008)

In [49]:
w12.grad.sum()

tensor(-107.5014)

In [50]:
w1g.sum()

tensor(-107.5023)

In [51]:
w1g.sum() - w12.grad.sum()

tensor(-0.0008)

In [None]:
test_allclose(w22.grad, w2g)

In [None]:
test_allclose(w12.grad, w1.g)

Now we switch to refactoring.

In [None]:
class Lin(object):
  def __init__(self, w, b):
    self.w = w
    self.b = b

  def __call__(self, x):
    return x@self.w + self.b

In [None]:
class ReLU(object):
  def __call__(self, x):
    return x.clamp_min(0).float()

In [None]:
class MSE(object):
  def __init__(self, y):
    self.y = y

  def __call__(self, x):
    return (x - self.y).pow(2).mean()

In [None]:
class Model(object):
  def __init__(self, layers):
    self.layers = layers

  def forward(self, x):
    for l in self.layers: x = l(x)
    return x

  def backward(self):
    for l in self.layers[::-1]: l.backward()

In [None]:
mdl = Model([
  Lin(w1, b1),
  ReLU(),
  Lin(w2, b2),
  MSE(y_train)    
])

In [59]:
mdl.forward(X_train)

tensor(27.5301)

In [None]:
class ReLU(object):
  def forward(self, x):
    self.inp = x
    self.out = x.clamp_min(0).float()
    return self.out

  def backward(self):
    self.inp.g = (self.inp > 0).float() * self.out.g

  def __call__(self, x):
    return self.forward(x)

In [None]:
class MSE(object):

  def __init__(self, y):
    self.y = y

  def forward(self, x):
    self.inp = x
    self.out = (x - self.y).pow(2).mean()
    return self.out

  def backward(self):
    self.inp.g = 2. * (self.inp - self.y) / self.y.shape[0]

  def __call__(self, x):
    return self.forward(x)

In [None]:
class Lin(object):
  def __init__(self, w, b):
    self.w = w
    self.b = b

  def forward(self, x):
    self.inp = x
    self.out = x@self.w + self.b
    return self.out

  def backward(self):
    self.inp.g = self.out.g @ self.w.t()
    self.w.g = self.inp.t() @ self.out.g
    self.b.g = self.out.g.sum(0)

  def __call__(self, x):
    return self.forward(x)

In [None]:
mdl = Model([
  Lin(w1, b1),
  ReLU(),
  Lin(w2, b2),
  MSE(y_train)
])

In [89]:
mdl.forward(X_train)

tensor(27.5301)

In [None]:
mdl.backward()

In [None]:
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig = X_train.g.clone()
xt2 = X_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [67]:
man_loss = forward2(xt2, y_train); man_loss

tensor(27.5301, grad_fn=<MeanBackward0>)

In [None]:
man_loss.backward()

In [None]:
test_allclose(w1g, w12.grad)

In [None]:
test_allclose(w1g, w12.grad)

In [None]:
test_allclose(w2g, w22.grad)

In [None]:
test_allclose(w2g, w22.grad)

In [None]:
test_allclose(b1g, b12.grad)

In [None]:
test_allclose(b1g, b12.grad)

In [None]:
test_allclose(b2g, b22.grad)

In [None]:
test_allclose(b2g, b22.grad)

In [None]:
my_grads = [w1g, w2g, b1g, b2g]
torch_grads = [w12.grad, w22.grad, b12.grad, b22.grad]
for i,j in zip(my_grads, torch_grads):
  test_allclose(i, j)

We see that each module has a backward and forward pass. Each model component directly implements a forward and a backward. A model is comprised of several model components, through which inputs are passed. When `backward` is called, it calls backward on each of the model components. 

In [None]:
#export
class Module(object):
  
  def forward(self, x):
    raise NotImplementedError()

  def backward(self):
    raise NotImplementedError()

  def __call__(self, x):
     return self.forward(x)

In [None]:
#export
class Loss(object):
  def forward(self, x, y):
    raise NotImplementedError

  def backward(self): 
    raise NotImplementedError

  def __call__(self, x, y):
    return self.forward(x, y)

In [None]:
#export
class Model(object):
  def __init__(self, layers: List[Module], loss: Loss):
    self.layers = layers
    self.loss = loss

  def forward(self, x, y):
    self.inp = x
    for l in self.layers: x = l(x)
    x = self.loss(x, y)
    self.out = x
    return x

  def backward(self):
    self.loss.backward()
    for l in self.layers[::-1]: l.backward()

  def __call__(self, x, y):
    return self.forward(x, y)

In [None]:
#export
class Lin(Module):
  def __init__(self, w, b):
    self.w = w
    self.b = b

  def forward(self, x):
    self.inp = x
    self.out = x@self.w + self.b
    return self.out

  def backward(self):
    self.inp.g = self.out.g @ self.w.t()
    self.w.g = self.inp.t() @ self.out.g
    self.b.g = self.out.g.sum(0)

class ReLU(Module):
  def forward(self, x):
    self.inp = x
    self.out = x.clamp_min(0).float()
    return self.out

  def backward(self):
    self.inp.g = (self.inp > 0).float() * self.out.g

class MSE(Loss):
  
  def forward(self, x, y):
    self.inp = x
    self.y = y
    self.out = (x - y).pow(2).mean()
    return self.out

  def backward(self):
    self.inp.g = 2. * (self.inp - self.y) / self.y.shape[0]

In [None]:
mdl = Model([
  Lin(w1, b1),
  ReLU(),
  Lin(w2, b2)
], loss=MSE())

In [96]:
mdl(X_train, y_train)

tensor(27.5301)

In [97]:
mdl.forward(X_train, y_train)

tensor(27.5301)

In [None]:
mdl.backward()

In [None]:
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig = X_train.g.clone()
xt2 = X_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)
def forward2(inp, outp):
  z1 = linear(w12, inp, b12)
  a1 = relu(z1)
  z2 = linear(w22, a1, b22)
  loss = mse(outp, z2)
  return loss

In [None]:
loss = forward2(xt2, y_train)

In [101]:
loss

tensor(27.5301, grad_fn=<MeanBackward0>)

In [None]:
loss.backward()

In [None]:
test_allclose(w1g, w12.grad)

In [None]:
test_allclose(w2g, w22.grad)

In [None]:
test_allclose(b1g, b12.grad)

In [None]:
test_allclose(b2g, b22.grad)