Skip to content

Commit

Permalink
Feature/init (#523)
Browse files Browse the repository at this point in the history
* Add on_init callback hook

* Add tests for on_init callback hook

* Add weight initialisations

* Add weight init callbacks
  • Loading branch information
ethanwharris committed Mar 13, 2019
1 parent cf1485e commit ed91a94
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
- Added cyclic learning rate finder
- Added on_init callback hook to run at the end of trial init
- Added callbacks for weight initialisation in ``torchbearer.callbacks.init``
### Changed
### Deprecated
### Removed
Expand Down
66 changes: 66 additions & 0 deletions tests/callbacks/test_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest import TestCase

from mock import MagicMock, patch

import torchbearer
import torchbearer.callbacks.init as init


class TestWeightInit(TestCase):
def test_modules_from_state(self):
callback = init.WeightInit(targets=['Mock'])
model = MagicMock()
state = {torchbearer.MODEL: model}
callback.on_init(state)
self.assertTrue(model.modules.call_count == 1)

def test_filter(self):
mock = MagicMock()
callback = init.WeightInit(initialiser=lambda m: m.test(), modules=[mock], targets=['Mock'])
callback.on_init({})
self.assertTrue(mock.test.call_count == 1)

mock = MagicMock()
callback = init.WeightInit(initialiser=lambda m: m.test(), modules=[mock], targets=['Not'])
callback.on_init({})
self.assertTrue(mock.test.call_count == 0)

def test_module_list(self):
mock = MagicMock()
callback = init.WeightInit(initialiser=lambda m: m.test(), modules=[mock], targets=['Mock'])
model = MagicMock()
state = {torchbearer.MODEL: model}
callback.on_init(state)
self.assertTrue(model.modules.call_count == 0)


class TestSimpleInits(TestCase):
@patch('torchbearer.callbacks.init.init')
def test_kaiming(self, nn_init):
callback = init.KaimingNormal(a=1, mode='test', nonlinearity='test2')
mock = MagicMock()
callback.initialiser(mock)
nn_init.kaiming_normal_.assert_called_once_with(mock.weight.data, a=1, mode='test', nonlinearity='test2')

callback = init.KaimingUniform(a=1, mode='test', nonlinearity='test2')
mock = MagicMock()
callback.initialiser(mock)
nn_init.kaiming_uniform_.assert_called_once_with(mock.weight.data, a=1, mode='test', nonlinearity='test2')

@patch('torchbearer.callbacks.init.init')
def test_xavier(self, nn_init):
callback = init.XavierNormal(gain=100)
mock = MagicMock()
callback.initialiser(mock)
nn_init.xavier_normal_.assert_called_once_with(mock.weight.data, gain=100)

callback = init.XavierUniform(gain=100)
mock = MagicMock()
callback.initialiser(mock)
nn_init.xavier_uniform_.assert_called_once_with(mock.weight.data, gain=100)

def test_bias(self):
callback = init.ZeroBias()
mock = MagicMock()
callback.initialiser(mock)
self.assertTrue(mock.bias.data.zero_.call_count == 1)
1 change: 1 addition & 0 deletions tests/test_bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_str(self):
def test_empty_methods(self):
callback = Callback()

self.assertIsNone(callback.on_init({}))
self.assertIsNone(callback.on_start({}))
self.assertIsNone(callback.on_start_epoch({}))
self.assertIsNone(callback.on_start_training({}))
Expand Down
9 changes: 6 additions & 3 deletions tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,17 +1873,20 @@ def test_str(self):
torchmodel = "mod"
optimizer = "opt"
metric = tb.metrics.Metric('met')
cb = tb.callbacks.Callback()
cb.on_init = Mock()

torchbearertrial = Trial(torchmodel, optimizer, "crit", [metric], ["cb"])
correct_string = "--------------------- OPTIMZER ---------------------\nopt\n\n-------------------- CRITERION ---------------------\ncrit\n\n--------------------- METRICS ----------------------\n['met']\n\n-------------------- CALLBACKS ---------------------\n['cb']\n\n---------------------- MODEL -----------------------\nmod\n\n"
torchbearertrial = Trial(torchmodel, optimizer, "crit", [metric], [cb])
correct_string = "--------------------- OPTIMZER ---------------------\nopt\n\n-------------------- CRITERION ---------------------\ncrit\n\n--------------------- METRICS ----------------------\n['met']\n\n-------------------- CALLBACKS ---------------------\n['torchbearer.bases.Callback']\n\n---------------------- MODEL -----------------------\nmod\n\n"
self.assertEqual(str(torchbearertrial), correct_string)
self.assertEqual(cb.on_init.call_count, 1)

def test_repr(self):
torchmodel = "mod"
optimizer = "opt"
metric = tb.metrics.Metric('met')

torchbearertrial = Trial(torchmodel, optimizer, "crit", [metric], ["cb"])
torchbearertrial = Trial(torchmodel, optimizer, "crit", [metric], [tb.callbacks.Callback()])
self.assertEqual(str(torchbearertrial), repr(torchbearertrial))

def test_train(self):
Expand Down
8 changes: 8 additions & 0 deletions torchbearer/bases.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def load_state_dict(self, state_dict):
"""
return self

def on_init(self, state):
"""Perform some action with the given state as context at the init of a trial instance
Args:
state (dict): The current state dict of the :class:`.Trial`.
"""
pass

def on_start(self, state):
"""Perform some action with the given state as context at the start of a model fit.
Expand Down
8 changes: 8 additions & 0 deletions torchbearer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@
:members:
:undoc-members:
Weight / Bias Initialisation
------------------------------------
.. automodule:: torchbearer.callbacks.init
:members:
:undoc-members:
Decorators
------------------------------------
Expand All @@ -94,6 +101,7 @@
from .terminate_on_nan import *
from .torch_scheduler import *
from .weight_decay import *
from . import init
from .aggregate_predictions import *
from .decorators import *
from .live_loss_plot import LiveLossPlot
10 changes: 9 additions & 1 deletion torchbearer/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,15 @@ def append(self, callback_list):
self.callback_list = self.callback_list + callback.callback_list
else:
self.callback_list.append(callback)


def on_init(self, state):
"""Call on_init on each callback in turn with the given state.
Args:
state (dict[str,any]): The current state dict of the :class:`.Trial`.
"""
self._for_list(lambda callback: callback.on_init(state))

def on_start(self, state):
"""Call on_start on each callback in turn with the given state.
Expand Down
149 changes: 149 additions & 0 deletions torchbearer/callbacks/init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import torchbearer
from torchbearer import cite
from torchbearer.callbacks import Callback

import torch.nn.init as init

__kaiming__ = """
@inproceedings{he2015delving,
title={Delving deep into rectifiers: Surpassing human-level performance on imagenet classification},
author={He, Kaiming and Zhang, Xiangyu and Ren, Shaoqing and Sun, Jian},
booktitle={Proceedings of the IEEE international conference on computer vision},
pages={1026--1034},
year={2015}
}"""

__xavier__ = """
@inproceedings{glorot2010understanding,
title={Understanding the difficulty of training deep feedforward neural networks},
author={Glorot, Xavier and Bengio, Yoshua},
booktitle={Proceedings of the thirteenth international conference on artificial intelligence and statistics},
pages={249--256},
year={2010}
}
"""


class WeightInit(Callback):
"""Base class for weight initialisations. Performs the provided function for each module when on_init is
called.
Args:
initialiser (lambda): a function which initialises an nn.Module **inplace**
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
State Requirements:
- :attr:`torchbearer.state.MODEL`: Model should have the `modules` method if modules is None
"""
def __init__(self, initialiser=lambda module: module, modules=None, targets=['Conv', 'Linear', 'Bilinear']):
self.initialiser = initialiser
self.modules = modules
self.targets = targets

def on_init(self, state):
if self.modules is None:
self.modules = state[torchbearer.MODEL].modules()

for m in self.modules:
if len(list(filter(lambda target: target in m.__class__.__name__, self.targets))) > 0:
self.initialiser(m)


@cite(__kaiming__)
class KaimingNormal(WeightInit):
"""Kaiming Normal weight initialisation. Uses ``torch.nn.init.kaiming_normal_`` on the ``weight`` attribute of the
filtered modules.
Args:
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
See:
`PyTorch kaiming_normal_ <https://pytorch.org/docs/stable/nn.html#torch.nn.init.kaiming_normal_>`_
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu', modules=None,
targets=['Conv', 'Linear', 'Bilinear']):
def initialiser(module):
init.kaiming_normal_(module.weight.data, a=a, mode=mode, nonlinearity=nonlinearity)

super(KaimingNormal, self).__init__(initialiser, modules=modules, targets=targets)


@cite(__kaiming__)
class KaimingUniform(WeightInit):
"""Kaiming Uniform weight initialisation. Uses ``torch.nn.init.kaiming_uniform_`` on the ``weight`` attribute of the
filtered modules.
Args:
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
See:
`PyTorch kaiming_uniform_ <https://pytorch.org/docs/stable/nn.html#torch.nn.init.kaiming_uniform_>`_
"""
def __init__(self, a=0, mode='fan_in', nonlinearity='leaky_relu', modules=None,
targets=['Conv', 'Linear', 'Bilinear']):
def initialiser(module):
init.kaiming_uniform_(module.weight.data, a=a, mode=mode, nonlinearity=nonlinearity)

super(KaimingUniform, self).__init__(initialiser, modules=modules, targets=targets)


@cite(__xavier__)
class XavierNormal(WeightInit):
"""Xavier Normal weight initialisation. Uses ``torch.nn.init.xavier_normal_`` on the ``weight`` attribute of the
filtered modules.
Args:
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
See:
`PyTorch xavier_normal_ <https://pytorch.org/docs/stable/nn.html#torch.nn.init.xavier_normal_>`_
"""
def __init__(self, gain=1, modules=None, targets=['Conv', 'Linear', 'Bilinear']):
def initialiser(module):
init.xavier_normal_(module.weight.data, gain=gain)

super(XavierNormal, self).__init__(initialiser, modules=modules, targets=targets)


@cite(__xavier__)
class XavierUniform(WeightInit):
"""Xavier Uniform weight initialisation. Uses ``torch.nn.init.xavier_uniform_`` on the ``weight`` attribute of the
filtered modules.
Args:
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
See:
`PyTorch xavier_uniform_ <https://pytorch.org/docs/stable/nn.html#torch.nn.init.xavier_uniform_>`_
"""
def __init__(self, gain=1, modules=None, targets=['Conv', 'Linear', 'Bilinear']):
def initialiser(module):
init.xavier_uniform_(module.weight.data, gain=gain)

super(XavierUniform, self).__init__(initialiser, modules=modules, targets=targets)


class ZeroBias(WeightInit):
"""Zero initialisation for the ``bias`` attributes of filtered modules. This is recommended for use in conjunction
with weight initialisation schemes.
Args:
modules (Iterable[nn.Module] or nn.Module, optional): an iterable of nn.Modules or a
single nn.Module that will have weights initialised, otherwise this is retrieved from the model
targets (list[String]): A list of lookup strings to match which modules will be initialised
"""
def __init__(self, modules=None, targets=['Conv', 'Linear', 'Bilinear']):
def initialiser(module):
module.bias.data.zero_()

super(ZeroBias, self).__init__(initialiser, modules=modules, targets=targets)
2 changes: 2 additions & 0 deletions torchbearer/trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@ def criterion(_, __):
torchbearer.INF_TRAIN_LOADING: False,
})

self.state[torchbearer.CALLBACK_LIST].on_init(self.state)

def __str__(self):
def state_string(name, state_key):
import math
Expand Down

0 comments on commit ed91a94

Please sign in to comment.