Skip to content

Commit

Permalink
[MPS] Adding lgamma, digamma, and polygamma implementations (#106292)
Browse files Browse the repository at this point in the history
Fixes issue mentioned in #77764

e.g. #77764 (comment)

Adds MPS support for the following ops:

- lgamma
- mvlgamma
- digamma
- polygamma

The lgamma fucntion does not yet have an MPS backend implementation. I've added one using a custom metal kernel (following John D. Cook's c++ implementation of the log gamma function: https://www.johndcook.com/blog/cpp_gamma/). For the backward pass op, I've added a digamma kernel that follows the cpu+cuda digamma implementation, and for the backward pass of the digamma op, I've added a polygamma + trigamma kernel following, again, the cpu+cuda implementations.

NOTE:

The cpu implementation of the polygamma function incorrectly (as far as I can tell) outputs a finite number for order = 1 and x in the negative integers. The mps implementation correctly outputs infinite. (see #106692)

The polygamma tests currently don't pass because of the error in the cpu+cuda kernels, but also because there are smallish discrepancies near the negative integers between the cpu+cuda and the mps polygamma and trigamma kernels. I'm not sure exactly why this is, but let me know if the discrepancies are too big.

Pull Request resolved: #106292
Approved by: https://github.com/kulinseth
  • Loading branch information
igm503 authored and pytorchmergebot committed Sep 12, 2023
1 parent c8e577b commit 1b9b3a2
Show file tree
Hide file tree
Showing 3 changed files with 639 additions and 18 deletions.
Loading

0 comments on commit 1b9b3a2

Please sign in to comment.