Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JaxSolver fails when using GPU support with no input parameters #3423

Merged
merged 7 commits into from Oct 31, 2023

Conversation

jsbrittain
Copy link
Contributor

Description

When specified with input parameters, the JaxSolver will parallelise the solve across parameter sets either using asyncio (cpu) or jax-vmap (gpu). When specified without input parameters the cpu pathway continues to solve the model correctly; however, when the solver is called without an input parameter list in a gpu-enabled jax environment then the solver fails. This is because the vmap function requires at least one input argument to contain a non-empty array.

A fix is to ensure that input parameter sets of length 0 or 1 are directed towards the cpu pathway, since this parallelisation occurs mostly over parameter sets, rather than within solves.

Fixes #3422

Type of change

Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.

  • New feature (non-breaking change which adds functionality)
  • Optimization (back-end change that speeds up the code)
  • Bug fix (non-breaking change which fixes an issue)

Key checklist:

  • No style issues: $ pre-commit run (or $ nox -s pre-commit) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)
  • All tests pass: $ python run-tests.py --all (or $ nox -s tests)
  • The documentation builds: $ python run-tests.py --doctest (or $ nox -s doctests)

You can run integration tests, unit tests, and doctests together at once, using $ python run-tests.py --quick (or $ nox -s quick).

Further checks:

  • Code is commented, particularly in hard-to-understand areas
  • Tests added that prove fix is effective or that feature works

** Existing tests should cover this scenario, but depend upon GPU runners which are work-in-progress. **

@codecov
Copy link

codecov bot commented Oct 9, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (0482121) 99.58% compared to head (6cc3940) 99.58%.
Report is 2 commits behind head on develop.

Additional details and impacted files
@@           Coverage Diff            @@
##           develop    #3423   +/-   ##
========================================
  Coverage    99.58%   99.58%           
========================================
  Files          256      256           
  Lines        20048    20048           
========================================
  Hits         19965    19965           
  Misses          83       83           
Files Coverage Δ
pybamm/solvers/jax_solver.py 90.69% <100.00%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Sponsor Member

@brosaplanella brosaplanella left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks! Let's hold merging this until the wheel publishing is fixed.

@agriyakhetarpal
Copy link
Member

agriyakhetarpal commented Oct 12, 2023

Thanks for this @jsbrittain, I tested this on Windows locally and I can confirm that #3371 is fixed through the changes in this PR. I didn't know it would be as easy as this but I get it because of my general inexperience with Jax. Could you please close that issue too here?

I have opened a new issue about bumping the jax and jaxlib versions so that GPU support can be targeted across platforms and for other reasons, see #3443

Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wheel and rc situation is fixed now. The CHANGELOG line should be moved to the unreleased section and everything should be good to go.

@agriyakhetarpal
Copy link
Member

The failing doctests here should not be a worry. I would suggest setting that configuration value to False, though.

@jsbrittain
Copy link
Contributor Author

@agriyakhetarpal All done and passing. Note that some tests still appear to be quite fragile, with the Example notebooks in particular requiring three attempts to pass in this case (despite no code changes between runs).

@agriyakhetarpal
Copy link
Member

The example notebooks issue should be this one: #3415, which came after the changes made in #3198.

Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
Copy link
Member

@Saransh-cpp Saransh-cpp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @jsbrittain! Sorry for all the release mess 😬

@agriyakhetarpal
Copy link
Member

@brosaplanella: I missed talking about it in the meeting—should this PR be included in the release given that this is a bug fix? I think #3443 ought to be taken a look at too otherwise the GPU support mentioned in our CHANGELOG for v23.9rc0 would not be available for users on some platforms (e.g., macOS with Metal requires v0.4.11—we are still on v0.4.8).

I would be happy to write a PR for that this week and incorporate @jsbrittain's suggestions on it into consideration, please let me know if that would be needed

@brosaplanella
Copy link
Sponsor Member

If we can include it that would be great. Tagging @Saransh-cpp so he is aware and can provide some input.

@Saransh-cpp
Copy link
Member

I'll merge this and add it in the rc1 release.

@Saransh-cpp Saransh-cpp merged commit 138cbf2 into pybamm-team:develop Oct 31, 2023
35 checks passed
Saransh-cpp added a commit that referenced this pull request Oct 31, 2023
JaxSolver fails when using GPU support with no input parameters
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants