## Fully Connected Layer

```
inputs.shape = (bs, in)
weights.shape = (in, out)
outputs.shape = (bs, out)

(bs, in) @ (in, out) = (bs, out)
```

In [None]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [None]:
import sys ; sys.executable

In [None]:
# Imports and previous code

import operator

def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b): test(a,b,operator.eq,'==')

from pathlib import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor

MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def near(a,b): return torch.allclose(a, b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)

In [None]:
def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train, y_train, x_valid, y_valid))

def normalize(x, m, s): return (x-m)/s

In [None]:
x_train, y_train, x_valid, y_valid = get_data()

In [None]:
x_train.shape

In [None]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

In [None]:
x_train = normalize(x_train, m=train_mean, s=train_std)
# NOTE: Use train mean/std to normalize validation!
x_valid = normalize(x_valid, m=train_mean, s=train_std)

In [None]:
x_train.shape, x_valid.shape

In [None]:
train_mean, train_std = x_train.mean(), x_train.std()
train_mean, train_std

In [None]:
def test_near_zero(x, tol=1e-3): assert x.abs()<tol, f'Near zero: {x}'

In [None]:
test_near_zero(x_train.mean())
test_near_zero(1-x_train.std())

In [None]:
n, m = x_train.shape
c = y_train.max()+1
n,m,c

## Basic Architecture

In [None]:
num_hidden = 50

In [None]:
w1 = torch.randn(m, num_hidden) / math.sqrt(m)
b1 = torch.zeros(num_hidden)

w2 = torch.randn(num_hidden, 1) / math.sqrt(num_hidden)
b2 = torch.zeros(1)

In [None]:
test_near_zero(w1.mean())

In [None]:
x_valid.mean(), x_valid.std()

In [None]:
def lin(x, w, b): return x@w + b

In [None]:
t = lin(x_valid, w1, b1)

In [None]:
t.mean(), t.std()

In [None]:
def relu(x): return x.clamp_min(0)

In [None]:
t = relu(lin(x_valid, w1, b1))

In [None]:
t.mean(), t.std()

### Let's try a different init scheme 

Need to get outputs to still have 1 std (instead of getting halved by ReLU!)

In [None]:
w1 = torch.randn(m, num_hidden)*math.sqrt(2/m)

In [None]:
w1.mean(), w1.std()

In [None]:
t = relu(lin(x_valid, w1, b1))
t.mean(), t.std()

In [None]:
def relu2(x): return x.clamp_min(0.0) - 0.5

In [None]:
def model(x):
    t1 = lin(x, w1, b1)
    t2 = relu(t1)
    t3 = lin(t2, w2, b2)
    return t3

In [None]:
%timeit -n 10 _=model(x_valid)

In [None]:
assert list(model(x_valid).shape) == [x_valid.shape[0], 1]

In [None]:
model(x_valid).shape

In [None]:
def mse(output, target):
    assert len(output.shape) == 2, output.shape
    return (output.squeeze(1) - target).pow(2).mean()

In [None]:
y_train, y_valid = y_train.float(), y_valid.float()

In [None]:
preds = model(x_train)

In [None]:
preds.shape

In [None]:
mse(preds, y_train)

# Gradients!

In [None]:
def assert_shape(x, shape:list):
    assert len(x.shape) == len(shape), (x.shape, shape)
    for _a, _b in zip(x.shape, shape):
        if _b != -1:
            assert _a == _b, (x.shape, shape)

In [None]:
def mse_grad(inp, target):
    assert_shape(inp, [-1, 1])
    inp.g = 2*(inp.squeeze()-target).unsqueeze(-1) / inp.shape[0]

In [None]:
def relu_grad(inp, out):
    # not technically relu grad, also takes in previous grad and passes it on
    assert inp.shape == out.shape == out.g.shape
    inp.g = (inp>0).float() * out.g

In [None]:
def lin_grad(inp, out, w, b):
    inp.g = out.g @ w.t()
    w.g = inp.t() @ out.g
    b.g = out.g.sum(0)

In [None]:
def fwd_back(inp, target):
    step1 = inp @ w1 + b1
    step2 = relu(step1)
    step3 = step2 @ w2 + b2
    
    loss = mse(step3, target)
    
    mse_grad(step3, target)
    lin_grad(step2, step3, w2, b2)
    relu_grad(step1, step2)
    lin_grad(inp, step1, w1, b1)

In [None]:
fwd_back(x_train, y_train)

In [None]:
w1g = w1.g.clone()
w2g = w2.g.clone()
b1g = b1.g.clone()
b2g = b2.g.clone()
ig  = x_train.g.clone()

In [None]:
xt2 = x_train.clone().requires_grad_(True)
w12 = w1.clone().requires_grad_(True)
w22 = w2.clone().requires_grad_(True)
b12 = b1.clone().requires_grad_(True)
b22 = b2.clone().requires_grad_(True)

In [None]:
def fwd(inp, target):
    l1 = inp @ w12 + b12
    l2 = relu(l1)
    out = l2 @ w22 + b22

    return mse(out, target)

In [None]:
loss = fwd(xt2, y_train)

In [None]:
loss.backward()

In [None]:
test_near(w22.grad, w2g)
test_near(b22.grad, b2g)
test_near(w12.grad, w1g)
test_near(b12.grad, b1g)
test_near(xt2.grad, ig )