Skip to content

Commit

Permalink
Experimental support for accelerate (#826)
Browse files Browse the repository at this point in the history
Add support for mixed precision training and more with accelerate

https://github.com/huggingface/accelerate

The feature is treated as experimental for now. This is why we don't make it
part of the core classes yet but instead have the user create a new class that
inherits from the mixin.

Mixed precision training was tested on Turing and Ampere architectures
successfully. Some of the other accelerate features, such as DeepSpeed
integration, were not tested.

Implementation

When using AMP, accelerate applies grad scaling under the hood using
GradScaler. That does not support passing the train step as a closure to
optimizer.step. Therefore, we need to step explicitly.

We could use a more sophisticated approach of trying to figure out if
grad scaler is actually being used and only stepping explicitly if
needed. However, the need for the closure is quite rare and we want to
treat accelerate as a black box instead of relying on implementation
details (which we would have to in order to figure out when grad scaling
is applied).
  • Loading branch information
BenjaminBossan committed Mar 4, 2022
1 parent f5bb1a5 commit c0df4d1
Show file tree
Hide file tree
Showing 6 changed files with 403 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added
- Added `load_best` attribute to `EarlyStopping` callback to automatically load module weights of the best result at the end of training
- Added experimental support for [huggingface accelerate](https://github.com/huggingface/accelerate); use the provided mixin class to add advanced training capabilities provided by the accelerate library to skorch

### Changed
- Initialize data loaders for training and validation dataset once per fit call instead of once per epoch ([migration guide](https://skorch.readthedocs.io/en/stable/user/FAQ.html#migration-from-0-11-to-0-12))
Expand Down
54 changes: 54 additions & 0 deletions docs/user/helper.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,58 @@ argument ``idx=0``, the default) and one for y (with argument
gs.fit(X_sl, y_sl)
AccelerateMixin
---------------

This mixin class can be used to add support for huggingface accelerate_ to
skorch. E.g., this allows you to use mixed precision training (AMP), multi-GPU
training, or training with a TPU. For the time being, this feature should be
considered experimental.

To use this feature, create a new subclass of the neural net class you want to
use and inherit from the mixin class. E.g., if you want to use a
:class:`.NeuralNet`, it would look like this:

.. code:: python
from skorch import NeuralNet
from skorch.helper import AccelerateMixin
class AcceleratedNet(AccelerateMixin, NeuralNet):
"""NeuralNet with accelerate support"""
The same would work for :class:`.NeuralNetClassifier`,
:class:`.NeuralNetRegressor`, etc. Then pass an instance of Accelerator_ with
the desired parameters and you're good to go:

.. code:: python
from accelerate import Accelerator
accelerator = Accelerator(...)
net = AcceleratedNet(
MyModule,
accelerator=accelerator,
)
net.fit(X, y)
accelerate_ recommends to leave the device handling to the Accelerator_, which
is why ``device`` defautls to ``None`` (thus telling skorch not to change the
device).

To install accelerate_, run the following command inside your Python environment:

.. code:: bash
python -m pip install accelerate
.. note::

Under the hood, accelerate uses :class:`~torch.cuda.amp.GradScaler`,
which does not support passing the training step as a closure.
Therefore, if your optimizer requires that (e.g.
:class:`torch.optim.LBFGS`), you cannot use accelerate.

Command line interface helpers
------------------------------

Expand Down Expand Up @@ -201,6 +253,8 @@ callbacks through the command line (but you can modify existing ones
as usual).
.. _accelerate: https://github.com/huggingface/accelerate
.. _Accelerator: https://huggingface.co/docs/accelerate/accelerator.html
.. _fire: https://github.com/google/python-fire
.. _numpydoc: https://github.com/numpy/numpydoc
.. _example: https://github.com/skorch-dev/skorch/tree/master/examples/cli
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
accelerate
fire
flaky
future>=0.17.1
Expand Down
161 changes: 161 additions & 0 deletions skorch/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch

from skorch.cli import parse_args # pylint: disable=unused-import
from skorch.dataset import unpack_data
from skorch.utils import _make_split
from skorch.utils import is_torch_data_type
from skorch.utils import to_tensor
Expand Down Expand Up @@ -508,3 +509,163 @@ def describe_signature(self, df):
)

return signature


class AccelerateMixin:
"""Mixin class to add support for huggingface accelerate
This is an *experimental* feature.
Use this mixin class with one of the neural net classes (e.g. ``NeuralNet``,
``NeuralNetClassifier``, or ``NeuralNetRegressor``) and pass an instance of
``Accelerator`` for mixed precision, multi-GPU, or TPU training.
Install the accelerate library using:
.. code-block::
python -m pip install accelerate
skorch does not itself provide any facilities to enable these training
features. A lot of them can still be implemented by the user with a little
bit of extra work but it can be a daunting task. That is why this helper
class was added: Using this mixin in conjunction with the accelerate library
should cover a lot of common use cases.
.. note::
Under the hood, accelerate uses :class:`~torch.cuda.amp.GradScaler`,
which does not support passing the training step as a closure.
Therefore, if your optimizer requires that (e.g.
:class:`torch.optim.LBFGS`), you cannot use accelerate.
.. warning::
Since accelerate is still quite young and backwards compatiblity
breaking features might be added, we treat its integration as an
experimental feature. When accelerate's API stabilizes, we will consider
adding it to skorch proper.
Examples
--------
>>> from skorch import NeuralNetClassifier
>>> from skorch.helper import AccelerateMixin
>>> from accelerate import Accelerator
>>>
>>> class AcceleratedNet(AccelerateMixin, NeuralNetClassifier):
... '''NeuralNetClassifier with accelerate support'''
>>>
>>> accelerator = Accelerator(...)
>>> net = AcceleratedNet(MyModule, accelerator=accelerator)
>>> net.fit(X, y)
The same approach works with all the other skorch net classes.
Parameters
----------
accelerator : accelerate.Accelerator
In addition to the usual parameters, pass an instance of
``accelerate.Accelerator`` with the desired settings.
device : str, torch.device, or None (default=None)
The compute device to be used. When using accelerate, it is recommended to
leave device handling to accelerate. Therefore, it is best to leave this
argument to be None, which means that skorch does not set the device.
callbacks__print_log__sink : 'auto' or callable
If 'auto', uses the ``print`` function of the accelerator, if it has one.
This avoids printing the same output multiple times when training
concurrently on multiple machines. If the accelerator does not have a
``print`` function, use Python's ``print`` function instead.
"""
def __init__(
self,
*args,
accelerator,
device=None,
callbacks__print_log__sink='auto',
**kwargs
):
super().__init__(
*args,
device=device,
callbacks__print_log__sink=callbacks__print_log__sink,
**kwargs
)
self.accelerator = accelerator

def _check_kwargs(self, kwargs):
super()._check_kwargs(kwargs)

if self.accelerator.device_placement and (self.device is not None):
raise ValueError(
"When device placement is performed by the accelerator, set device=None"
)

def _initialize_callbacks(self):
if self.callbacks__print_log__sink == 'auto':
print_func = getattr(self.accelerator, 'print', print)
self.callbacks__print_log__sink = print_func
super()._initialize_callbacks()
return self

def _initialize_criterion(self, *args, **kwargs):
super()._initialize_criterion(*args, **kwargs)

with self._current_init_context('criterion'):
for name in self._criteria:
criterion = getattr(self, name + '_')
if isinstance(criterion, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(criterion))

return self

def _initialize_module(self, *args, **kwargs):
super()._initialize_module(*args, **kwargs)

with self._current_init_context('module'):
for name in self._modules:
module = getattr(self, name + '_')
if isinstance(module, torch.nn.Module):
setattr(self, name + '_', self.accelerator.prepare(module))

return self

def _initialize_optimizer(self, *args, **kwargs):
super()._initialize_optimizer(*args, **kwargs)

with self._current_init_context('optimizer'):
for name in self._optimizers:
optimizer = getattr(self, name + '_')
if isinstance(optimizer, torch.optim.Optimizer):
setattr(self, name + '_', self.accelerator.prepare(optimizer))

return self

def train_step_single(self, batch, **fit_params):
self._set_training(True)
Xi, yi = unpack_data(batch)
y_pred = self.infer(Xi, **fit_params)
loss = self.get_loss(y_pred, yi, X=Xi, training=True)
self.accelerator.backward(loss)
return {
'loss': loss,
'y_pred': y_pred,
}

def get_iterator(self, *args, **kwargs):
iterator = super().get_iterator(*args, **kwargs)
iterator = self.accelerator.prepare(iterator)
return iterator

def _step_optimizer(self, step_fn):
# We cannot step_fn as a 'closure' to .step because GradScaler doesn't
# suppor it:
# https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler.step
# Therefore, we need to call step_fn explicitly and step without
# argument.
step_fn()
for name in self._optimizers:
optimizer = getattr(self, name + '_')
optimizer.step()
2 changes: 1 addition & 1 deletion skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class NeuralNet:
summary scores are always logged in the history attribute,
regardless of the verbose setting.
device : str, torch.device (default='cpu')
device : str, torch.device, or None (default='cpu')
The compute device to be used. If set to 'cuda', data in torch
tensors will be pushed to cuda tensors before being sent to the
module. If set to None, then all compute devices will be left
Expand Down

0 comments on commit c0df4d1

Please sign in to comment.