In [1]:
%load_ext autoreload
%autoreload 2

%matplotlib inline

In [32]:
# print all outputs from jupyter
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

In [4]:
from pathlib import Path
from IPython.core.debugger import set_trace
from fastai import datasets
import pickle, gzip, math, torch, matplotlib as mpl
import matplotlib.pyplot as plt
from torch import tensor
import operator

In [5]:
def test(a,b,cmp,cname=None):
    if cname is None: cname=cmp.__name__
    assert cmp(a,b),f"{cname}:\n{a}\n{b}"

def test_eq(a,b): test(a,b,operator.eq,'==')
    
def test_near_zero(a,tol=1e-3): assert a.abs()<tol, f"Near zero: {a}"

In [6]:
def near(a,b): return torch.allclose(a,b, rtol=1e-3, atol=1e-5)
def test_near(a,b): test(a,b,near)

In [7]:
MNIST_URL='http://deeplearning.net/data/mnist/mnist.pkl'

def get_data():
    path = datasets.download_data(MNIST_URL, ext='.gz')
    with gzip.open(path, 'rb') as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
    return map(tensor, (x_train,y_train,x_valid,y_valid))

def normalize(x, m, s): return (x-m)/s

In [16]:
torch.__version__


'1.1.0'

In [145]:
from torch.nn import init

from torch import nn

In [222]:
torch.nn.modules.conv._ConvNd.reset_parameters??

In [225]:
torch.nn.init.kaiming_uniform_??

In [19]:
x_train,y_train,x_valid,y_valid = get_data()
train_mean,train_std = x_train.mean(),x_train.std()
x_train = normalize(x_train, train_mean, train_std)
x_valid = normalize(x_valid, train_mean, train_std)

In [20]:
x_train = x_train.view(-1,1,28,28)
x_valid = x_valid.view(-1,1,28,28)
x_train.shape,x_valid.shape

(torch.Size([50000, 1, 28, 28]), torch.Size([10000, 1, 28, 28]))

In [23]:
x = x_valid[:100]

In [24]:
x.shape

torch.Size([100, 1, 28, 28])

In [25]:
def stats(x): return x.mean(), x.std()

In [45]:
import torch.nn.functional as F
f1 = lambda z, leak=0: F.leaky_relu(z,leak)

## Pytorch 1.1 Conv default Init + Relu (No Leaks)

In [226]:
l1 = nn.Conv2d(1, nh, 5)
z1 = l1(x)
a1 = f1(z1, leak=0)

# Print stats
for e, n in zip((x, z1, a1), ('x', 'z1', 'a1')):
    n
    stats(e)
    

'x'

(tensor(-0.0363), tensor(0.9602))

'z1'

(tensor(0.0289, grad_fn=<MeanBackward0>),
 tensor(0.6615, grad_fn=<StdBackward0>))

'a1'

(tensor(0.2368, grad_fn=<MeanBackward0>),
 tensor(0.4579, grad_fn=<StdBackward0>))

## Pytorch 1.1 Conv default Init + Relu (Leak=√5)

In [237]:
l1 = nn.Conv2d(1, nh, 5)
z1 = l1(x)
a1 = f1(z1, 
        leak=math.sqrt(5))

# Print stats
for e, n in zip((x, z1, a1), ('x', 'z1', 'a1')):
    n
    stats(e)

'x'

(tensor(-0.0363), tensor(0.9602))

'z1'

(tensor(0.0121, grad_fn=<MeanBackward0>),
 tensor(0.5706, grad_fn=<StdBackward0>))

'a1'

(tensor(-0.2295, grad_fn=<MeanBackward0>),
 tensor(0.9744, grad_fn=<StdBackward0>))

## Pytorch 1.1 Kaiming Init + Relu (No Leaks)

In [244]:
l1 = nn.Conv2d(1, nh, 5)
_  = init.kaiming_uniform_(l1.weight)
z1 = l1(x)
a1 = f1(z1, 
        leak=0)

# Print stats
for e, n in zip((x, z1, a1), ('x', 'z1', 'a1')):
    n
    stats(e)

'x'

(tensor(-0.0363), tensor(0.9602))

'z1'

(tensor(-0.0431, grad_fn=<MeanBackward0>),
 tensor(1.5606, grad_fn=<StdBackward0>))

'a1'

(tensor(0.5099, grad_fn=<MeanBackward0>),
 tensor(0.8811, grad_fn=<StdBackward0>))

## Kaiming From Scratch + Relu (No Leaks)

In [247]:
def kaiming_uniform_scratch(weights, leak, 
                            use_fan_out=False):
    n_filters, n_channels_in, *_ = weights.shape
    receptive_field = l1.weight[0, 0].shape.numel()
    # compute fan in and fan out
    fan_in  = n_channels_in * receptive_field
    fan_out = n_filters     * receptive_field
    fan = fan_in if not use_fan_out else fan_out
    # init weights
    gain = lambda a : math.sqrt(2.0 / (1 + a**2))
    std = gain(a=leak) / math.sqrt(fan)
    bound = math.sqrt(3.) * std
    weights.data.uniform_(-bound, bound)
    

In [254]:
l1 = nn.Conv2d(1, nh, 5)
kaiming_uniform_scratch(l1.weight,
                        leak=0)
z1 = l1(x)
a1 = f1(z1, leak=0)

for e, n in zip((x, z1, a1), ('x', 'z1', 'a1')):
    n
    stats(e)
    

'x'

(tensor(-0.0363), tensor(0.9602))

'z1'

(tensor(0.0123, grad_fn=<MeanBackward0>),
 tensor(1.4455, grad_fn=<StdBackward0>))

'a1'

(tensor(0.4811, grad_fn=<MeanBackward0>),
 tensor(0.9655, grad_fn=<StdBackward0>))

