# Lesson 8 Making Relu / Initialization 

<img src="https://snag.gy/Uy9qxS.jpg" style="width:700px"/>

In [3]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
#export
from exp.nb_lesson82 import *

def get_data():
    """
    Loads the MNIST data from before
    """
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): 
    """
    Normalizes an input array
    Subtract the mean and divide by standard dev
    result should be mean 0, std 1
    """
    return (x-m)/s

def test_near_zero(a,tol=1e-3): 
    assert a.abs()<tol, f"Near zero: {a}"


#### Load the MNIST data and normalize

In [9]:
# load the data
x_train, y_train, x_valid, y_valid = get_data()

# calculate the mean and standard deviation
train_mean,train_std = x_train.mean(),x_train.std()
print("original mean and std:", train_mean,train_std)

# normalize the values
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

# check the updated values
train_mean,train_std = x_train.mean(),x_train.std()
print("normalized mean and std:", train_mean, train_std)

original mean and std: tensor(0.1304) tensor(0.3073)
normalized mean and std: tensor(0.0001) tensor(1.)


In [11]:
# check to ensure that mean is near zero
test_near_zero(x_train.mean())

# check to ensure that std is near zero
test_near_zero(1-x_train.std())

### Take a look at the training data

Note the size of the training set

In [12]:
n,m = x_train.shape
c = y_train.max()+1
n,m,c

(50000, 784, tensor(10))

# Our first model

Our first model will have 50 hidden units. It will also have two hidden layers:

1. first layer (`w1`): will be size of `input_shape` x `hidden units`
2. second layer (`w2`): will be size of `hidden units`

In [15]:
# our linear layer definition

def lin(x, w, b):
    return x@w + b

# number of hidden units
nh = 50

# initialize our weights and bias
# simplified kaiming init / he init
w1 = torch.randn(m,nh)/math.sqrt(m)
b1 = torch.zeros(nh)

w2 = torch.randn(nh,1)/math.sqrt(nh)
b2 = torch.zeros(1)

## getting normalized weights

If we want our weights to similiarily be between 0 and 1. We will divide by these various factors so that the output should also have a mean 0 and standard deviation 1. this is typically done with kaiming normal, but we are approximating it by dividing by sqrt

In [16]:
t = lin(x_valid, w1, b1)
print(t.mean(), t.std())

tensor(-0.0155) tensor(1.0006)


