Skip to content

Commit

Permalink
Fix accelerate bug, document, add scripts (#947)
Browse files Browse the repository at this point in the history
Partly resolves #944

There is an issue with using skorch in a multi-GPU setting with
accelerate. After some searching, it turns out there were two problems:

1. skorch did not call `accelerator.gather_for_metrics`, which resulted
   in `y_pred` not having the correct size. For more on this, consult the
   [accelerate
   docs](https://huggingface.co/docs/accelerate/quicktour#distributed-evaluation).

2. accelerate has an issue with beeing deepcopied, which happens for
   instance when using GridSearchCV. The problem is that some references
   get messed up, resulting in the GradientState of the accelerator
   instance and of the dataloader to diverge. Therefore, the
   accelerator did not "know" when the last batch was encountered and was
   thus unable to remove the dummy samples added for multi-GPU inference.

The fix for 1. is provided in this PR. For 2., there is no solution in
skorch, but a possible (maybe hacky) fix is suggested in the docs. The
fix consists of writing a custom Accelerator class that overrides
__deepcopy__ to just return self. I don't know enough about accelerate
internals to determine if this is a safe solution or if it can cause
more issues down the line, but it resolves the issue.

Since reproducing this bug requires a multi-GPU setup and running the
scripts with the accelerate launcher, it cannot be covered by normal
unit tests. Instead, this PR adds two scripts to reproduce the issue.
With the appropriate hardware, they can be used to check the solution.
  • Loading branch information
BenjaminBossan committed Apr 28, 2023
1 parent c909e14 commit 3c7fa60
Show file tree
Hide file tree
Showing 8 changed files with 278 additions and 2 deletions.
5 changes: 3 additions & 2 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed

### Fixed
- Fixed install command to work with recent changes in Google Colab. (#928)
- Fixes a couple of bugs related to using non-default modules and criteria (#927)
- Fixed install command to work with recent changes in Google Colab (#928)
- Fixed a couple of bugs related to using non-default modules and criteria (#927)
- Fixed a bug when using `AccelerateMixin` in a multi-GPU setup (#947)

## [0.12.1] - 2022-11-18

Expand Down
2 changes: 2 additions & 0 deletions docs/user/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ current epoch. To write a value to the current batch, use
Distributed history
-------------------

.. _dist-history:

When training a net in a distributed setting, e.g. when using
:class:`torch.nn.parallel.DistributedDataParallel`, directly or indirectly with
the help of :class:`.AccelerateMixin`, the default history class should not be
Expand Down
60 changes: 60 additions & 0 deletions docs/user/huggingface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ accelerate_ to skorch. E.g., this allows you to use mixed precision training
(AMP), multi-GPU training, training with a TPU, or gradient accumulation. For the
time being, this feature should be considered experimental.

Using accelerate
^^^^^^^^^^^^^^^^

To use this feature, create a new subclass of the neural net class you want to
use and inherit from the mixin class. E.g., if you want to use a
:class:`.NeuralNet`, it would look like this:
Expand Down Expand Up @@ -63,6 +66,63 @@ To install accelerate_, run the following command inside your Python environment
:class:`torch.optim.LBFGS`), you cannot use accelerate.


Caution when using a multi-GPU setup
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

There is a known issue when using accelerate in a multi-GPU setup *if copies of
the net are created*. In particular, be aware that sklearn often creates copies
under the hood, which may not immediately obvious to a user. Examples of
functions and classes creating copies are:

- `GridSearchCV`, `RandomizedSearchCV` etc.
- `cross_validate`, `cross_val_score` etc.
- `VotingClassifier`, `CalibratedClassifierCV` and other meta estimators (but
not `Pipeline`).

When using any of those in a multi-GPU setup with :class:`.AccelerateMixin`, you
may encounter errors. A possible fix is to prevent the ``Accelerator`` instance
from being copied (or, to be precise, deep-copied):

.. code:: python
class AcceleratedNet(AccelerateMixin, NeuralNet):
pass
class MyAccelerator(Accelerator):
def __deepcopy__(self, memo):
return self
accelerator = MyAccelerator()
net = AcceleratedNet(..., accelerator=accelerator)
# now grid search et al. should work
gs = GridSearchCV(net, ...)
gs.fit(X, y)
Note that this is a hacky solution, so monitor your results closely to ensure
nothing strange is going on.

There is also a problem with caching not working correctly in multi-GPU
training. Therefore, if using a scoring callback (e.g.
:class:`skorch.callbacks.EpochScoring`), turn caching off by passing
``use_caching=False``. Be aware that when using
:class:`skorch.NeuralNetClassifier`, a scorer for accuracy on the validation set
is added automatically. Caching can be turned off like this:

.. code:: python
net = NeuralNetClassifier(..., valid_acc__use_caching=False)
When running a lot of scorers, the lack of caching can slow down training
considerably because inference is called once for each scorer, even if the
results are always the same. A possible solution to this is to write your own
scoring callback that records multiple scores to the ``history`` using a single
inference call.

Moreover, if your training relies on the training history on some capacity, e.g.
because you want to early stop when the validation loss stops improving, you
should use :class:`.DistributedHistory` instead of the default history. More
information on this can be found :ref:`here <dist-history>`.

Tokenizers
----------

Expand Down
38 changes: 38 additions & 0 deletions examples/accelerate-multigpu/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Testing skorch with accelerate in multi GPU setting

The full history of this can be found here: https://github.com/skorch-dev/skorch/issues/944

There was an issue with using skorch in a multi-GPU setting with accelerate. After some searching, it turns out there were two problems:

1. skorch did not call `accelerator.gather_for_metrics`, which resulted in `y_pred` not having the correct size. For more on this, consult the [accelerate docs](https://huggingface.co/docs/accelerate/quicktour#distributed-evaluation).
2. accelerate has an issue with beeing deepcopied, which happens for instance when using `GridSearchCV`. The problem is that some references get messed up, resulting in the `GradientState` of the `accelerator` instance and of the `dataloader` to diverge. Therefore, the `accelerator` did not "know" when the last batch was encountered and was thus unable to remove the dummy samples added for multi-GPU inference.

The fix for 1. is provided in the same PR as this was added. For 2., the scripts contain a custom `Accelerator` class that overrides `__deepcopy__` to just return `self`. I don't know enough about accelerate internals to determine if this is a safe solution or if it can cause more issues down the line, but it resolves the issue.

This example contains two scripts, one involving skorch and one with skorch completely removed. The scripts reproduce the issue in a multi-GPU setup (tested on a GCP VM instance with two T4's). Unfortunately, the GitHub Action runners don't have such an option, which is why there is no unit test being added for the bug.

Run the scripts like this:

```sh
accelerate launch <script.py>
```

The accelerate config is:

```yaml
compute_environment: LOCAL_MACHINE
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
92 changes: 92 additions & 0 deletions examples/accelerate-multigpu/run-no-skorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
import torch
from accelerate import Accelerator
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from sklearn.base import BaseEstimator
from torch import nn


class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.dense0 = nn.Linear(100, 2)
self.nonlin = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.dense0(X)
X = self.nonlin(X)
return X


class Net(BaseEstimator):
def __init__(self, module, accelerator):
self.module = module
self.accelerator = accelerator

def fit(self, X, y, **fit_params):
X = torch.as_tensor(X)
y = torch.as_tensor(y)
dataset = torch.utils.data.TensorDataset(X, y)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
optimizer = torch.optim.SGD(self.module.parameters(), lr=0.01)

self.module = self.accelerator.prepare(self.module)
optimizer = self.accelerator.prepare(optimizer)
dataloader = self.accelerator.prepare(dataloader)

# training
self.module.train()
for epoch in range(5):
for source, targets in dataloader:
optimizer.zero_grad()
output = self.module(source)
loss = nn.functional.nll_loss(output, targets)
self.accelerator.backward(loss)
optimizer.step()

return self

def predict_proba(self, X):
self.module.eval()
X = torch.as_tensor(X)
dataset = torch.utils.data.TensorDataset(X)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64)
dataloader = self.accelerator.prepare(dataloader)

probas = []
with torch.no_grad():
for source, *_ in dataloader:
output = self.module(source)
output = self.accelerator.gather_for_metrics(output)
output = output.cpu().detach().numpy()
probas.append(output)

return np.vstack(probas)

def predict(self, X):
y_proba = self.predict_proba(X)
return y_proba.argmax(1)


class MyAccelerator(Accelerator):
def __deepcopy__(self, memo):
return self


def main():
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0)
X = X.astype(np.float32)

module = MyModule()
accelerator = MyAccelerator()
net = Net(module, accelerator)
# cross_validate creates a deepcopy of the accelerator attribute
res = cross_validate(
net, X, y, cv=2, scoring='accuracy', verbose=3, error_score='raise',
)
print(res)


if __name__ == '__main__':
main()
72 changes: 72 additions & 0 deletions examples/accelerate-multigpu/run-with-skorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import numpy as np
import torch
from accelerate import Accelerator
from sklearn.datasets import make_classification
from sklearn.model_selection import cross_validate
from torch import nn
from torch.distributed import TCPStore

from skorch import NeuralNetClassifier
from skorch.hf import AccelerateMixin
from skorch.history import DistributedHistory


class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense0 = nn.Linear(100, 2)
self.nonlin = nn.LogSoftmax(dim=-1)

def forward(self, X):
X = self.dense0(X)
X = self.nonlin(X)
return X


# make use of accelerate by creating a class with the AccelerateMixin
class AcceleratedNeuralNetClassifier(AccelerateMixin, NeuralNetClassifier):
pass


# prevent the accelerator from being copied by sklearn
class MyAccelerator(Accelerator):
def __deepcopy__(self, memo):
return self


def main():
X, y = make_classification(10000, n_features=100, n_informative=50, random_state=0)
X = X.astype(np.float32)

accelerator = MyAccelerator()

# use history class that works in distributed setting
# see https://skorch.readthedocs.io/en/latest/user/history.html#distributed-history
is_master = accelerator.is_main_process
world_size = accelerator.num_processes
rank = accelerator.local_process_index
store = TCPStore(
"127.0.0.1", port=8080, world_size=world_size, is_master=is_master)
dist_history = DistributedHistory(
store=store, rank=rank, world_size=world_size)

model = AcceleratedNeuralNetClassifier(
MyModule,
accelerator=accelerator,
max_epochs=3,
lr=0.001,
history=dist_history,
)

cross_validate(
model,
X,
y,
cv=2,
scoring="average_precision",
error_score="raise",
)


if __name__ == '__main__':
main()
8 changes: 8 additions & 0 deletions skorch/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,6 +1011,14 @@ def on_train_end(self, net, X=None, y=None, **kwargs):
super().on_train_end(net, X=X, y=y, **kwargs)
self.module_ = self.accelerator.unwrap_model(self.module_)

def evaluation_step(self, batch, training=False):
# More context:
# https://github.com/skorch-dev/skorch/issues/944
# https://huggingface.co/docs/accelerate/quicktour#distributed-evaluation
output = super().evaluation_step(batch, training=training)
y_pred = self.accelerator.gather_for_metrics(output)
return y_pred


class HfHubStorage:
"""Helper class that allows writing data to the Hugging Face Hub.
Expand Down
3 changes: 3 additions & 0 deletions skorch/tests/test_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,9 @@ def backward(self, loss, **kwargs):
def unwrap_model(self, model):
return model

def gather_for_metrics(self, output):
return output

# pylint: disable=unused-argument
@contextmanager
def accumulate(self, model):
Expand Down

0 comments on commit 3c7fa60

Please sign in to comment.