Skip to content

Commit

Permalink
Add lsuv init (#534)
Browse files Browse the repository at this point in the history
* Add lsuv init

* Update changelog

* Formatting

* Remove device management

* Move to torch svd

* Formatting

* Formatting

* Send update to data

* Remove bias and extra weight checks

* Test break

* Fix multiple run bug
  • Loading branch information
MattPainter01 authored and ethanwharris committed Apr 3, 2019
1 parent 0a08885 commit 4e034a7
Show file tree
Hide file tree
Showing 4 changed files with 232 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added ``ImagingCallback`` class for callbacks which produce images that can be sent to tensorboard, visdom or a file
- Added ``CachingImagingCallback`` and ``MakeGrid`` callback to make a grid of images
- Added the option to give the ``only_if`` callback decorator a function of self and state rather than just state
- Added Layer-sequential unit-variance (LSUV) initialization
### Changed
### Deprecated
### Removed
Expand Down
57 changes: 57 additions & 0 deletions tests/callbacks/test_init.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
from unittest import TestCase

from mock import MagicMock, patch
Expand Down Expand Up @@ -64,3 +65,59 @@ def test_bias(self):
mock = MagicMock()
callback.initialiser(mock)
self.assertTrue(mock.bias.data.zero_.call_count == 1)


class TestLsuv(TestCase):
def test_end_to_end(self):
import numpy as np
from torchbearer.callbacks.init import ZeroBias

np.random.seed(7)
torch.manual_seed(7)

class Flatten(torch.nn.Module):
def forward(self, x):
return x.view(x.shape[0], -1)

model = torch.nn.Sequential(
torch.nn.Conv2d(1,1,1),
Flatten(),
torch.nn.Linear(4, 2),
)

state = {torchbearer.MODEL: model}
data = torch.rand(2, 1, 2, 2)
ZeroBias(model.modules()).on_init(state) # LSUV expects biases to be zero
init.LsuvInit(data).on_init(state)

correct_conv_weight = torch.FloatTensor([[[[3.2236]]]])
correct_linear_weight = torch.FloatTensor([[-0.3414, -0.5503, -0.4402, -0.4367],
[0.3425, -0.0697, -0.6646, 0.4900]])

conv_weight = list(model.modules())[1].weight
linear_weight = list(model.modules())[3].weight
diff_conv = (conv_weight-correct_conv_weight) < 0.0001
diff_linear = (linear_weight - correct_linear_weight) < 0.0001
self.assertTrue(diff_conv.all().item())
self.assertTrue(diff_linear.all().item())

def test_break(self):
import numpy as np
from torchbearer.callbacks.init import ZeroBias

np.random.seed(7)
torch.manual_seed(7)

model = torch.nn.Sequential(
torch.nn.Conv2d(1,1,1),
)

with patch('torchbearer.callbacks.lsuv.LSUV.apply_weights_correction') as awc:
state = {torchbearer.MODEL: model}
data = torch.rand(2, 1, 2, 2)
ZeroBias(model.modules()).on_init(state) # LSUV expects biases to be zero
init.LsuvInit(data, std_tol=1e-20, max_attempts=0, do_orthonorm=False).on_init(state)

# torchbearer.callbacks.lsuv.apply_weights_correction = old_fun
self.assertTrue(awc.call_count == 2)

44 changes: 44 additions & 0 deletions torchbearer/callbacks/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
}
"""

__lsuv__ = """
@article{mishkin2015all,
title={All you need is a good init},
author={Mishkin, Dmytro and Matas, Jiri},
journal={arXiv preprint arXiv:1511.06422},
year={2015}
}
"""


class WeightInit(Callback):
"""Base class for weight initialisations. Performs the provided function for each module when on_init is
Expand Down Expand Up @@ -51,6 +60,41 @@ def on_init(self, state):
self.initialiser(m)


@cite(__lsuv__)
class LsuvInit(Callback):
"""Layer-sequential unit-variance (LSUV) initialization as described in
`All you need is a good init <https://arxiv.org/abs/1511.06422>`_ and
modified from the code by `ducha-aiki <https://github.com/ducha-aiki/LSUV-pytorch>`__.
To be consistent with the paper, LsuvInit should be preceeded by a ZeroBias init on the Linear and Conv layers.
Args:
data_item (torch.Tensor: A representative data item to put through the model
weight_lambda (lambda): A function that takes a module and returns the weight attribute. If none defaults to
module.weight.
needed_std: See `paper <https://arxiv.org/abs/1511.06422>`__, where needed_std is always 1.0
std_tol: See `paper <https://arxiv.org/abs/1511.06422>`__, Tol_{var}
max_attempts: See `paper <https://arxiv.org/abs/1511.06422>`__, T_{max}
do_orthonorm: See `paper <https://arxiv.org/abs/1511.06422>`__, first pre-initialise with orthonormal matricies
State Requirements:
- :attr:`torchbearer.state.MODEL`: Model should have the `modules` method if modules is None
"""
def __init__(self, data_item, weight_lambda=None, needed_std=1.0, std_tol=0.1, max_attempts=10, do_orthonorm=True):
from torchbearer.callbacks.lsuv import LSUV
self.lsuv_init = LSUV
self.data = data_item
self.needed_std = needed_std
self.std_tol = std_tol
self.max_attempts = max_attempts
self.do_arthonorm = do_orthonorm
self.weight_lambda = weight_lambda

def on_init(self, state):
lsuv = self.lsuv_init()
state[torchbearer.MODEL] = lsuv.init_model(state[torchbearer.MODEL], self.data, self.weight_lambda, self.needed_std,
self.std_tol, self.max_attempts, self.do_arthonorm)


@cite(__kaiming__)
class KaimingNormal(WeightInit):
"""Kaiming Normal weight initialisation. Uses ``torch.nn.init.kaiming_normal_`` on the ``weight`` attribute of the
Expand Down
130 changes: 130 additions & 0 deletions torchbearer/callbacks/lsuv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright (C) 2017, Dmytro Mishkin
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
# met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the
# distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
# HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

# This file has been modified from https://github.com/ducha-aiki/LSUV-pytorch

import numpy as np
import torch
import torch.nn.init
import torch.nn as nn


class LSUV(object):
def __init__(self):
super(LSUV, self).__init__()
self.gg = self.reset_parameters()

def reset_parameters(self):
self.gg = {
'hook_position': 0,
'total_fc_conv_layers': 0,
'done_counter': -1,
'hook': None,
'act_dict': {},
'counter_to_apply_correction': 0,
'correction_needed': False,
'current_coef': 1.0,
'weight_lambda': lambda m: m.weight,
}
return self.gg

def svd_orthonormal(self, w):
shape = w.shape
flat_shape = (shape[0], np.prod(shape[1:]))
a = torch.rand(flat_shape, device=w.device)
u, _, v = torch.svd(a, some=True)
q = u if u.shape == flat_shape else v.t()
q = q.view(shape)
return q.to(torch.float)

def add_current_hook(self, m):
if self.gg['hook'] is not None:
return
if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
if self.gg['hook_position'] > self.gg['done_counter']:
self.gg['hook'] = m.register_forward_hook(self.store_activations_wrapper())
else:
self.gg['hook_position'] += 1

def count_conv_fc_layers(self, m):
if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
self.gg['total_fc_conv_layers'] += 1

def orthogonal_weights_init(self, m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
weight = self.gg['weight_lambda'](m)
w_ortho = self.svd_orthonormal(weight.data)
m.weight.data = w_ortho.data

def apply_weights_correction(self, m):
if self.gg['hook'] is None or not self.gg['correction_needed']:
return
if (isinstance(m, nn.Conv2d)) or (isinstance(m, nn.Linear)):
if self.gg['counter_to_apply_correction'] < self.gg['hook_position']:
self.gg['counter_to_apply_correction'] += 1
else:
weight = self.gg['weight_lambda'](m)
weight.data *= float(self.gg['current_coef'])
self.gg['correction_needed'] = False

def init_model(self, model, data, weight_lambda=None, needed_std=1.0, std_tol=0.1, max_attempts=10, do_orthonorm=True):
self.gg = self.reset_parameters()
train = True if model.training else False
self.gg['weight_lambda'] = self.gg['weight_lambda'] if weight_lambda is None else weight_lambda

model.eval()
model.apply(self.count_conv_fc_layers)
if do_orthonorm:
model.apply(self.orthogonal_weights_init)
for layer_idx in range(self.gg['total_fc_conv_layers']):
model.apply(self.add_current_hook)
_ = model(data)
current_std = self.gg['act_dict'].std()
attempts = 0
while torch.abs(current_std - needed_std).item() > std_tol:
self.gg['current_coef'] = needed_std / (current_std + 1e-8)
self.gg['correction_needed'] = True
model.apply(self.apply_weights_correction)
_ = model(data)
current_std = self.gg['act_dict'].std()
attempts += 1
if attempts > max_attempts:
break
if self.gg['hook'] is not None:
self.gg['hook'].remove()
self.gg['done_counter'] += 1
self.gg['counter_to_apply_correction'] = 0
self.gg['hook_position'] = 0
self.gg['hook'] = None

if train:
model.train()
return model

def store_activations_wrapper(self):
gg = self.gg
def store_activations(self, input, output):
gg['act_dict'] = output.data
return store_activations

0 comments on commit 4e034a7

Please sign in to comment.