Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] Adding lgamma, digamma, and polygamma implementations #106292

Closed
wants to merge 14 commits into from

Conversation

igm503
Copy link
Contributor

@igm503 igm503 commented Jul 31, 2023

Fixes issue mentioned in #77764

e.g. #77764 (comment)

Adds MPS support for the following ops:

  • lgamma
  • mvlgamma
  • digamma
  • polygamma

The lgamma fucntion does not yet have an MPS backend implementation. I've added one using a custom metal kernel (following John D. Cook's c++ implementation of the log gamma function: https://www.johndcook.com/blog/cpp_gamma/). For the backward pass op, I've added a digamma kernel that follows the cpu+cuda digamma implementation, and for the backward pass of the digamma op, I've added a polygamma + trigamma kernel following, again, the cpu+cuda implementations.

NOTE:

The cpu implementation of the polygamma function incorrectly (as far as I can tell) outputs a finite number for order = 1 and x in the negative integers. The mps implementation correctly outputs infinite. (see #106692)

The polygamma tests currently don't pass because of the error in the cpu+cuda kernels, but also because there are smallish discrepancies near the negative integers between the cpu+cuda and the mps polygamma and trigamma kernels. I'm not sure exactly why this is, but let me know if the discrepancies are too big.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 31, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/106292

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 8 Unrelated Failures

As of commit cc34d98 with merge base 703cdd7 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

UNSTABLE - The following job failed but was likely due to flakiness present on trunk and has been marked as unstable:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/mps Run MPS tests (subset of trunk) release notes: mps Release notes category labels Jul 31, 2023
@igm503 igm503 marked this pull request as ready for review August 7, 2023 06:29
@igm503 igm503 requested a review from kulinseth as a code owner August 7, 2023 06:29
@igm503 igm503 changed the title added log-gamma function kernel for mps backend [MPS] lgamma, digamma, and polygamma implementations Aug 7, 2023
@igm503 igm503 changed the title [MPS] lgamma, digamma, and polygamma implementations [MPS] Adding lgamma, digamma, and polygamma implementations Aug 7, 2023
@zou3519 zou3519 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 7, 2023
@igm503
Copy link
Contributor Author

igm503 commented Aug 26, 2023

@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem?

@kulinseth
Copy link
Collaborator

=================================== FAILURES ===================================
______________ TestFallbackWarning.test_error_on_not_implemented _______________
Traceback (most recent call last):
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/test/test_mps.py", line 10438, in test_error_on_not_implemented
    fn(*args, **kwargs)
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_5983090866/lib/python3.9/unittest/case.py", line 226, in __exit__
    self._raiseFailure("{} not raised".format(exc_name))
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_5983090866/lib/python3.9/unittest/case.py", line 163, in _raiseFailure
    raise self.test_case.failureException(msg)
AssertionError: NotImplementedError not raised

This issue seems unrelated to the PR. Can you @igm503 please rebase the PR?

@igm503 igm503 force-pushed the lgamma branch 2 times, most recently from 048b88a to 3938014 Compare September 2, 2023 20:04
@igm503 igm503 closed this Sep 2, 2023
@igm503 igm503 deleted the lgamma branch September 2, 2023 20:08
@igm503 igm503 restored the lgamma branch September 2, 2023 20:14
@igm503 igm503 reopened this Sep 2, 2023
@kulinseth
Copy link
Collaborator

@kulinseth Any chance you can give this a look and advise about whether the test failures are a problem?

@igm503 the assertion is coming from not implemented test . Can you check if lgamma tests are not in that category class in test_mps .

@igm503
Copy link
Contributor Author

igm503 commented Sep 3, 2023

@kulinseth I've fixed the assertion error by swapping another not-yet-implemented op for lgamma in the not_implemented test.

@igm503
Copy link
Contributor Author

igm503 commented Sep 6, 2023

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@kulinseth
Copy link
Collaborator

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@igm503 , I see, we can add these tests to XFAILLIST here.

@igm503
Copy link
Contributor Author

igm503 commented Sep 7, 2023

@kulinseth So, at least as I'm typing this, the test errors are now those that I mentioned in the pull request body: in some cases, they're precision issues, but in other cases, I think the cpu implementation is incorrect.

@igm503 , I see, we can add these tests to XFAILLIST here.

The tests now pass on the macos 13 builds.

@kulinseth However, since there are precision issues with test_output_grad_match_polygamma_polygamma_n_0_cpu_float32 on macos 12 as well, where should I put that exception? I scanned the different XFAILLISTs, and I don't see a clear place for it. Of course, I could put it in the pre-13 XFAIL list, but that would make it seem like it's fixed for >13, which it isn't.

@igm503
Copy link
Contributor Author

igm503 commented Sep 11, 2023

@kulinseth I went ahead and added the failing tests to the MACOS_BEFORE_13_3_XFAILLIST as well. Let me know if there's a more appropriate place to put them.

Copy link
Collaborator

@kulinseth kulinseth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@igm503
Copy link
Contributor Author

igm503 commented Sep 12, 2023

@pytorchbot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 12, 2023
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: trunk / macos-12-py3-arm64 / test (default, 3, 3, macos-m1-12)

Details for Dev Infra team Raised by workflow job

@igm503
Copy link
Contributor Author

igm503 commented Sep 12, 2023

@pytorchbot merge -i

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/mps Run MPS tests (subset of trunk) ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: mps Release notes category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants