In [172]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## The forward and backward passes

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=4960)

In [173]:
#export
from exp.nb_01 import *

def get_data():
    path = datasets.download_data(MNIST_URL,'d:/.fastai/data/MNIST_URL', ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): return (x-m)/s

In [174]:
x_train,y_train,x_valid,y_valid = get_data()
x_train.shape,y_train.shape,x_valid.shape,y_valid.shape

(torch.Size([50000, 784]),
 torch.Size([50000]),
 torch.Size([10000, 784]),
 torch.Size([10000]))

In [175]:
train_mean,train_std = x_train.mean(),x_train.std()
train_mean,train_std

(tensor(0.1304), tensor(0.3073))

In [176]:
x_train = normalize(x_train, train_mean, train_std)
# NB: Use training, not validation mean for validation set
x_valid = normalize(x_valid, train_mean, train_std)  #use training set std and mean to normlize validation set

In [177]:
train_mean,train_std = x_train.mean(),x_train.std()
train_mean,train_std

(tensor(-7.6999e-06), tensor(1.))

In [178]:
#export
def test_near_zero(a,tol=1e-3): assert a.abs()<tol, f"Near zero: {a}"

In [179]:
test_near_zero(x_train.mean())
test_near_zero(1-x_train.std())

In [180]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(50000, 784, tensor(10))

## Foundations version

### Basic architecture

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=5128)

In [181]:
# num hidden
nh = 50

[Tinker practice](https://course.fast.ai/videos/?lesson=8&t=5255)

In [182]:
# simplified kaiming init / he init
w1 = torch.randn(m,nh)/math.sqrt(m)
b1 = torch.zeros(nh)
w2 = torch.randn(nh,1)/math.sqrt(nh)
b2 = torch.zeros(1)
w1.shape, w2.shape, b1.shape, b2.shape

(torch.Size([784, 50]), torch.Size([50, 1]), torch.Size([50]), torch.Size([1]))

In [183]:
test_near_zero(w1.mean())
test_near_zero(w1.std()-1/math.sqrt(m))

In [184]:
# This should be ~ (0,1) (mean,std)...
x_valid.mean(),x_valid.std()

(tensor(-0.0059), tensor(0.9924))

In [185]:
def lin(x, w, b): return x@w + b

In [186]:
t = lin(x_valid, w1, b1)

In [187]:
#...so should this, because we used kaiming init, which is designed to do this
t.mean(),t.std()

(tensor(-0.1790), tensor(1.0309))

In [188]:
def relu(x): return x.clamp_min(0.)

In [189]:
t = relu(lin(x_valid, w1, b1))

In [190]:
#...actually it really should be this!
t.mean(),t.std()

(tensor(0.3224), tensor(0.5174))

From pytorch docs: `a: the negative slope of the rectifier used after this layer (0 for ReLU by default)`

$$\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}$$

This was introduced in the paper that described the Imagenet-winning approach from *He et al*: [Delving Deep into Rectifiers](https://arxiv.org/abs/1502.01852), which was also the first paper that claimed "super-human performance" on Imagenet (and, most importantly, it introduced resnets!)

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=5128)

In [191]:
# kaiming init / he init for relu
w1 = torch.randn(m,nh)*math.sqrt(2/m)

In [192]:
w1.mean(),w1.std()

(tensor(3.5084e-05), tensor(0.0507))

In [193]:
t = relu(lin(x_valid, w1, b1))
t.mean(),t.std()

(tensor(0.5050), tensor(0.8091))

In [194]:
#export
from torch.nn import init

In [195]:
w1 = torch.zeros(m,nh)
init.kaiming_normal_(w1, mode='fan_out') #pytorch porch.nn.Linear function will do a transpose to weight matrix, so need fan_out
#fan_in' (default) or 'fan_out'. Choosing `fan_in` preserves the magnitude of the variance of the weights in the
#forward pass (like sqrt(m) in the earlier example). Choosing `fan_out` preserves the magnitudes in the backwards pass.will be
# sqrt(nh) in the earlier example
t = relu(lin(x_valid, w1, b1))

In [196]:
init.kaiming_normal_??

In [197]:
w1.mean(),w1.std()

(tensor(-0.0005), tensor(0.0501))

In [198]:
t.mean(),t.std()

(tensor(0.5617), tensor(0.8381))

In [199]:
w1.shape

torch.Size([784, 50])

In [200]:
import torch.nn

In [201]:
torch.nn.Linear(m,nh).weight.shape

torch.Size([50, 784])

In [202]:
torch.nn.Linear.forward??

In [203]:
torch.nn.functional.linear??

In [204]:
torch.nn.Conv2d??

In [205]:
torch.nn.modules.conv._ConvNd.reset_parameters??

In [206]:
# what if...?
def relu(x): return x.clamp_min(0.) - 0.5

In [207]:
# kaiming init / he init for relu
w1 = torch.randn(m,nh)*math.sqrt(2./m )
t1 = relu(lin(x_valid, w1, b1))
t1.mean(),t1.std()

(tensor(0.1773), tensor(0.9148))

In [208]:
def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [209]:
%timeit -n 10 _=model(x_valid)

8.38 ms ± 638 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [210]:
model(x_valid).shape,torch.Size([x_valid.shape[0],1])

(torch.Size([10000, 1]), torch.Size([10000, 1]))

In [211]:
assert model(x_valid).shape==torch.Size([x_valid.shape[0],1])

### Loss function: MSE

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=6372)

In [212]:
model(x_valid).shape

torch.Size([10000, 1])

We need `squeeze()` to get rid of that trailing (,1), in order to use `mse`. (Of course, `mse` is not a suitable loss function for multi-class classification; we'll use a better loss function soon. We'll use `mse` for now to keep things simple.)

In [213]:
#export
def mse(output, targ): return (output.squeeze(-1) - targ).pow(2).mean()

In [214]:
y_train,y_valid = y_train.float(),y_valid.float()

In [215]:
preds = model(x_train)

In [216]:
preds.shape

torch.Size([50000, 1])

In [217]:
mse(preds, y_train)

tensor(40.8813)

### Gradients and backward pass

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=6493)

In [218]:
def mse_grad(inp, targ): 
    # grad of loss with respect to output of previous layer
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

In [219]:
def relu_grad(inp, out):
    # grad of relu with respect to input activations
    inp.g = (inp>0).float() * out.g

In [220]:
def lin_grad(inp, out, w, b):
    # grad of matmul with respect to input
    inp.g = out.g @ w.t()
    print(inp.shape,out.shape)
    w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(0)
    print(inp.unsqueeze(-1).shape,  out.unsqueeze(1).shape,  w.g.shape)
    b.g = out.g.sum(0)
    print((inp.unsqueeze(-1) * out.g.unsqueeze(1)).shape,  b.g.shape)
    

In [221]:
def forward_and_backward(inp, targ):
    # forward pass:
    l1 = inp @ w1 + b1
    l2 = relu(l1)
    out = l2 @ w2 + b2
    # we don't actually need the loss in backward!
    loss = mse(out, targ)
    
    # backward pass:
    mse_grad(out, targ)
    lin_grad(l2, out, w2, b2)
    
    relu_grad(l1, l2)
    lin_grad(inp, l1, w1, b1)

In [222]:
forward_and_backward(x_train, y_train)


torch.Size([50000, 50]) torch.Size([50000, 1])
torch.Size([50000, 50, 1]) torch.Size([50000, 1, 1]) torch.Size([50, 1])
torch.Size([50000, 50, 1]) torch.Size([1])
torch.Size([50000, 784]) torch.Size([50000, 50])
torch.Size([50000, 784, 1]) torch.Size([50000, 1, 50]) torch.Size([784, 50])
torch.Size([50000, 784, 50]) torch.Size([50])


In [223]:
# Save for testing against later
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

We cheat a little bit and use PyTorch autograd to check our results.

In [224]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [225]:
def forward(inp, targ):
    # forward pass:
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22
    # we don't actually need the loss in backward!
    return mse(out, targ)

In [226]:
loss = forward(xt2, y_train)

In [227]:
loss.backward()

In [228]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )

## Refactor model

### Layers as classes

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=7112)

In [250]:
class Relu():
    def __call__(self, inp):
        self.inp = inp
        self.out = inp.clamp_min(0.)-0.5
        print('Relu input',hex(id(self.inp)))
        print('Relu output',hex(id(self.out)))
        return self.out
    
    def backward(self): self.inp.g = (self.inp>0).float() * self.out.g

In [251]:
class Lin():
    def __init__(self, w, b): self.w,self.b = w,b
        
    def __call__(self, inp):
        self.inp = inp
        self.out = inp@self.w + self.b
        print('lin input',hex(id(self.inp)))
        print('lin output',hex(id(self.out)))
        return self.out
    
    def backward(self):
        self.inp.g = self.out.g @ self.w.t()
        print(self.out.g)
        # Creating a giant outer product, just to sum it, is inefficient!
        self.w.g = (self.inp.unsqueeze(-1) * self.out.g.unsqueeze(1)).sum(0)
        self.b.g = self.out.g.sum(0)

In [252]:
class Mse():
    def __call__(self, inp, targ):
        self.inp = inp
        self.targ = targ
        self.out = (inp.squeeze() - targ).pow(2).mean()
        print('Mse input',hex(id(self.inp)))
        print('Mse output',hex(id(self.out)))
        return self.out
    
    def backward(self):
        self.inp.g = 2. * (self.inp.squeeze() - self.targ).unsqueeze(-1) / self.targ.shape[0]
        print(self.inp.g)

In [253]:
class Model():
    def __init__(self, w1, b1, w2, b2):
        self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()

In [254]:
w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model(w1, b1, w2, b2)

In [255]:
loss = model(x_train, y_train) #the current output memory address is the same as the input of next layer 

lin input 0x2be81451fc0
lin output 0x2be8146bd80
Relu input 0x2be8146bd80
Relu output 0x2be8146b0d8
lin input 0x2be8146b0d8
lin output 0x2be81431990
Mse input 0x2be81431990
Mse output 0x2be81431900


In [242]:
%time loss = model(x_train, y_train)

Wall time: 40.9 ms


In [243]:
%time model.backward()

tensor([[-2.7791e-04],
        [-3.1471e-05],
        [-1.6162e-04],
        ...,
        [-3.8055e-04],
        [-2.3560e-04],
        [-3.4172e-04]])
tensor([[-2.7791e-04],
        [-3.1471e-05],
        [-1.6162e-04],
        ...,
        [-3.8055e-04],
        [-2.3560e-04],
        [-3.4172e-04]])
tensor([[-1.8434e-05, -0.0000e+00, -1.9105e-05,  ..., -3.9064e-05,
         -1.9092e-06, -2.1107e-05],
        [-2.0875e-06, -0.0000e+00, -2.1635e-06,  ..., -0.0000e+00,
         -2.1620e-07, -2.3902e-06],
        [-0.0000e+00, -0.0000e+00, -1.1111e-05,  ..., -0.0000e+00,
         -1.1103e-06, -0.0000e+00],
        ...,
        [-0.0000e+00, -0.0000e+00, -2.6161e-05,  ..., -0.0000e+00,
         -2.6143e-06, -2.8902e-05],
        [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
         -1.6185e-06, -1.7894e-05],
        [-2.2667e-05, -0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
         -2.3476e-06, -2.5953e-05]])
Wall time: 2.99 s


In [149]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

### Module.forward()

In [150]:
class Module():
    def __call__(self, *args):
        self.args = args
        self.out = self.forward(*args)
        return self.out
    
    def forward(self): raise Exception('not implemented') #check if the classes have forward function implemented
    def backward(self): self.bwd(self.out, *self.args)  #define backward function with args [out, *args]

In [151]:
class Relu(Module):
    def forward(self, inp): return inp.clamp_min(0.)-0.5
    def bwd(self, out, inp): inp.g = (inp>0).float() * out.g

In [152]:
class Lin(Module):
    def __init__(self, w, b): self.w,self.b = w,b
        
    def forward(self, inp): return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = torch.einsum("bi,bj->ij", inp, out.g)
        self.b.g = out.g.sum(0)

In [153]:
class Mse(Module):
    def forward (self, inp, targ): return (inp.squeeze() - targ).pow(2).mean()
    def bwd(self, out, inp, targ): inp.g = 2*(inp.squeeze()-targ).unsqueeze(-1) / targ.shape[0]

In [154]:
class Model():
    def __init__(self):
        self.layers = [Lin(w1,b1), Relu(), Lin(w2,b2)]
        self.loss = Mse()
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x, targ)
    
    def backward(self):
        self.loss.backward()
        for l in reversed(self.layers): l.backward()

In [155]:
w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model()

In [156]:
%time loss = model(x_train, y_train)

Wall time: 43.9 ms


In [157]:
%time model.backward()

Wall time: 154 ms


In [158]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

### Without einsum

[Jump_to lesson 8 video](https://course.fast.ai/videos/?lesson=8&t=7484)

In [159]:
class Lin(Module):
    def __init__(self, w, b): self.w,self.b = w,b
        
    def forward(self, inp): return inp@self.w + self.b
    
    def bwd(self, out, inp):
        inp.g = out.g @ self.w.t()
        self.w.g = inp.t() @ out.g
        self.b.g = out.g.sum(0)

In [160]:
w1.g,b1.g,w2.g,b2.g = [None]*4
model = Model()

In [161]:
%time loss = model(x_train, y_train)

Wall time: 38.9 ms


In [162]:
%time model.backward()

Wall time: 123 ms


In [163]:
test_near(w2g, w2.g)
test_near(b2g, b2.g)
test_near(w1g, w1.g)
test_near(b1g, b1.g)
test_near(ig, x_train.g)

### nn.Linear and nn.Module

In [164]:
#export
from torch import nn

In [165]:
class Model(nn.Module):
    def __init__(self, n_in, nh, n_out):
        super().__init__()
        self.layers = [nn.Linear(n_in,nh), nn.ReLU(), nn.Linear(nh,n_out)]
        self.loss = mse
        
    def __call__(self, x, targ):
        for l in self.layers: x = l(x)
        return self.loss(x.squeeze(), targ)

In [166]:
model = Model(m, nh, 1)

In [167]:
%time loss = model(x_train, y_train)

Wall time: 109 ms


In [168]:
%time loss.backward()

Wall time: 53.8 ms


## Export

In [169]:
!notebook2script.py 02_fully_connected.ipynb

Converted 02_fully_connected.ipynb to exp\nb_02.py
