Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demonstrate parallel execution of a loss function #50

Open
jackbaker1001 opened this issue Sep 12, 2023 · 3 comments
Open

Demonstrate parallel execution of a loss function #50

jackbaker1001 opened this issue Sep 12, 2023 · 3 comments
Assignees

Comments

@jackbaker1001
Copy link
Collaborator

On an HPC cluster, each term in a mean square loss can be calculated using embarrassingly parallel logic.

Unfortunately, the native way of doing this with jax (using jax.vmap and jax.pmap) is not compatible with input we must parallelize over: the Molecule object. This is because its data is stored in "ragged" structure. I.e., the dimensions of the grid for one molecule are very often different from the grid for another and the dimensions of the 1-RDM for one molecule is different for another: jnp.array([rdm1_1, rdm1_2]) will not work.

This means that for loss parallelism, we need to think differently. Sharding may be the way forward, but this requires more thought. A good reference is here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

I don't think we will get around to solving this problem before our release deadline, but if we want to do something with HPC, getting this right is non-negotiable.

@PabloAMC
Copy link
Collaborator

@Matematija recommended sharding too.

@jackbaker1001
Copy link
Collaborator Author

Related to #83

@jackbaker1001
Copy link
Collaborator Author

Having playing around with the multiple hosts parallelism in JAX, I came across many issues on Perlmutter with the detection of GPUs.

I'm giving mpi4jax a go for this task now. It should be fairly easy if this works well on Perlmutter.

@jackbaker1001 jackbaker1001 self-assigned this Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

2 participants