Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DistributedHistory for multi-gpu training #955

Merged
merged 12 commits into from
Apr 21, 2023

Conversation

BenjaminBossan
Copy link
Collaborator

@BenjaminBossan BenjaminBossan commented Apr 14, 2023

Description

(this is more or less copied from the docs)

When training a net in a distributed setting, e.g. when using torch.nn.parallel.DistributedDataParallel, directly or indirectly with the help of AccelerateMixin, the default history class should not be used. This is because each process will have its own history instance with no syncing happening between processes. Therefore, the information in the histories can diverge. When steering the training process through the histories, the resulting differences can cause trouble. When using early stopping, for instance, one process could receive the signal to stop but not the other.

DistributedHistory will take care of syncing the distributed batch information across processes, which will prevent the issue just described.

This class needs to be initialized with a distributed store provided by PyTorch (https://pytorch.org/docs/stable/distributed.html#distributed-key-value-store). I have only tested torch.distributed.TCPStore so far, but torch.distributed.FileStore should also work. The DistributedHistory also needs to be initialized with its rank and the world size (number of processes) so that it has all the required information to perform the syncing. When using accelerate, that information can be retrieved from the Accelerator instance.

Comments

  1. Even though the batch information, which is split across processes, is synced, the epoch information, which is not split, is not synced. E.g. the recorded duration can be different between processes. It is not quite clear what the "correct" behavior should be here, it would probably depend on what is done based on this information.
  2. To make it possible to use the new class, I had to change the net initialization code to not create a new history instance (except when it is None) -- instead the history is just cleared. Otherwise, calling fit would always overwrite the DistributedHistory with a normal History object.
  3. Unfortunately, TCPStore cannot be pickled. Therefore, I set it to None when pickling. This is not tragic as long as users pickle the final model and only load it for inference. If they want to keep on training, they would need to set the net.history.store manually.
  4. To store values in the kv store as strings, I json-serialize the values, which means that certain types are not supported (say, numpy arrays). I guess the values could be pickled, but that seems like overkill.
  5. I had some bugs when testing this with certain Python versions with pytest ("Fatal Python error: Aborted"). This did not happen with other Python versions or when not using pytest. Fingers crossed that this works in CI. I also had to add some time.sleep calls in the multiprocessing tests to avoid "broken pipe" etc. Update: CI had segfaults for PyTorch 1.11, so I'm skipping tests with DistributedHistory for that PyTorch version. Maybe the tests are just flaky, but we don't want that.

Description

(this is more or less copied from the docs)

When training a net in a distributed setting, e.g. when using
torch.nn.parallel.DistributedDataParallel, directly or indirectly with
the help of AccelerateMixin, the default history class should not be
used. This is because each process will have its own history instance
with no syncing happening between processes. Therefore, the information
in the histories can diverge. When steering the training process through
the histories, the resulting differences can cause trouble. When using
early stopping, for instance, one process could receive the signal to
stop but not the other.

DistributedHistory will take care of syncing the distributed batch
information across processes, which will prevent the issue just
described.

This class needs to be initialized with a distributed store provided by
PyTorch (https://pytorch.org/docs/stable/distributed.html#distributed-key-value-store).
I have only tested torch.distributed.TCPStore so far, but
torch.distributed.FileStore should also work. The DistributedHistory
also needs to be initialized with its rank and the world size (number of
processes) so that it has all the required information to perform the
syncing. When using accelerate, that information can be retrieved from
the Accelerator instance.

Comment

Even though the batch information, which is split across processes, is
synced, the epoch information, which is not split, is *not* synced. E.g.
the recorded duration can be different between processes. It is not
quite clear what the "correct" behavior should be here, it would
probably depend on what is done based on this information.

To make it possible to use the new class, I had to change the net
initialization code to not reinitialize the history when it is not None.
Otherwise, calling fit would always overwrite the DistributedHistory
with a normal History object.

Also, unfortunately, TCPStore cannot be pickled. Therefore, I set it
to None when pickling. This is not tragic as long as users pickle the
final model and only load it for inference. If they want to keep on
training, they would need to set the net.history.store manually.
@BenjaminBossan BenjaminBossan self-assigned this Apr 14, 2023
@BenjaminBossan
Copy link
Collaborator Author

BenjaminBossan commented Apr 14, 2023

For the history of this feature, see this comment and associated thread. Ping @Raphaaal.

accelerator.process_index seems to be something else here.
Copy link
Contributor

@githubnemo githubnemo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean and well thought out. LGTM

state = self.__dict__.copy()
try:
pickle.dumps(state['store'])
except TypeError:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be beneficial to know why this is expected

:class:`torch.distributed.TCPStore` has been tested to work.

rank : int
The rank of this process.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a single description of the parameter this is quite tautological. Maybe it is worthwhile to remark the properties and the possible origin of the rank? E.g. "Number differentiating the distributed training processes, e.g. as provided by accelerate.local_process_index."?

- Better docstring
- More code comments

Also:

- Catch PicklingError in case user provides custom Store that can't be
  pickled.
@BenjaminBossan
Copy link
Collaborator Author

@githubnemo Very good points, I have extended the docstrings and added more context to the comment in __getstate__. I also added a minor change there, namely also catching PicklingError. The PyTorch stores will raise TypeError, but maybe some users want to provide custom stores, so it's cautious to also catch PicklingError.

Oh and btw. coverage claims that these lines are not covered, which is not true (adding an error there will raise). I suspect that coverage is not correctly measured because of the use of multiprocessing.

@BenjaminBossan
Copy link
Collaborator Author

Failing CI seems to be unrelated.

@BenjaminBossan BenjaminBossan merged commit c909e14 into master Apr 21, 2023
15 checks passed
@BenjaminBossan BenjaminBossan deleted the distributed-history branch April 21, 2023 08:08
BenjaminBossan added a commit that referenced this pull request Apr 21, 2023
@BenjaminBossan BenjaminBossan mentioned this pull request May 8, 2023
BenjaminBossan added a commit that referenced this pull request May 17, 2023
Preparation for release of version 0.13.0

Release text:

The new skorch release is here and it has some changes that will be exiting for
some users.

- First of all, you may have heard of the [PyTorch 2.0
  release](https://pytorch.org/get-started/pytorch-2.0/), which includes the
  option to compile the PyTorch module for better runtime performance. This
  skorch release allows you to pass `compile=True` when initializing the net to
  enable compilation.
- Support for training on multiple GPUs with the help of the
  [`accelerate`](https://huggingface.co/docs/accelerate/index) package has been
  improved by fixing some bugs and providing a dedicated [history
  class](https://skorch.readthedocs.io/en/latest/user/history.html#distributed-history).
  Our documentation contains more information on [what to consider when training
  on multiple
  GPUs](https://skorch.readthedocs.io/en/latest/user/huggingface.html#caution-when-using-a-multi-gpu-setup).
- If you have ever been frustrated with your neural net not training properly,
  you know how hard it can be to discover the underlying issue. Using the new
  [`SkorchDoctor`](https://skorch.readthedocs.io/en/latest/helper.html#skorch.helper.SkorchDoctor)
  class will simplify the diagnosis of underlying issues. Take a look at the
  accompanying
  [notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb)

Apart from that, a few bugs have been fixed and the included notebooks have been
updated to properly install requirements on Google Colab.

We are grateful for external contributors, many thanks to:

- Kshiteej K (kshitij12345)
- Muhammad Abdullah (abdulasiraj)
- Royi (RoyiAvital)
- Sawradip Saha (sawradip)
- y10ab1 (y10ab1)

Find below the list of all changes since v0.12.1 below:

### Added
- Add support for compiled PyTorch modules using the `torch.compile` function,
  introduced in [PyTorch 2.0
  release](https://pytorch.org/get-started/pytorch-2.0/), which can greatly
  improve performance on new GPU architectures; to use it, initialize your net
  with the `compile=True` argument, further compilation arguments can be
  specified using the dunder notation, e.g. `compile__dynamic=True`
- Add a class
  [`DistributedHistory`](https://skorch.readthedocs.io/en/latest/history.html#skorch.history.DistributedHistory)
  which should be used when training in a multi GPU setting (#955)
- `SkorchDoctor`: A helper class that assists in understanding and debugging the
  neural net training, see [this
  notebook](https://nbviewer.org/github/skorch-dev/skorch/blob/master/notebooks/Skorch_Doctor.ipynb)
  (#912)
- When using `AccelerateMixin`, it is now possible to prevent unwrapping of the
  modules by setting `unwrap_after_train=True` (#963)

### Fixed
- Fixed install command to work with recent changes in Google Colab (#928)
- Fixed a couple of bugs related to using non-default modules and criteria
  (#927)
- Fixed a bug when using `AccelerateMixin` in a multi-GPU setup (#947)
- `_get_param_names` returns a list instead of a generator so that subsequent
  error messages return useful information instead of a generator `repr` string
  (#925)
- Fixed a bug that caused modules to not be sufficiently unwrapped at the end of
  training when using `AccelerateMixin`, which could prevent them from being
  pickleable (#963)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants