Skip to content

Commit

Permalink
Implement torch.igamma (#46183)
Browse files Browse the repository at this point in the history
Summary:
Fixes #41637
This is regularized lower incomplete gamma function, equivalent to scipy's `gammainc` and tensorflow `igamma`.

cc fritzo mruberry

Pull Request resolved: #46183

Reviewed By: gchanan

Differential Revision: D24479126

Pulled By: mruberry

fbshipit-source-id: fdf8ea289fe4ca1b408810732192411e948fcdfe
  • Loading branch information
mfkasim1 authored and facebook-github-bot committed Oct 29, 2020
1 parent dd95bf6 commit 6eaa324
Show file tree
Hide file tree
Showing 26 changed files with 1,590 additions and 8 deletions.
106 changes: 106 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,112 @@ Apache License Version 2.0:
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.

=======================================================================
Cephes's 3-Clause BSD License
=======================================================================

Code derived from implementations in the Cephes Math Library should mention
its derivation and reference the following license:

3-Clause BSD License for the Cephes Math Library
Copyright (c) 2018, Steven Moshier
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.

* Neither the name of the nor the
names of its contributors may be used to endorse or promote products
derived from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY
DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.


=======================================================================
SciPy's 3-Clause BSD License
=======================================================================

Code derived from implementations in SciPy should mention its derivation
and reference the following license:

Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:

1. Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.

2. Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following
disclaimer in the documentation and/or other materials provided
with the distribution.

3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

=======================================================================
Boost's 1.0 Software License
=======================================================================

Code derived from implementations in Boost 1.0 should mention its
derivation and reference the following license:

Boost Software License - Version 1.0 - August 17th, 2003

Permission is hereby granted, free of charge, to any person or organization
obtaining a copy of the software and accompanying documentation covered by
this license (the "Software") to use, reproduce, display, distribute,
execute, and transmit the Software, and to prepare derivative works of the
Software, and to permit third-parties to whom the Software is furnished to
do so, all subject to the following:

The copyright notices in the Software and this entire statement, including
the above license grant, this restriction and the following disclaimer,
must be included in all copies of the Software, in whole or in part, and
all derivative works of the Software, unless such copies or derivative
works are solely in the form of machine-executable object code generated by
a source language processor.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

END OF TERMS AND CONDITIONS

APPENDIX: How to apply the Apache License to your work.
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("i0", CppFunction::makeFallthrough());
m.impl("i0.out", CppFunction::makeFallthrough());
m.impl("i0_", CppFunction::makeFallthrough());
m.impl("igamma", CppFunction::makeFallthrough());
m.impl("igamma.out", CppFunction::makeFallthrough());
m.impl("igamma_", CppFunction::makeFallthrough());
m.impl("imag", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough());
m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough());
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,8 @@ _(aten, hstack) \
_(aten, hypot) \
_(aten, i0) \
_(aten, i0_) \
_(aten, igamma) \
_(aten, igamma_) \
_(aten, ifft) \
_(aten, index) \
_(aten, index_add) \
Expand Down
7 changes: 7 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,13 @@ struct Vec256 {
Vec256<T> i0() const {
return map(calc_i0);
}
Vec256<T> igamma(const Vec256<T> &x) const {
Vec256<T> ret;
for (int64_t i = 0; i < size(); i++) {
ret[i] = calc_igamma(values[i], x[i]);
}
return ret;
}
Vec256<T> neg() const {
// NB: the trailing return type is needed because we need to coerce the
// return value back to T in the case of unary operator- incuring a
Expand Down
19 changes: 19 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,25 @@ template <> class Vec256<BFloat16> {
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> igamma(const Vec256<BFloat16> &x) const {
__m256 lo, hi;
__m256 xlo, xhi;
cvtbf16_fp32(values, lo, hi);
cvtbf16_fp32(x.values, xlo, xhi);
__at_align32__ float tmp1[size() / 2], tmp2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmp1), lo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmp2), hi);
__at_align32__ float tmpx1[size() / 2], tmpx2[size() / 2];
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx1), xlo);
_mm256_storeu_ps(reinterpret_cast<float*>(tmpx2), xhi);
for (int64_t i = 0; i < size() / 2; ++i) {
tmp1[i] = calc_igamma(tmp1[i], tmpx1[i]);
tmp2[i] = calc_igamma(tmp2[i], tmpx2[i]);
}
auto o1 = _mm256_loadu_ps(tmp1);
auto o2 = _mm256_loadu_ps(tmp2);
return cvtfp32_bf16(o1, o2);
}
Vec256<BFloat16> log() const {
return map(Sleef_logf8_u10);
}
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_complex_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ template <> class Vec256<c10::complex<double>> {
Vec256<c10::complex<double>> hypot(const Vec256<c10::complex<double>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<double>> igamma(const Vec256<c10::complex<double>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<double>> neg() const {
auto zero = _mm256_setzero_pd();
return _mm256_sub_pd(zero, values);
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_complex_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,9 @@ template <> class Vec256<c10::complex<float>> {
Vec256<c10::complex<float>> hypot(const Vec256<c10::complex<float>> &b) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<float>> igamma(const Vec256<c10::complex<float>> &x) const {
AT_ERROR("not supported for complex numbers");
}
Vec256<c10::complex<float>> neg() const {
auto zero = _mm256_setzero_ps();
return _mm256_sub_ps(zero, values);
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_double.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,16 @@ template <> class Vec256<double> {
Vec256<double> i0() const {
return map(calc_i0);
}
Vec256<double> igamma(const Vec256<double> &x) const {
__at_align32__ double tmp[size()];
__at_align32__ double tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<double> log() const {
return Vec256<double>(Sleef_logd4_u10(values));
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,16 @@ template <> class Vec256<float> {
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<float> neg() const {
return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
}
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/cpu/vec256/vec256_float_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,16 @@ template <> class Vec256<float> {
Vec256<float> i0() const {
return map(calc_i0);
}
Vec256<float> igamma(const Vec256<float> &x) const {
__at_align32__ float tmp[size()];
__at_align32__ float tmp_x[size()];
store(tmp);
x.store(tmp_x);
for (int64_t i = 0; i < size(); i++) {
tmp[i] = calc_igamma(tmp[i], tmp_x[i]);
}
return loadu(tmp);
}
Vec256<float> log() const {
return map(std::log);
}
Expand Down
18 changes: 18 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ DEFINE_DISPATCH(logaddexp2_stub);
DEFINE_DISPATCH(gcd_stub);
DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(igamma_stub);
DEFINE_DISPATCH(nextafter_stub);
DEFINE_DISPATCH(heaviside_stub);

Expand Down Expand Up @@ -968,6 +969,23 @@ Tensor& hypot_(Tensor& self, const Tensor& other) {
return at::hypot_out(self, self, other);
}

Tensor& igamma_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return result;
}

Tensor igamma(const Tensor& self, const Tensor& other) {
Tensor result;
auto iter = TensorIterator::binary_op(result, self, other);
igamma_stub(iter.device_type(), iter);
return iter.output();
}

Tensor& igamma_(Tensor& self, const Tensor& other) {
return at::igamma_out(self, self, other);
}

Tensor& nextafter_out(Tensor& result, const Tensor& self, const Tensor& other) {
auto iter = TensorIterator::binary_op(result, self, other);
nextafter_stub(iter.device_type(), iter);
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/BinaryOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace at { namespace native {
inline void alpha_check(const ScalarType dtype, Scalar alpha) {
TORCH_CHECK(! alpha.isBoolean() || dtype == ScalarType::Bool,
"Boolean alpha only supported for Boolean results.");
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
TORCH_CHECK(isFloatingType(dtype) || isComplexType(dtype)
|| alpha.isIntegral(true),
"For integral input tensors, argument alpha must not be a floating point number.");
}
Expand Down Expand Up @@ -68,6 +68,7 @@ DECLARE_DISPATCH(binary_fn, logaddexp2_stub);
DECLARE_DISPATCH(binary_fn, gcd_stub);
DECLARE_DISPATCH(binary_fn, lcm_stub);
DECLARE_DISPATCH(binary_fn, hypot_stub);
DECLARE_DISPATCH(binary_fn, igamma_stub);
DECLARE_DISPATCH(binary_fn, nextafter_stub);
DECLARE_DISPATCH(binary_fn, heaviside_stub);

Expand Down
Loading

0 comments on commit 6eaa324

Please sign in to comment.