Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelisation #16

Open
yangjackie opened this issue Oct 5, 2023 · 2 comments
Open

Parallelisation #16

yangjackie opened this issue Oct 5, 2023 · 2 comments

Comments

@yangjackie
Copy link

Hi,

Thanks for sharing this code. I had a little play with it on a garnet material and it works fairly good on training sets include finite temperature displaced structure generated from Phonopy.

So far I only managed to run this on a singe core. Just wondering how can I run the training on parallel architecture? Can you provide an example how to set it up? I just dive into N.N. recently, so not quite familiar with the JAX and other libraries that are used in your code.

Thanks a lot
Jack

@thorben-frank
Copy link
Owner

thorben-frank commented Oct 5, 2023

Hi Jack,

thanks for bringing this up. Generally, JAX supports functions suited for parallelisation called jax.pjit and jax.pmap which work in the spirit of jax.jit and jax.vmap just for mutliple GPUs.

However, if you are interested in parallelizing a message passing neural network (meaning you want to parallelise across graph nodes / atoms) you have to write your code such that it is compatible with these two functions. The jraph library has an experimental feature on how to parallelise graph neural networks, maybe this gives you an idea (https://github.com/google-deepmind/jraph/tree/master/jraph/experimental).

If you only want to parallelise e.g. along the batch dimension during training, one could maybe make use of the jax.pmap and jax.pjit function (something along the lines like here https://www.mishalaskin.com/posts/data_parallel). The corresponding part in mlff would be here for the loss_fn

def loss_fn(params, batch: DataTupleT):
and here for the update_fn
state = state.apply_gradients(grads=grads)
.

However, so far I have never tried it so I can't provide any insight on potential pitfalls or best practices but in case you want to dive deeper into parallelisation either across atoms or batch dimension I am happy to assist in any way I can.

Best,
Thorben

@yangjackie
Copy link
Author

Thanks Thorben for the detailed explanation. It does sound a bit involved. I'll give it a try with the basic jax function first and see how far I can go.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants