Skip to content

Commit

Permalink
Support for gradient accumulation with accelerate (#874)
Browse files Browse the repository at this point in the history
Support for gradient accumulation with accelerate

The latest accelerate release added gradient accumulation support. This
requires the AccelerateMixin to call the training loop within a context
manager. This is now done. Users can therefore use the gradient
accumulation feature of accelerate.

Furthermore, the learning rate scheduler is now also prepared if it was
used with skorch's LRScheduler callback.

Note:
If users don't use skorch's LRScheduler callback, it cannot be prepared
because there is no reliable way of detecting its use.
  • Loading branch information
BenjaminBossan committed Jul 22, 2022
1 parent d7cfdcc commit 9bc8fe6
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 6 deletions.
21 changes: 21 additions & 0 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

from skorch.cli import parse_args # pylint: disable=unused-import
from skorch.callbacks import LRScheduler
from skorch.dataset import unpack_data
from skorch.utils import _make_split
from skorch.utils import to_numpy
Expand Down Expand Up @@ -653,6 +654,26 @@ def _initialize_optimizer(self, *args, **kwargs):

return self

def initialize_callbacks(self, *args, **kwargs):
super().initialize_callbacks(*args, **kwargs)

for _, callback in self.callbacks_:
if isinstance(callback, LRScheduler):
callback.policy_ = self.accelerator.prepare(callback.policy_)

return self

def train_step(self, batch, **fit_params):
# Call training step within the accelerator context manager
with self.accelerator.accumulate(self.module_):
# Why are we passing only module_ here, even though there might be
# other modules as well? First of all, there is no possibility to
# pass multiple modules. Second, the module_ is only used to
# determine if Distributed Data Parallel is being used, not for
# anything else. Therefore, passing module_ should be sufficient
# most of the time.
return super().train_step(batch, **fit_params)

def train_step_single(self, batch, **fit_params):
self._set_training(True)
Xi, yi = unpack_data(batch)
Expand Down
83 changes: 77 additions & 6 deletions skorch/tests/test_helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test for helper.py"""
import pickle
from contextlib import contextmanager
from distutils.version import LooseVersion
from functools import partial
from unittest.mock import Mock
Expand Down Expand Up @@ -888,14 +889,15 @@ def test_print_log_sink_uses_print_if_accelerator_has_no_print(

def test_all_components_prepared(self, module_cls, data):
# We cannot test whether accelerate is really performing its job.
# Instead, we test that all modules and optimizers, even custom
# user-defined ones, are properly prepared. We also test that
# Instead, we test that all modules, optimizers, and lr schedulers, even
# custom user-defined ones, are properly prepared. We also test that
# loss.backward() is called. This means that we do test implementation
# details of accelerate that may change in the future.
from skorch import NeuralNetClassifier
from skorch.callbacks import LRScheduler
from skorch.helper import AccelerateMixin

# pylint: disable=missing-class-docstring
# pylint: disable=missing-docstring
class MockAccelerator:
def __init__(self):
self.device_placement = True
Expand All @@ -913,7 +915,11 @@ def backward(self, loss, **kwargs):
def unwrap_model(self, model):
return model

# pylint: disable=missing-class-docstring
@contextmanager
def accumulate(self, model):
yield

# pylint: disable=missing-docstring,arguments-differ
class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
def get_iterator(self, *args, **kwargs):
iterator = super().get_iterator(*args, **kwargs)
Expand Down Expand Up @@ -952,10 +958,14 @@ def infer(self, *args, **kwargs):
return super().infer(*args, **kwargs)

def train_step_single(self, *args, **kwargs):
# check that all optimizers are prepared and that
# loss.backward() was called
# check that all optimizers and the lr scheduler are prepared,
# and that loss.backward() was called,
assert self.optimizer_.is_prepared
assert self.optimizer2_.is_prepared

lr_scheduler = dict(self.callbacks_)['lr_scheduler'].policy_
assert lr_scheduler.is_prepared

output = super().train_step_single(*args, **kwargs)
assert output['loss'].backward_was_called
return output
Expand All @@ -966,8 +976,69 @@ def train_step_single(self, *args, **kwargs):
device=None,
accelerator=accelerator,
max_epochs=2,
callbacks=[('lr_scheduler', LRScheduler)],
)
X, y = data
# does not raise
net.fit(X, y)
net.predict(X)

# make sure that even after resetting parameters, components are still prepared
net.set_params(
module__hidden_units=7,
lr=0.05,
batch_size=33,
criterion__reduction='sum',
callbacks__lr_scheduler__policy=torch.optim.lr_scheduler.ReduceLROnPlateau,
)
# does not raise
net.fit(X, y)
net.predict(X)

def test_gradient_accumulation_with_accelerate(
self, module_cls, accelerator_cls, data
):
# Check that using gradient accumulation provided by accelerate actually
# works. Testing this is not quite trivial. E.g. we cannot check haven
# often optimizer.step() is called because accelerate still calls it on
# each step but does not necessarily update the weights. Therefore, we
# check if there was an update step by comparing the weights before and
# after the train_step call. If the weights changed, then there was a
# step, otherwise not.
from skorch import NeuralNetClassifier
from skorch.helper import AccelerateMixin

def weight_sum(module):
return sum(weights.sum() for weights in module.parameters())

# Record for each training step if there was an update of the weights
updated = []

# pylint: disable=missing-docstring
class GradAccNet(AccelerateMixin, NeuralNetClassifier):
# pylint: disable=arguments-differ
def train_step(self, *args, **kwargs):
# Note: We use a very simplified way of checking if weights were
# updated by just comparing their sum. This way, we don't need
# to keep a copy around.
weight_sum_before = weight_sum(self.module_)
step = super().train_step(*args, **kwargs)
weight_sum_after = weight_sum(self.module_)
update_occurred = (weight_sum_before != weight_sum_after).item()
updated.append(update_occurred)
return step

max_epochs = 2
acc_steps = 3
accelerator = accelerator_cls(gradient_accumulation_steps=acc_steps)
net = GradAccNet(module_cls, accelerator=accelerator, max_epochs=max_epochs)
X, y = data
net.fit(X, y)

# Why we expect this outcome: Since acc_steps is 3, we expect that
# updated should be [False, False, True]. However, since we have 1000
# samples and a batch size of 128, every 7th batch is the last batch of
# the epoch, after which there should also be an update. Therefore,
# every 7th entry is also True.
updated_expected = [False, False, True, False, False, True, True] * max_epochs
assert updated == updated_expected

0 comments on commit 9bc8fe6

Please sign in to comment.