Initialization weights matters. Example: Large network was trained with very specific weight initialization [https://arxiv.org/abs/1901.09321](https://arxiv.org/abs/1901.09321). It turns out even in one-cycle training, those first iterations are very important. We will come back to this

<img src='https://snag.gy/osvYL4.jpg' style='width:600px' />

### Our ReLu (Rectified Linear Unit)

In [17]:
def relu(x):
    """
    Will return itself, unless its below 0
    then will return 0
    """
    return x.clamp_min(0.)


#### Check for mean 0 std 1

This will not be true, because all negative values will be changed 0, so the mean will not be zero and the std will vary as well

In [18]:
t = relu(lin(x_valid, w1, b1))
print(t.mean(), t.std())

tensor(0.3896) tensor(0.5860)


#### How to deal with Relu --> (0,1) 


Imagenet Winners paper

Competition winners papers have many good ideas. This introduces ReLu, resnet, kaiming normalization 

<img src='https://snag.gy/qeJVki.jpg' style='width:600px' />

<img src='https://snag.gy/E6efz4.jpg' style='width:600px' />

In section 2.2

    "Rectifier networks are easier to train"
    "Very deep models > 8 conv layers have difficulties to converge"
    
You may see `Glorot` initialization (2010). Great paper, and highly influential.

<img src='https://snag.gy/NmqKbJ.jpg' style='width:600px' />

One suggestion to initialize was this one:

<img src='https://snag.gy/gAJQUz.jpg' style='width:600px' />

So the imagenet folks modified the equation to account for relu

$$\text{std} = \sqrt{\frac{2}{(1 + a^2) \times \text{fan_in}}}$$


In [19]:
# kaiming init / he init for relu
w1 = torch.randn(m,nh)*math.sqrt(2/m)

In [21]:
w1.mean(),w1.std()

(tensor(0.0003), tensor(0.0506))

### and now the result is much closer to mean 0, std 1

In [22]:
t = relu(lin(x_valid, w1, b1))
t.mean(),t.std()

(tensor(0.5896), tensor(0.8658))

The paper is worth digging into. Another interesting topic they address is that conv layer is very similar to a matrix multiplication

<img src='https://snag.gy/SB5yFZ.jpg' style='width:600px' />

Then they take you step by step of how the variance changes throughout the network

<img src='https://snag.gy/mypw3u.jpg' style='width:600px' />

Forward pass is a matrix multiply and backward pass is a matrix multiply with a transpose. And they finally recommend sqrt(2 over activations). Now that we understand how to normalize weights and how to calculate the kaiming normal, lets use the pytorch version of it

In [23]:
#export
from torch.nn import init

w1 = torch.zeros(m,nh)
init.kaiming_normal_(w1, mode='fan_out')
t = relu(lin(x_valid, w1, b1))

#### Fan in or Fan out

    mode: either 'fan_in' (default) or 'fan_out'. Choosing `fan_in`
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing `fan_out` preserves the magnitudes in the
            backwards pass.
            
So why are we doing `fan_out`? Are you dividing by `row(m)` or by `row(nh)`. BEcause our weight shape is 784 x 50. pytorch actually does the reverse (50 x 784). How does this work?

```python
import torch.nn
torch.nn.Linear(m,nh).weight.shape
```

```bash
    torch.Size([50, 784])
```

Docstring

```python
torch.nn.Linear.forward??
```

```python
...
# Source:   
    @weak_script_method
    def forward(self, input):
        return F.linear(input, self.weight, self.bias)
...
```

In pytorch **`F`** always refers to **`torch.nn.functional`**

Docstring

```python
torch.nn.functional.linear??
```

```python
@torch._jit_internal.weak_script
def linear(input, weight, bias=None):
    # type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
    r"""
    Applies a linear transformation to the incoming data: :math:`y = xA^T + b`.

    Shape:

        - Input: :math:`(N, *, in\_features)` where `*` means any number of
          additional dimensions
        - Weight: :math:`(out\_features, in\_features)`
        - Bias: :math:`(out\_features)`
        - Output: :math:`(N, *, out\_features)`
    """
    if input.dim() == 2 and bias is not None:
        # fused op is marginally faster
        ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
    else:
        output = input.matmul(weight.t())
        if bias is not None:
            output += torch.jit._unwrap_optional(bias)
        ret = output
    return ret
```

We see in the doc string that we do the transpose with the following phrase **`weight.t()`** and that's why the dimensions are flipped

#### What does pytorch do for Conv2d layers?

```python
torch.nn.Conv2d??
```

Turns out that all the code is passed to another class

```python
#Source:
class Conv2d(_ConvNd):
...

#File:           ~/envs/py3/lib/python3.6/site-packages/torch/nn/modules/conv.py
#Type:           type
```

So if we dig to the next level of the library

```python
torch.nn.modules.conv._ConvNd.reset_parameters??
```

```python
# Source:
    def reset_parameters(self):
        n = self.in_channels
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
```

note that it is divided by **math.sqrt(5)** which turns out not to perform very well. 

### Back to activation functions

now we see that the mean is zero and the standard deviation is close to 1

In [32]:
def relu(x): 
    return x.clamp_min(0.) - 0.5


for i in range(10):
    # kaiming init / he init for relu
    w1 = torch.randn(m,nh)*math.sqrt(2./m )
    t1 = relu(lin(x_valid, w1, b1))
    print(t1.mean(), t1.std(), '| ')

tensor(0.0482) tensor(0.7982) | 
tensor(0.0316) tensor(0.8060) | 
tensor(0.1588) tensor(0.9367) | 
tensor(0.0863) tensor(0.8403) | 
tensor(-0.0310) tensor(0.7310) | 
tensor(0.0467) tensor(0.7965) | 
tensor(0.1252) tensor(0.8700) | 
tensor(-0.0610) tensor(0.7189) | 
tensor(0.0264) tensor(0.7755) | 
tensor(0.1081) tensor(0.8605) | 


#### So where are we now? Fully connected Layers

<img src="https://snag.gy/Uy9qxS.jpg" style="width:700px"/>

## Our first model

In [33]:
def relu(x): 
    return x.clamp_min(0.) - 0.5

def lin(x, w, b):
    return x@w + b

def model(xb):
    l1 = lin(xb, w1, b1)
    l2 = relu(l1)
    l3 = lin(l2, w2, b2)
    return l3

In [35]:
# timing it on the validation set
%timeit -n 10 _=model(x_valid)

6.71 ms ± 456 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [36]:
assert model(x_valid).shape==torch.Size([x_valid.shape[0],1])

### Loss Functions : MSE

We need **`squeeze()`** to get rid of that trailing **`(,1)`**, in order to use mse. (Of course, mse is not a suitable loss function for multi-class classification; we'll use a better loss function soon. We'll use mse for now to keep things simple.)



In [37]:
model(x_valid).shape

torch.Size([10000, 1])

In [38]:
#export
def mse(output, targ): 
    # we want to drop the last dimension
    return (output.squeeze(-1) - targ).pow(2).mean()

In [40]:
# converting to float (from tensors)
y_train, y_valid = y_train.float(), y_valid.float()

# make our predictions
preds = model(x_train)
print(preds.shape)

# check our mse
print(mse(preds, y_train))

torch.Size([50000, 1])
tensor(22.1963)


## Gradients

How much should you know about matrix calculus? It's up to you, but there's a great reference article:

<img src='https://snag.gy/URwGZu.jpg' style='width:600px' />

One thing you should learn is the **chain rule**. 

When we are working through our functions, the order works like this:

<img src='https://snag.gy/MsT82c.jpg' style='width:600px' />

Then we can simplify to:

$y = f(u)$

$u = g(x)$

And then the derivative is:

$ \frac{dy}{dx} = \frac{dy}{du} \frac{du}{dx}$

<img src='https://snag.gy/4GJa9C.jpg' style='width:200px' />

The representation looks like this:

<img src='https://snag.gy/hdNabj.jpg' style='width:200px' />


#### To do the chain rule, start backwards..

In [41]:
def mse_grad(inp, targ): 
    # grad of loss with respect to output of previous layer
    # the derivative of squared output x^2 => 2x
    inp.g = 2. * (inp.squeeze() - targ).unsqueeze(-1) / inp.shape[0]

In [None]:
def relu_grad(inp, out):
    # grad of relu with respect to input activations
    inp.g = (inp>0).float() * out.g