## Kaiming From Scratch (with leak=√5)+ Relu (with leak=0) (PYTORCH BUG)

In [261]:
l1 = nn.Conv2d(1, nh, 5)
kaiming_uniform_scratch(l1.weight,
                        leak=math.sqrt(5))
z1 = l1(x)
a1 = f1(z1, leak=0)

for e, n in zip((x, z1, a1), ('x', 'z1', 'a1')):
    n
    stats(e)

'x'

(tensor(-0.0363), tensor(0.9602))

'z1'

(tensor(-0.0039, grad_fn=<MeanBackward0>),
 tensor(0.6973, grad_fn=<StdBackward0>))

'a1'

(tensor(0.2428, grad_fn=<MeanBackward0>),
 tensor(0.4405, grad_fn=<StdBackward0>))

# Study of bug impact

#### Define Network Architectures

In [144]:
class Flatten(nn.Module):
    def forward(self, x): return x.view(-1)

In [146]:
def mse(output, targ): 
    return (output.squeeze(-1) - targ).pow(2).mean()

In [149]:
model = nn.Sequential(
        nn.Conv2d( 1,  8, 5, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d( 8, 16, 3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(16, 32, 3, stride=2, padding=1),
        nn.ReLU(),
        nn.Conv2d(32,  1, 3, stride=2, padding=1),
        nn.AdaptiveAvgPool2d(1),
        Flatten()
        )

### Study of variance accross gradients and activations 

In [262]:
y = y_valid[:100].float()


is_conv2d_layer = lambda l : isinstance(l, nn.Conv2d)

def reset_parameters_with_kaiming_uniform(model,leak=0):
    for layer in model:
        if is_conv2d_layer(layer):
            # init weights with kaiming uniform
            kaiming_uniform_scratch(
                weights=layer.weight,
                leak=leak)
            # init bias to zero
            layer.bias.data.zero_()

#### Pytorch Bug Init & Relu No Leak


In [265]:
reset_parameters_with_kaiming_uniform(model, leak=math.sqrt(5))

out = model(x)

"output activation stats"
stats(out)

loss = mse(out, y)
f"loss = {loss}"

loss.backward()
"Gradient stats from input layer to ouput layer"

conv_layers = (l for l in model if is_conv2d_layer(l))
acti = x
for i, cl in enumerate(conv_layers):
    f'layer {i+1}'
    acti = f1(cl(acti), leak=0)
    stats(acti)
    stats(cl.weight.grad)

'output activation stats'

(tensor(0.0010, grad_fn=<MeanBackward0>),
 tensor(0.0010, grad_fn=<StdBackward0>))

'loss = 30.70157241821289'

'Gradient stats from input layer to ouput layer'

'layer 1'

(tensor(0.1718, grad_fn=<MeanBackward0>),
 tensor(0.2529, grad_fn=<StdBackward0>))

(tensor(-0.1673), tensor(0.4589))

'layer 2'

(tensor(0.0347, grad_fn=<MeanBackward0>),
 tensor(0.0547, grad_fn=<StdBackward0>))

(tensor(-0.0167), tensor(0.8515))

'layer 3'

(tensor(0.0066, grad_fn=<MeanBackward0>),
 tensor(0.0109, grad_fn=<StdBackward0>))

(tensor(0.0706), tensor(0.8574))

'layer 4'

(tensor(0.0018, grad_fn=<MeanBackward0>),
 tensor(0.0023, grad_fn=<StdBackward0>))

(tensor(-22.9978), tensor(7.5723))

#### Kaiming No Leak and Relu No Leak

In [266]:
reset_parameters_with_kaiming_uniform(model, leak=0)

out = model(x)

"output activation stats"
stats(out)

loss = mse(out, y)
f"loss = {loss}"

loss.backward()
"Gradient stats from input layer to ouput layer"

conv_layers = (l for l in model if is_conv2d_layer(l))
acti = x
for i, cl in enumerate(conv_layers):
    f'layer {i+1}'
    acti = f1(cl(acti), leak=0)
    stats(acti)
    stats(cl.weight.grad)

'output activation stats'

(tensor(0.0091, grad_fn=<MeanBackward0>),
 tensor(0.0874, grad_fn=<StdBackward0>))

'loss = 30.674882888793945'

'Gradient stats from input layer to ouput layer'

'layer 1'

(tensor(0.4704, grad_fn=<MeanBackward0>),
 tensor(1.1132, grad_fn=<StdBackward0>))

(tensor(-0.1541), tensor(0.4571))

'layer 2'

(tensor(0.3427, grad_fn=<MeanBackward0>),
 tensor(0.7312, grad_fn=<StdBackward0>))

(tensor(0.0052), tensor(0.9036))

'layer 3'

(tensor(0.1310, grad_fn=<MeanBackward0>),
 tensor(0.2910, grad_fn=<StdBackward0>))

(tensor(0.0816), tensor(0.8955))

'layer 4'

(tensor(0.0804, grad_fn=<MeanBackward0>),
 tensor(0.1220, grad_fn=<StdBackward0>))

(tensor(-23.9137), tensor(7.8495))

# ...

In [246]:
# Why use last????

l1.weight.numel()
l1.weight[0].numel()
l1.weight[0, 0].numel()
l1.weight[0, 0].shape.numel()

800

25

25

25

In [219]:
U = torch.FloatTensor(10000).uniform_(-1, 1)

In [220]:
U

tensor([-0.4685, -0.5267, -0.8348,  ...,  0.7095, -0.4523, -0.0186])

In [221]:
stats(U)

stats(U * math.sqrt(3))

(tensor(-0.0030), tensor(0.5793))

(tensor(-0.0052), tensor(1.0033))