Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-gpu with pmap docs #147

Closed
sooheon opened this issue Jan 29, 2021 · 6 comments
Closed

Multi-gpu with pmap docs #147

sooheon opened this issue Jan 29, 2021 · 6 comments

Comments

@sooheon
Copy link
Contributor

sooheon commented Jan 29, 2021

One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?

@cgarciae
Copy link
Collaborator

Hey @sooheon

Right now we don't handle that case as we are still defining some of the basic APIs (check #139 is you are interested in Pytorch Lightning-like APIs) but:

  • I believe you can easily use pmap inside your Module
  • A flag could be added as you say, but we would have to test how it integrates with hooks like add_loss, add_summary, ect.

@sooheon
Copy link
Contributor Author

sooheon commented Jan 29, 2021

Adding pmap in Module and remembering also to shape your data in pmap friendly way as input to .fit seems like the default way then?

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 3, 2021

Giving it some thought, I think it would be good to get a working example using pmap with the new low-level API to get a better sense of how to generalize / automate it via simple arguments to Model.

I believe there are multiple ways of doing this, we can check what Keras and Pytorch Lightning propose, but I can think of 2 strategies:

  • Just parallelize the call to the main module, gradients calculated outside pmap so should not have device dimension.
  • Parallelize everthing before calling the optimizer, gradients calculated inside pmap so should have device dimension.

If there are examples on Flax or Haiku we can get a better sense on how to properly do this.

@sooheon
Copy link
Contributor Author

sooheon commented Feb 4, 2021

Yeah I was trying to do just that, get an example working with low lvl api. Some examples to look at:
flax transformer
haiku imagenet

Looks like they both do strat 2 IIUC.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 20, 2021

@sooheon I think this is a first good step in this direction:

https://github.com/poets-ai/elegy/blob/master/examples/elegy_mnist_conv_pmap.py

We can build on this to add it to ModelCore on a future PR, ideally you just pass a flag to enable distributed training.

My main concern is how synchronizing the states / batch statistics: are there states that you synchronize and states that you don't? If there are 2 types of states we might need to expand the GeneralizedModule API to differentiate these for all frameworks e.g. instead of params, states have params, sync_states, states.

@cgarciae
Copy link
Collaborator

cgarciae commented Nov 9, 2021

Took a while but its finally supported in 0.8.1 :)

@cgarciae cgarciae closed this as completed Nov 9, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants