New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
JaxSolver fails when using GPU support with no input parameters #3423
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #3423 +/- ##
========================================
Coverage 99.58% 99.58%
========================================
Files 256 256
Lines 20048 20048
========================================
Hits 19965 19965
Misses 83 83
☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks! Let's hold merging this until the wheel publishing is fixed.
Thanks for this @jsbrittain, I tested this on Windows locally and I can confirm that #3371 is fixed through the changes in this PR. I didn't know it would be as easy as this but I get it because of my general inexperience with Jax. Could you please close that issue too here? I have opened a new issue about bumping the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The wheel and rc situation is fixed now. The CHANGELOG line should be moved to the unreleased section and everything should be good to go.
The failing doctests here should not be a worry. I would suggest setting that configuration value to |
@agriyakhetarpal All done and passing. Note that some tests still appear to be quite fragile, with the Example notebooks in particular requiring three attempts to pass in this case (despite no code changes between runs). |
Co-authored-by: Saransh Chopra <saransh0701@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, @jsbrittain! Sorry for all the release mess 😬
@brosaplanella: I missed talking about it in the meeting—should this PR be included in the release given that this is a bug fix? I think #3443 ought to be taken a look at too otherwise the GPU support mentioned in our CHANGELOG for v23.9rc0 would not be available for users on some platforms (e.g., macOS with Metal requires I would be happy to write a PR for that this week and incorporate @jsbrittain's suggestions on it into consideration, please let me know if that would be needed |
If we can include it that would be great. Tagging @Saransh-cpp so he is aware and can provide some input. |
I'll merge this and add it in the rc1 release. |
JaxSolver fails when using GPU support with no input parameters
Description
When specified with input parameters, the JaxSolver will parallelise the solve across parameter sets either using asyncio (cpu) or jax-vmap (gpu). When specified without input parameters the cpu pathway continues to solve the model correctly; however, when the solver is called without an input parameter list in a gpu-enabled jax environment then the solver fails. This is because the vmap function requires at least one input argument to contain a non-empty array.
A fix is to ensure that input parameter sets of length 0 or 1 are directed towards the cpu pathway, since this parallelisation occurs mostly over parameter sets, rather than within solves.
Fixes #3422
Type of change
Please add a line in the relevant section of CHANGELOG.md to document the change (include PR #) - note reverse order of PR #s. If necessary, also add to the list of breaking changes.
Key checklist:
$ pre-commit run
(or$ nox -s pre-commit
) (see CONTRIBUTING.md for how to set this up to run automatically when committing locally, in just two lines of code)$ python run-tests.py --all
(or$ nox -s tests
)$ python run-tests.py --doctest
(or$ nox -s doctests
)You can run integration tests, unit tests, and doctests together at once, using
$ python run-tests.py --quick
(or$ nox -s quick
).Further checks:
** Existing tests should cover this scenario, but depend upon GPU runners which are work-in-progress. **