Skip to content

Commit

Permalink
[pt2] add Python meta for polygamma
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
nkaretnikov committed Aug 6, 2023
1 parent 68cb854 commit 49b5aed
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 13 deletions.
7 changes: 0 additions & 7 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2812,7 +2812,6 @@ def forward(self, x):
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('combinations', ''), # aten.masked_select.default
xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
xfail('gradient', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
Expand Down Expand Up @@ -2844,18 +2843,12 @@ def forward(self, x):
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/de...
xfail('prod', ''), # Cannot call numel() on tensor with symbolic sizes/strides
xfail('repeat_interleave', ''), # aten.repeat_interleave.Te...
xfail('_segment_reduce', 'lengths'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta functio...
xfail('sgn', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('special.i1', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic ...
xfail('stft', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('take_along_dim', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('trace', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Expand Down
6 changes: 0 additions & 6 deletions test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1559,11 +1559,6 @@ def f(t):
xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
Expand All @@ -1581,7 +1576,6 @@ def f(t):
xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct...
xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct...
xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct...
xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/...
xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo...
xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
Expand Down
11 changes: 11 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5368,6 +5368,17 @@ def meta_searchsorted(
return torch.empty((), dtype=dtype, device=sorted_sequence.device)


@register_meta(aten.polygamma)
@out_wrapper()
def meta_polygamma(n: int, self: Tensor) -> Tensor:
torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
_, result_dtype = elementwise_dtypes(
self,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
)
return torch.empty_like(self, dtype=result_dtype)


# We must also trigger meta registrations from PrimTorch ref
# decompositions
import torch._refs
Expand Down

0 comments on commit 49b5aed

Please sign in to comment.