In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from exp.nb_07 import *

## Layerwise Sequential Unit Variance (LSUV)

In [55]:
x_train, y_train, x_valid, y_valid = get_data()

x_train, x_valid = normalize_to(x_train, x_valid)
train_ds, valid_ds = Dataset(x_train, y_train), Dataset(x_valid, y_valid)

nh, bs = 50, 512
c = y_train.max().item() + 1
loss_func = F.cross_entropy

data = DataBunch(*get_dls(train_ds, valid_ds, bs), c)

In [56]:
mnist_view = view_tfm(1, 28, 28)
cbfs = [Recorder,
        partial(AvgStatsCallback, accuracy),
        CudaCallback,
        partial(BatchTransformXCallback, mnist_view)]

In [57]:
nfs = [8, 16, 32, 64, 64]

In [58]:
class ConvLayer(nn.Module):
    def __init__(self, ni, nf, ks=3, stride=2, sub=0., **kwargs):
        super().__init__()
        self.conv = nn.Conv2d(ni, nf, ks, padding=ks//2, stride=stride, bias=True)
        self.relu = GeneralRelu(sub=sub, **kwargs)
        
    def forward(self, x): return self.relu(self.conv(x))
    
    @property
    def bias(self): return - self.relu.sub
    @bias.setter
    def bias(self, v): self.relu.sub = - v
    @property
    def weight(self): return self.conv.weight

In [59]:
learn, run = get_learn_run(nfs, data, .6, ConvLayer, cbs=cbfs)

In [60]:
run.fit(2, learn)

train: [2.255385625, tensor(0.2134, device='cuda:0')]
valid: [1.7634078125, tensor(0.3449, device='cuda:0')]
train: [0.688213515625, tensor(0.7721, device='cuda:0')]
valid: [0.183526904296875, tensor(0.9435, device='cuda:0')]


In [61]:
[cb for cb in run.cbs if getattr(cb, 'begin_batch', None)]

[<exp.nb_06.CudaCallback at 0x7fbc0b26ea20>,
 <exp.nb_06.BatchTransformXCallback at 0x7fbc0b26e6a0>]

In [62]:
def get_batch(dl, run):
    run.xb, run.yb = next(iter(dl))
    for cb in run.cbs: cb.set_runner(run)
    run('begin_batch')
    return run.xb, run.yb

In [63]:
xb, yb = get_batch(data.train_dl, run)

In [64]:
def find_modules(m, cond):
    if cond(m): return [m]
    return sum([find_modules(o, cond) for o in m.children()], [])

def is_lin_layer(l):
    lin_layers = (nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.Linear, nn.ReLU)
    return isinstance(l, lin_layers)

In [65]:
sum([[4], [5], [6, 7], [34, [56, 23]]], [])

[4, 5, 6, 7, 34, [56, 23]]

In [66]:
mods = find_modules(learn.model, lambda o: isinstance(o, ConvLayer))

In [67]:
mods

[ConvLayer(
   (conv): Conv2d(1, 8, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(8, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 ), ConvLayer(
   (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
   (relu): GeneralRelu()
 )]

In [68]:
def append_stat(hook, mod, inp, outp):
    d = outp.data
    hook.mean, hook.std = d.mean().item(), d.std().item()

In [69]:
mdl = learn.model.cuda()

In [70]:
with Hooks(mods, append_stat) as hooks:
    mdl(xb)
    for h in hooks: print(h.mean, h.std)

0.523516058921814 1.1133036613464355
0.3681550621986389 1.437757968902588
0.3139904737472534 1.2075289487838745
0.29176145792007446 0.9788991808891296
0.43855100870132446 1.3311876058578491


In [71]:
def lsuv_module(m, xb):
    h = Hook(m, append_stat)
    while mdl(xb) is not None and abs(h.mean) > 1e-3: m.bias -= h.mean
    while mdl(xb) is not None and abs(h.std - 1) > 1e-3: m.weight.data /= h.std
    h.remove()
    return h.mean, h.std

In [72]:
for m in mods: print(lsuv_module(m, xb))

(-0.07046079635620117, 0.9998978972434998)
(-0.7317276000976562, 0.9994165301322937)
(-1.7011597156524658, 1.0008671283721924)
(-8.042387962341309, 1.0009326934814453)
(-40.43047332763672, 0.9998490810394287)


In [73]:
%time run.fit(2, learn)

train: [inf, tensor(0.1020, device='cuda:0')]
valid: [1.7199740484977972e+22, tensor(0.1030, device='cuda:0')]
train: [1.7216806674725205e+22, tensor(0.1020, device='cuda:0')]
valid: [1.7199740484977972e+22, tensor(0.1030, device='cuda:0')]
CPU times: user 2.38 s, sys: 436 ms, total: 2.82 s
Wall time: 2.84 s
