Skip to content

Commit

Permalink
Alias for digamma as psi to special namespace (#59143)
Browse files Browse the repository at this point in the history
Summary:
See #50345

cc: mruberry kshitij12345

Pull Request resolved: #59143

Reviewed By: jbschlosser

Differential Revision: D28986909

Pulled By: mruberry

fbshipit-source-id: bc8ff0375de968f3662b224689fa0a6b117f9c4e
  • Loading branch information
krshrimali authored and facebook-github-bot committed Jun 14, 2021
1 parent ff15d93 commit cf38b20
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 24 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -296,7 +296,6 @@ _(aten, diagonal) \
_(aten, fill_diagonal_) \
_(aten, diff) \
_(aten, frexp) \
_(aten, digamma) \
_(aten, dim) \
_(aten, dist) \
_(aten, dot) \
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -322,6 +322,9 @@ namespace c10 {
_(aten, moveaxis) \
_(aten, lgamma) \
_(aten, special_gammaln) \
_(aten, digamma) \
_(aten, special_psi) \
_(aten, special_digamma) \
_(aten, erf) \
_(aten, special_erf) \
_(aten, erfc) \
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/native/UnaryOps.cpp
Expand Up @@ -441,6 +441,13 @@ Tensor special_erfc(const Tensor& self) { return self.erfc(); }
Tensor& special_erfinv_out(const Tensor& self, Tensor& result) { return at::erfinv_out(result, self); }
Tensor special_erfinv(const Tensor& self) { return self.erfinv(); }

// special_psi, alias for digamma
Tensor& special_psi_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
Tensor special_psi(const Tensor& self) { return self.digamma(); }
// special_digamma, alias for digamma
Tensor& special_digamma_out(const Tensor& self, Tensor& result) { return at::digamma_out(result, self); }
Tensor special_digamma(const Tensor& self) { return self.digamma(); }

// special_i0, alias for i0
Tensor& special_i0_out(const Tensor& self, Tensor& result) { return at::i0_out(result, self); }
Tensor special_i0(const Tensor& self) { return self.i0(); }
Expand Down
16 changes: 16 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -9426,6 +9426,22 @@
python_module: special
variants: function

- func: special_psi(Tensor self) -> Tensor
python_module: special
variants: function

- func: special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
variants: function

- func: special_digamma(Tensor self) -> Tensor
python_module: special
variants: function

- func: special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
python_module: special
variants: function

- func: special_gammaln(Tensor self) -> Tensor
python_module: special
variants: function
Expand Down
2 changes: 2 additions & 0 deletions docs/source/special.rst
Expand Up @@ -26,6 +26,8 @@ Functions
.. autofunction:: expm1
.. autofunction:: exp2
.. autofunction:: gammaln
.. autofunction:: digamma
.. autofunction:: psi
.. autofunction:: i0
.. autofunction:: i0e
.. autofunction:: i1
Expand Down
25 changes: 2 additions & 23 deletions torch/_torch_docs.py
Expand Up @@ -2874,29 +2874,8 @@ def merge_dicts(*dicts):
add_docstr(torch.digamma, r"""
digamma(input, *, out=None) -> Tensor
Computes the logarithmic derivative of the gamma function on `input`.
.. math::
\psi(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)}
""" + r"""
Args:
input (Tensor): the tensor to compute the digamma function on
Keyword args:
{out}
.. note:: This function is similar to SciPy's `scipy.special.digamma`.
.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`.
Previously it returned `NaN` for `0`.
Example::
>>> a = torch.tensor([1, 0.5])
>>> torch.digamma(a)
tensor([-0.5772, -1.9635])
""".format(**common_args))

Alias for :func:`torch.special.digamma`.
""")

add_docstr(torch.dist,
r"""
Expand Down
32 changes: 32 additions & 0 deletions torch/csrc/api/include/torch/special.h
Expand Up @@ -21,6 +21,38 @@ inline Tensor& gammaln_out(Tensor& result, const Tensor& self) {
return torch::special_gammaln_out(result, self);
}

/// Computes the logarithmic derivative of the gamma function on input
/// See https://pytorch.org/docs/master/special.html#torch.special.psi
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::psi(t);
/// ```
inline Tensor psi(const Tensor& self) {
return torch::special_psi(self);
}

inline Tensor& psi_out(Tensor& result, const Tensor& self) {
return torch::special_psi_out(result, self);
}

/// Computes the logarithmic derivative of the gamma function on input
/// See https://pytorch.org/docs/master/special.html#torch.special.digamma
///
/// Example:
/// ```
/// auto t = torch::randn(128, dtype=kDouble);
/// torch::special::digamma(t);
/// ```
inline Tensor digamma(const Tensor& self) {
return torch::special_digamma(self);
}

inline Tensor& digamma_out(Tensor& result, const Tensor& self) {
return torch::special_digamma_out(result, self);
}

/// Computes entropy of input, elementwise
/// See https://pytorch.org/docs/master/special.html#torch.special.entr.
///
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/jit/passes/normalize_ops.cpp
Expand Up @@ -117,6 +117,8 @@ const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
{aten::special_exp2, aten::exp2},
{aten::special_expm1, aten::expm1},
{aten::special_logit, aten::logit},
{aten::special_digamma, aten::digamma},
{aten::special_psi, aten::digamma},
{aten::special_i0, aten::i0},
{aten::orgqr, aten::linalg_householder_product},
{aten::special_gammaln, aten::lgamma}};
Expand Down
2 changes: 2 additions & 0 deletions torch/overrides.py
Expand Up @@ -873,6 +873,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.special.exp2: lambda input: -1,
torch.special.expm1: lambda input: -1,
torch.special.expit: lambda input: -1,
torch.special.digamma: lambda input: -1,
torch.special.psi: lambda input: -1,
torch.special.gammaln: lambda input: -1,
torch.special.i0: lambda input: -1,
torch.special.i0e: lambda input: -1,
Expand Down
35 changes: 35 additions & 0 deletions torch/special/__init__.py
Expand Up @@ -35,6 +35,41 @@
tensor([ -inf, 0.0000, 0.3466])
""")

psi = _add_docstr(_special.special_psi,
r"""
psi(input, *, out=None) -> Tensor
Alias for :func:`torch.special.digamma`.
""")

digamma = _add_docstr(_special.special_digamma,
r"""
digamma(input, *, out=None) -> Tensor
Computes the logarithmic derivative of the gamma function on `input`.
.. math::
\digamma(x) = \frac{d}{dx} \ln\left(\Gamma\left(x\right)\right) = \frac{\Gamma'(x)}{\Gamma(x)}
""" + r"""
Args:
input (Tensor): the tensor to compute the digamma function on
Keyword args:
{out}
.. note:: This function is similar to SciPy's `scipy.special.digamma`.
.. note:: From PyTorch 1.8 onwards, the digamma function returns `-Inf` for `0`.
Previously it returned `NaN` for `0`.
Example::
>>> a = torch.tensor([1, 0.5])
>>> torch.special.digamma(a)
tensor([-0.5772, -1.9635])
""".format(**common_args))

gammaln = _add_docstr(_special.special_gammaln,
r"""
gammaln(input, *, out=None) -> Tensor
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -7071,6 +7071,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
assert_autodiffed=True),
UnaryUfuncInfo('digamma',
ref=scipy.special.digamma if TEST_SCIPY else _NOTHING,
aliases=('special.psi', 'special.digamma',),
decorators=(precisionOverride({torch.float16: 5e-1}),),
dtypes=all_types_and(torch.bool),
dtypesIfCUDA=all_types_and(torch.bool, torch.half),
Expand Down

0 comments on commit cf38b20

Please sign in to comment.