New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] Adding lgamma, digamma, and polygamma implementations #106292
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106292
Note: Links to docs will display an error until the docs builds have been completed. ❌ 2 New Failures, 8 Unrelated FailuresAs of commit cc34d98 with merge base 703cdd7 (): NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem? |
This issue seems unrelated to the PR. Can you @igm503 please rebase the PR? |
048b88a
to
3938014
Compare
@igm503 the assertion is coming from not implemented test . Can you check if lgamma tests are not in that category class in test_mps . |
@kulinseth I've fixed the assertion error by swapping another not-yet-implemented op for lgamma in the not_implemented test. |
@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect. |
|
The tests now pass on the macos 13 builds. @kulinseth However, since there are precision issues with test_output_grad_match_polygamma_polygamma_n_0_cpu_float32 on macos 12 as well, where should I put that exception? I scanned the different XFAILLISTs, and I don't see a clear place for it. Of course, I could put it in the pre-13 XFAIL list, but that would make it seem like it's fixed for >13, which it isn't. |
…unction had been made static elsewhere
…or test_error_on_not_implemented
@kulinseth I went ahead and added the failing tests to the MACOS_BEFORE_13_3_XFAILLIST as well. Let me know if there's a more appropriate place to put them. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
@pytorchbot merge -i |
Merge startedYour change will be merged while ignoring the following 6 checks: pull / linux-focal-py3.8-clang10 / test (default, 2, 3, linux.2xlarge), pull / linux-jammy-py3.9-clang12-asan / test (default, 6, 6, linux.4xlarge), pull / linux-jammy-py3.8-gcc11 / test (default, 2, 3, linux.2xlarge), pull / linux-focal-py3.11-clang10 / test (default, 3, 3, linux.2xlarge), pull / linux-focal-cuda12.1-py3.10-gcc9-sm86 / test (default, 4, 5, linux.g5.4xlarge.nvidia.gpu, unstable), pull / linux-focal-cuda12.1-py3.10-gcc9 / test (default, 1, 5, linux.4xlarge.nvidia.gpu) Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 3, 3, macos-m1-12) Details for Dev Infra teamRaised by workflow job |
@pytorchbot merge -i |
Fixes issue mentioned in #77764
e.g. #77764 (comment)
Adds MPS support for the following ops:
The lgamma fucntion does not yet have an MPS backend implementation. I've added one using a custom metal kernel (following John D. Cook's c++ implementation of the log gamma function: https://www.johndcook.com/blog/cpp_gamma/). For the backward pass op, I've added a digamma kernel that follows the cpu+cuda digamma implementation, and for the backward pass of the digamma op, I've added a polygamma + trigamma kernel following, again, the cpu+cuda implementations.
NOTE:
The cpu implementation of the polygamma function incorrectly (as far as I can tell) outputs a finite number for order = 1 and x in the negative integers. The mps implementation correctly outputs infinite. (see #106692)
The polygamma tests currently don't pass because of the error in the cpu+cuda kernels, but also because there are smallish discrepancies near the negative integers between the cpu+cuda and the mps polygamma and trigamma kernels. I'm not sure exactly why this is, but let me know if the discrepancies are too big.