diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 1a7486a019a0..25cda648c89a 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -296,7 +296,6 @@ _(aten, diagonal) \ _(aten, fill_diagonal_) \ _(aten, diff) \ _(aten, frexp) \ -_(aten, digamma) \ _(aten, dim) \ _(aten, dist) \ _(aten, dot) \ diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 96a49f4426be..d06618271bc7 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -322,6 +322,9 @@ namespace c10 { _(aten, moveaxis) \ _(aten, lgamma) \ _(aten, special_gammaln) \ + _(aten, digamma) \ + _(aten, special_psi) \ + _(aten, special_digamma) \ _(aten, erf) \ _(aten, special_erf) \ _(aten, erfc) \ diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 624881a01dbf..4ba59ef3d05d 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -441,6 +441,13 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); } Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); } Tensor special_erfinv(const Tensor& self) { return self.erfinv(); } +// special_psi, alias for digamma +Tensor& special_psi_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); } +Tensor special_psi(const Tensor& self) { return self.digamma(); } +// special_digamma, alias for digamma +Tensor& special_digamma_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); } +Tensor special_digamma(const Tensor& self) { return self.digamma(); } + // special_i0, alias for i0 Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); } Tensor special_i0(const Tensor& self) { return self.i0(); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 5e0dd9917dd9..f0762dfc535c 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -9426,6 +9426,22 @@ python_module: special variants: function +- func: special_psi(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + +- func: special_digamma(Tensor self) -> Tensor + python_module: special + variants: function + +- func: special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + python_module: special + variants: function + - func: special_gammaln(Tensor self) -> Tensor python_module: special variants: function diff --git a/docs/source/special.rst b/docs/source/special.rst index 39aa0640c953..cc173dbc65ba 100644 --- a/docs/source/special.rst +++ b/docs/source/special.rst @@ -26,6 +26,8 @@ Functions .. autofunction:: expm1 .. autofunction:: exp2 .. autofunction:: gammaln +.. autofunction:: digamma +.. autofunction:: psi .. autofunction:: i0 .. autofunction:: i0e .. autofunction:: i1 diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4e38a487f0e7..81fe8c007b95 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2874,29 +2874,8 @@ def merge_dicts(*dicts): add_docstr(torch.digamma, r""" digamma(input, *, out=None) -> Tensor -Computes the logarithmic derivative of the gamma function on `input`. - -.. math:: - \psi(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} -""" + r""" -Args: - input (Tensor): the tensor to compute the digamma function on - -Keyword args: - {out} - -.. note:: This function is similar to SciPy's `scipy.special.digamma`. - -.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. - Previously it returned `NaN` for `0`. - -Example:: - - >>> a = torch.tensor([1, 0.5]) - >>> torch.digamma(a) - tensor([-0.5772, -1.9635]) -""".format(**common_args)) - +Alias for :func:`torch.special.digamma`. +""") add_docstr(torch.dist, r""" diff --git a/torch/csrc/api/include/torch/special.h b/torch/csrc/api/include/torch/special.h index d80a43981ae6..cf667f9412a7 100644 --- a/torch/csrc/api/include/torch/special.h +++ b/torch/csrc/api/include/torch/special.h @@ -21,6 +21,38 @@ inline Tensor& gammaln_out(Tensor& result, const Tensor& self) { return torch::special_gammaln_out(result, self); } +/// Computes the logarithmic derivative of the gamma function on input +/// See https://pytorch.org/docs/master/special.html#torch.special.psi +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::psi(t); +/// ``` +inline Tensor psi(const Tensor& self) { + return torch::special_psi(self); +} + +inline Tensor& psi_out(Tensor& result, const Tensor& self) { + return torch::special_psi_out(result, self); +} + +/// Computes the logarithmic derivative of the gamma function on input +/// See https://pytorch.org/docs/master/special.html#torch.special.digamma +/// +/// Example: +/// ``` +/// auto t = torch::randn(128, dtype=kDouble); +/// torch::special::digamma(t); +/// ``` +inline Tensor digamma(const Tensor& self) { + return torch::special_digamma(self); +} + +inline Tensor& digamma_out(Tensor& result, const Tensor& self) { + return torch::special_digamma_out(result, self); +} + /// Computes entropy of input, elementwise /// See https://pytorch.org/docs/master/special.html#torch.special.entr. /// diff --git a/torch/csrc/jit/passes/normalize_ops.cpp b/torch/csrc/jit/passes/normalize_ops.cpp index eda56fe22777..4e59467d7c13 100644 --- a/torch/csrc/jit/passes/normalize_ops.cpp +++ b/torch/csrc/jit/passes/normalize_ops.cpp @@ -117,6 +117,8 @@ const std::unordered_map& getOperatorAliasMap() { {aten::special_exp2, aten::exp2}, {aten::special_expm1, aten::expm1}, {aten::special_logit, aten::logit}, + {aten::special_digamma, aten::digamma}, + {aten::special_psi, aten::digamma}, {aten::special_i0, aten::i0}, {aten::orgqr, aten::linalg_householder_product}, {aten::special_gammaln, aten::lgamma}}; diff --git a/torch/overrides.py b/torch/overrides.py index 75bde5decb78..aa6876d43ea7 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -873,6 +873,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.special.exp2: lambda input: -1, torch.special.expm1: lambda input: -1, torch.special.expit: lambda input: -1, + torch.special.digamma: lambda input: -1, + torch.special.psi: lambda input: -1, torch.special.gammaln: lambda input: -1, torch.special.i0: lambda input: -1, torch.special.i0e: lambda input: -1, diff --git a/torch/special/__init__.py b/torch/special/__init__.py index d0aae87a0b23..68133bbe66f7 100644 --- a/torch/special/__init__.py +++ b/torch/special/__init__.py @@ -35,6 +35,41 @@ tensor([ -inf, 0.0000, 0.3466]) """) +psi = _add_docstr(_special.special_psi, + r""" +psi(input, *, out=None) -> Tensor + +Alias for :func:`torch.special.digamma`. +""") + +digamma = _add_docstr(_special.special_digamma, + r""" +digamma(input, *, out=None) -> Tensor + +Computes the logarithmic derivative of the gamma function on `input`. + +.. math:: + \digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)} +""" + r""" +Args: + input (Tensor): the tensor to compute the digamma function on + +Keyword args: + {out} + +.. note:: This function is similar to SciPy's `scipy.special.digamma`. + +.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`. + Previously it returned `NaN` for `0`. + +Example:: + + >>> a = torch.tensor([1, 0.5]) + >>> torch.special.digamma(a) + tensor([-0.5772, -1.9635]) + +""".format(**common_args)) + gammaln = _add_docstr(_special.special_gammaln, r""" gammaln(input, *, out=None) -> Tensor diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 1cda7f822db5..08e35af018b0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7071,6 +7071,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): assert_autodiffed=True), UnaryUfuncInfo('digamma', ref=scipy.special.digamma if TEST_SCIPY else _NOTHING, + aliases=('special.psi', 'special.digamma',), decorators=(precisionOverride({torch.float16: 5e-1}),), dtypes=all_types_and(torch.bool), dtypesIfCUDA=all_types_and(torch.bool, torch.half),