New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-gpu with pmap docs #147
Comments
Hey @sooheon Right now we don't handle that case as we are still defining some of the basic APIs (check #139 is you are interested in Pytorch Lightning-like APIs) but:
|
Adding pmap in Module and remembering also to shape your data in pmap friendly way as input to .fit seems like the default way then? |
Giving it some thought, I think it would be good to get a working example using I believe there are multiple ways of doing this, we can check what Keras and Pytorch Lightning propose, but I can think of 2 strategies:
If there are examples on Flax or Haiku we can get a better sense on how to properly do this. |
Yeah I was trying to do just that, get an example working with low lvl api. Some examples to look at: Looks like they both do strat 2 IIUC. |
@sooheon I think this is a first good step in this direction: https://github.com/poets-ai/elegy/blob/master/examples/elegy_mnist_conv_pmap.py We can build on this to add it to My main concern is how synchronizing the states / batch statistics: are there states that you synchronize and states that you don't? If there are 2 types of states we might need to expand the |
Took a while but its finally supported in |
One of the selling points of jax is the pmap transformation, but best practices around actually getting your training loop parallelizable still is confusing. What is elegy's story around multigpu training? Is it possible to get to pytorch-lightning like api as a single arg to model.fit?
The text was updated successfully, but these errors were encountered: