Skip to content

Commit

Permalink
Small changes to net that make grad accumulation easier (#516)
Browse files Browse the repository at this point in the history
* Small changes to net that make grad accumulation easier

* Move some lines around for less indentation

* Correct wrong implementation of gradient accumulation

loss.backward() needs to be called for each batch.

* Divide loss at the correct place.

* Improve and fix test for gradient accumulation

Test different accumulation step sizes, fix a bug in calculating
expected number.

* Add entry to FAQ for how to do gradient accumulation

* Entry to CHANGES.md

* Use correct class name in FAQ example

* Remove unnecessary check
  • Loading branch information
BenjaminBossan authored and ottonemo committed Sep 16, 2019
1 parent a361bc1 commit a62e419
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

- Improve numerical stability when using `NLLLoss` in `NeuralNetClassifer` (#491)
- Refactor code to make gradient accumulation easier to implement (#506)

### Fixed

Expand Down
43 changes: 43 additions & 0 deletions docs/user/FAQ.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,46 @@ To inspect all output values, you can use either the

For an example of how this works, have a look at this `notebook
<https://nbviewer.jupyter.org/github/skorch-dev/skorch/blob/master/notebooks/Advanced_Usage.ipynb#Multiple-return-values-from-forward>`_.

How can I perform gradient accumulation with skorch?
----------------------------------------------------

There is no direct option to turn on gradient accumulation (at least
for now). However, with a few modifications, you can implement
gradient accumulation yourself:


.. code:: python
ACC_STEPS = 2 # number of steps to accumulate before updating weights
class GradAccNet(NeuralNetClassifier):
"""Net that accumulates gradients"""
def __init__(self, *args, acc_steps=ACC_STEPS, **kwargs):
super().__init__(*args, **kwargs)
self.acc_steps = acc_steps
def get_loss(self, *args, **kwargs):
loss = super().get_loss(*args, **kwargs)
return loss / self.acc_steps # normalize loss
def train_step(self, Xi, yi, **fit_params):
"""Perform gradient accumulation
Only optimize every nth batch.
"""
# note that n_train_batches starts at 1 for each epoch
n_train_batches = len(self.history[-1, 'batches'])
step = self.train_step_single(Xi, yi, **fit_params)
if n_train_batches % self.acc_steps == 0:
self.optimizer_.step()
self.optimizer_.zero_grad()
return step
This is not a complete recipe. For example, if you optimize every 2nd
step, and the number of training batches is uneven, you should make
sure that there is an optimization step after the last batch of each
epoch. However, this example can serve as a starting point to
implement your own version gradient accumulation.
6 changes: 3 additions & 3 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,6 @@ def train_step_single(self, Xi, yi, **fit_params):
"""
self.module_.train()
self.optimizer_.zero_grad()
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
loss.backward()
Expand Down Expand Up @@ -656,6 +655,7 @@ def step_fn():
step_accumulator.store_step(step)
return step['loss']
self.optimizer_.step(step_fn)
self.optimizer_.zero_grad()
return step_accumulator.get_step()

def evaluation_step(self, Xi, training=False):
Expand Down Expand Up @@ -730,10 +730,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
yi_res = yi if not y_train_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
step = self.train_step(Xi, yi, **fit_params)
train_batch_count += 1
self.history.record_batch('train_loss', step['loss'].item())
self.history.record_batch('train_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
train_batch_count += 1
self.history.record("train_batch_count", train_batch_count)

if dataset_valid is None:
Expand All @@ -746,10 +746,10 @@ def fit_loop(self, X, y=None, epochs=None, **fit_params):
yi_res = yi if not y_valid_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
step = self.validation_step(Xi, yi, **fit_params)
valid_batch_count += 1
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
valid_batch_count += 1
self.history.record("valid_batch_count", valid_batch_count)

self.notify('on_epoch_end', **on_epoch_kwargs)
Expand Down
55 changes: 55 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -2215,6 +2215,61 @@ def test_set_lr_at_runtime_sets_lr_pgroups(self, net_cls, module_cls, data):
assert net.optimizer_.param_groups[0]['lr'] == lr_pgroup_0_new
assert net.optimizer_.param_groups[1]['lr'] == lr_pgroup_1_new

@pytest.mark.parametrize('acc_steps', [1, 2, 3, 5, 10])
def test_gradient_accumulation(self, net_cls, module_cls, data, acc_steps):
# Test if gradient accumulation technique is possible,
# i.e. performing a weight update only every couple of
# batches.
mock_optimizer = Mock()

class GradAccNet(net_cls):
"""Net that accumulates gradients"""
def __init__(self, *args, acc_steps=acc_steps, **kwargs):
super().__init__(*args, **kwargs)
self.acc_steps = acc_steps

def initialize(self):
# This is not necessary for gradient accumulation but
# only for testing purposes
super().initialize()
self.true_optimizer_ = self.optimizer_
mock_optimizer.step.side_effect = self.true_optimizer_.step
mock_optimizer.zero_grad.side_effect = self.true_optimizer_.zero_grad
self.optimizer_ = mock_optimizer

def get_loss(self, *args, **kwargs):
loss = super().get_loss(*args, **kwargs)
# because only every nth step is optimized
return loss / self.acc_steps

def train_step(self, Xi, yi, **fit_params):
"""Perform gradient accumulation
Only optimize every 2nd batch.
"""
# note that n_train_batches starts at 1 for each epoch
n_train_batches = len(self.history[-1, 'batches'])
step = self.train_step_single(Xi, yi, **fit_params)

if n_train_batches % self.acc_steps == 0:
self.optimizer_.step()
self.optimizer_.zero_grad()
return step

max_epochs = 5
net = GradAccNet(module_cls, max_epochs=max_epochs)
X, y = data
net.fit(X, y)

n = len(X) * 0.8 # number of training samples
b = np.ceil(n / net.batch_size) # batches per epoch
s = b // acc_steps # number of acc steps per epoch
calls_total = s * max_epochs
calls_step = mock_optimizer.step.call_count
calls_zero_grad = mock_optimizer.zero_grad.call_count
assert calls_total == calls_step == calls_zero_grad


class TestNetSparseInput:
@pytest.fixture(scope='module')
Expand Down

0 comments on commit a62e419

Please sign in to comment.