diff --git a/aten/src/ATen/core/NamedRegistrations.cpp b/aten/src/ATen/core/NamedRegistrations.cpp index 7e787dc0f5692..6712be56ebb2a 100644 --- a/aten/src/ATen/core/NamedRegistrations.cpp +++ b/aten/src/ATen/core/NamedRegistrations.cpp @@ -211,6 +211,9 @@ TORCH_LIBRARY_IMPL(aten, Named, m) { m.impl("hypot", CppFunction::makeFallthrough()); m.impl("hypot.out", CppFunction::makeFallthrough()); m.impl("hypot_", CppFunction::makeFallthrough()); + m.impl("i0", CppFunction::makeFallthrough()); + m.impl("i0.out", CppFunction::makeFallthrough()); + m.impl("i0_", CppFunction::makeFallthrough()); m.impl("imag", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Scalar", CppFunction::makeFallthrough()); m.impl("index_fill.Dimname_Tensor", CppFunction::makeFallthrough()); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index c4b0e7498c75d..afbff6ffbe567 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -373,6 +373,8 @@ _(aten, histc) \ _(aten, hspmm) \ _(aten, hstack) \ _(aten, hypot) \ +_(aten, i0) \ +_(aten, i0_) \ _(aten, ifft) \ _(aten, index) \ _(aten, index_add) \ diff --git a/aten/src/ATen/cpu/vec256/vec256_base.h b/aten/src/ATen/cpu/vec256/vec256_base.h index ea6da94dd08d6..d50f933f72be7 100644 --- a/aten/src/ATen/cpu/vec256/vec256_base.h +++ b/aten/src/ATen/cpu/vec256/vec256_base.h @@ -384,6 +384,9 @@ struct Vec256 { } return ret; } + Vec256 i0() const { + return map(calc_i0); + } Vec256 neg() const { // NB: the trailing return type is needed because we need to coerce the // return value back to T in the case of unary operator- incuring a diff --git a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h index 540fdef6966a5..37d41676e53c9 100644 --- a/aten/src/ATen/cpu/vec256/vec256_bfloat16.h +++ b/aten/src/ATen/cpu/vec256/vec256_bfloat16.h @@ -276,6 +276,20 @@ template <> class Vec256 { auto o2 = Sleef_hypotf8_u05(hi, b2); return cvtfp32_bf16(o1, o2); } + Vec256 i0() const { + __m256 lo, hi; + cvtbf16_fp32(values, lo, hi); + __at_align32__ float tmp1[size() / 2], tmp2[size() / 2]; + _mm256_storeu_ps(reinterpret_cast(tmp1), lo); + _mm256_storeu_ps(reinterpret_cast(tmp2), hi); + for (int64_t i = 0; i < size() / 2; i++) { + tmp1[i] = calc_i0(tmp1[i]); + tmp2[i] = calc_i0(tmp2[i]); + } + auto o1 = _mm256_loadu_ps(tmp1); + auto o2 = _mm256_loadu_ps(tmp2); + return cvtfp32_bf16(o1, o2); + } Vec256 log() const { return map(Sleef_logf8_u10); } diff --git a/aten/src/ATen/cpu/vec256/vec256_double.h b/aten/src/ATen/cpu/vec256/vec256_double.h index 204609ad19e88..fcad154e68b2c 100644 --- a/aten/src/ATen/cpu/vec256/vec256_double.h +++ b/aten/src/ATen/cpu/vec256/vec256_double.h @@ -152,6 +152,9 @@ template <> class Vec256 { Vec256 hypot(const Vec256 &b) const { return Vec256(Sleef_hypotd4_u05(values, b)); } + Vec256 i0() const { + return map(calc_i0); + } Vec256 log() const { return Vec256(Sleef_logd4_u10(values)); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float.h b/aten/src/ATen/cpu/vec256/vec256_float.h index ca87bf800718f..1ab11ea81529d 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float.h +++ b/aten/src/ATen/cpu/vec256/vec256_float.h @@ -190,6 +190,9 @@ template <> class Vec256 { Vec256 hypot(const Vec256 &b) const { return Vec256(Sleef_hypotf8_u05(values, b)); } + Vec256 i0() const { + return map(calc_i0); + } Vec256 neg() const { return _mm256_xor_ps(_mm256_set1_ps(-0.f), values); } diff --git a/aten/src/ATen/cpu/vec256/vec256_float_neon.h b/aten/src/ATen/cpu/vec256/vec256_float_neon.h index f04a10bb927d8..cfe6b0ea0fb36 100644 --- a/aten/src/ATen/cpu/vec256/vec256_float_neon.h +++ b/aten/src/ATen/cpu/vec256/vec256_float_neon.h @@ -359,6 +359,9 @@ template <> class Vec256 { } return loadu(tmp); } + Vec256 i0() const { + return map(calc_i0); + } Vec256 log() const { return map(std::log); } diff --git a/aten/src/ATen/cpu/vml.h b/aten/src/ATen/cpu/vml.h index 0670465f566cc..fe021a2b58c2b 100644 --- a/aten/src/ATen/cpu/vml.h +++ b/aten/src/ATen/cpu/vml.h @@ -128,6 +128,7 @@ IMPLEMENT_VML(erfinv) IMPLEMENT_VML_BUG(exp) IMPLEMENT_VML_BUG(expm1) IMPLEMENT_VML_BUG(floor) +IMPLEMENT_VML(i0) IMPLEMENT_VML(reciprocal) IMPLEMENT_VML_BUG(log) IMPLEMENT_VML_BUG(log10) diff --git a/aten/src/ATen/native/Distributions.h b/aten/src/ATen/native/Distributions.h index 923879dfd0f77..8b12ba6b287d6 100644 --- a/aten/src/ATen/native/Distributions.h +++ b/aten/src/ATen/native/Distributions.h @@ -259,11 +259,8 @@ C10_DEVICE scalar_t sample_binomial(scalar_t count, scalar_t prob, BaseSampler C10_DEVICE static inline scalar_t digamma_one(scalar_t x) { diff --git a/aten/src/ATen/native/Math.h b/aten/src/ATen/native/Math.h index 9e4fd20eb975d..c00ffec941199 100644 --- a/aten/src/ATen/native/Math.h +++ b/aten/src/ATen/native/Math.h @@ -114,13 +114,40 @@ Date: February 1996 #undef CENTRAL_RANGE /* - * The following function comes with the following copyright notice. - * It has been released under the BSD license. + * Note [3-Clause BSD License for the Cephes Math Library] + * Code derived from implementations in the Cephes Math Library should mention its derivation and reference + * this note (ex. 'This function is derived from the implementation of X in the Cephes Math Library. See note + * [3-Clause BSD License for the Cephes Math Library]. The license is: + * Copyright (c) 2018, Steven Moshier + * All rights reserved. * - * Cephes Math Library Release 2.8: June, 2000 - * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL Steven Moshier BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ +/* + * This function is derived from the implementation of the zeta function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + */ static inline double zeta(double x, double q) { static double MACHEP = 1.11022302462515654042E-16; static double A[] = { @@ -244,14 +271,11 @@ static inline float trigamma(float x) { result += (1 + 1 / (2*x) + ixx * (1.f/6 - ixx * (1.f/30 - ixx * (1.f/42)))) / x; return sign * result; } + /* - * The following function comes with the following copyright notice. - * It has been released under the BSD license. - * - * Cephes Math Library Release 2.8: June, 2000 - * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. */ - static inline double calc_digamma(double x) { static double PSI_10 = 2.25175258906672110764; if (x == 0) { @@ -296,11 +320,8 @@ static inline double calc_digamma(double x) { } /* - * The following function comes with the following copyright notice. - * It has been released under the BSD license. - * - * Cephes Math Library Release 2.8: June, 2000 - * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier + * This function is derived from the implementation of the digamma function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. */ static inline float calc_digamma(float x) { static float PSI_10 = 2.25175258906672110764f; @@ -384,3 +405,138 @@ calc_gcd(T a, T b) { } return b; } + +/* + * This function is derived from the implementation of the chbevl function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Evaluates the series + * + * len-1 + * - ' + * y = > array[i] T (x/2) + * - i + * i=0 + * + * of Chebyshev polynomials Ti at argument x/2. + * + * Coefficients are stored in reverse order, i.e. the zero order term is last in the array. Note len is the number of + * coefficients, not the order. + * + * If coefficients are for the interval a to b, x must have been transformed to x -> 2(2x - b - a)/(b-a) before + * entering the routine. This maps x from (a, b) to (-1, 1), over which the Chebyshev polynomials are defined. + * + * If the coefficients are for the inverted interval, in which (a, b) is mapped to (1/b, 1/a), the transformation + * required is x -> 2(2ab/x - b - a)/(b-a). If b is infinity, this becomes x -> 4a/x - 1. + */ +template +static inline typename std::enable_if::value, T>::type + chbevl(T x, T array[], size_t len) { + T b0, b1, b2; + + b0 = array[0]; + b1 = static_cast(0.0); + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + array[i]; + } + + return (static_cast(0.5) * (b0 - b2)); +} + +/* + * This function is derived from the implementation of the i0 function in the Cephes Math Library. + * See note [3-Clause BSD License for the Cephes Math Library]. + * + * Computes an approximation of the zeroth order modified Bessel function of the first kind. + * The approximation is actually two (sub)approximations, both using a Chebyshev polynomial expansion. + * One approximates the function over [0, 8], and the other over (8, infinity). This function takes the absolute value + * of all inputs to convert them into the domain of the approximation. + */ +template +static inline typename std::enable_if::value, T>::type +calc_i0(T _x) { + T x = std::abs(_x); + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + static T A[] = { + -4.41534164647933937950E-18, + 3.33079451882223809783E-17, + -2.43127984654795469359E-16, + 1.71539128555513303061E-15, + -1.16853328779934516808E-14, + 7.67618549860493561688E-14, + -4.85644678311192946090E-13, + 2.95505266312963983461E-12, + -1.72682629144155570723E-11, + 9.67580903537323691224E-11, + -5.18979560163526290666E-10, + 2.65982372468238665035E-9, + -1.30002500998624804212E-8, + 6.04699502254191894932E-8, + -2.67079385394061173391E-7, + 1.11738753912010371815E-6, + -4.41673835845875056359E-6, + 1.64484480707288970893E-5, + -5.75419501008210370398E-5, + 1.88502885095841655729E-4, + -5.76375574538582365885E-4, + 1.63947561694133579842E-3, + -4.32430999505057594430E-3, + 1.05464603945949983183E-2, + -2.37374148058994688156E-2, + 4.93052842396707084878E-2, + -9.49010970480476444210E-2, + 1.71620901522208775349E-1, + -3.04682672343198398683E-1, + 6.76795274409476084995E-1 + }; + + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + static T B[] = { + -7.23318048787475395456E-18, + -4.83050448594418207126E-18, + 4.46562142029675999901E-17, + 3.46122286769746109310E-17, + -2.82762398051658348494E-16, + -3.42548561967721913462E-16, + 1.77256013305652638360E-15, + 3.81168066935262242075E-15, + -9.55484669882830764870E-15, + -4.15056934728722208663E-14, + 1.54008621752140982691E-14, + 3.85277838274214270114E-13, + 7.18012445138366623367E-13, + -1.79417853150680611778E-12, + -1.32158118404477131188E-11, + -3.14991652796324136454E-11, + 1.18891471078464383424E-11, + 4.94060238822496958910E-10, + 3.39623202570838634515E-9, + 2.26666899049817806459E-8, + 2.04891858946906374183E-7, + 2.89137052083475648297E-6, + 6.88975834691682398426E-5, + 3.36911647825569408990E-3, + 8.04490411014108831608E-1 + }; + + if (x <= 8.0) { + T y = (x / 2.0) - 2.0; + return static_cast(std::exp(x) * chbevl(y, A, 30)); + } + + return static_cast(std::exp(x) * chbevl(static_cast(32.0 / x - 2.0), B, 25) / std::sqrt(x)); +} + +// Upcast bfloat16 input to float for numerical accuracy purposes +inline c10::BFloat16 calc_i0(c10::BFloat16 a) { return calc_i0(static_cast(a)); } diff --git a/aten/src/ATen/native/UnaryOps.cpp b/aten/src/ATen/native/UnaryOps.cpp index 73ad9840b0bea..aaed45cbef09b 100644 --- a/aten/src/ATen/native/UnaryOps.cpp +++ b/aten/src/ATen/native/UnaryOps.cpp @@ -247,6 +247,10 @@ Tensor& floor_out(Tensor& result, const Tensor& self) { Tensor floor(const Tensor& self) { return unary_op_impl(self, at::floor_out); } Tensor& floor_(Tensor& self) { return unary_op_impl_(self, at::floor_out); } +Tensor& i0_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, i0_stub); } +Tensor i0(const Tensor& self) { return unary_op_impl(self, at::i0_out); } +Tensor& i0_(Tensor& self) { return unary_op_impl_(self, at::i0_out); } + Tensor& log_out(Tensor& result, const Tensor& self) { return unary_op_impl_out(result, self, log_stub); } Tensor log(const Tensor& self) { return unary_op_impl(self, at::log_out); } Tensor& log_(Tensor& self) { return unary_op_impl_(self, at::log_out); } @@ -607,6 +611,7 @@ DEFINE_DISPATCH(exp_stub); DEFINE_DISPATCH(expm1_stub); DEFINE_DISPATCH(floor_stub); DEFINE_DISPATCH(frac_stub); +DEFINE_DISPATCH(i0_stub); DEFINE_DISPATCH(log_stub); DEFINE_DISPATCH(log10_stub); DEFINE_DISPATCH(log1p_stub); diff --git a/aten/src/ATen/native/UnaryOps.h b/aten/src/ATen/native/UnaryOps.h index df982c2f02ee2..0ae26f88f94ec 100644 --- a/aten/src/ATen/native/UnaryOps.h +++ b/aten/src/ATen/native/UnaryOps.h @@ -38,6 +38,7 @@ DECLARE_DISPATCH(unary_fn, exp_stub); DECLARE_DISPATCH(unary_fn, expm1_stub); DECLARE_DISPATCH(unary_fn, floor_stub); DECLARE_DISPATCH(unary_fn, frac_stub); +DECLARE_DISPATCH(unary_fn, i0_stub); DECLARE_DISPATCH(unary_fn, log_stub); DECLARE_DISPATCH(unary_fn, log10_stub); DECLARE_DISPATCH(unary_fn, log1p_stub); diff --git a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp index 3ce7d5ccfe038..a7a8175749e38 100644 --- a/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/UnaryOpsKernel.cpp @@ -646,6 +646,7 @@ IMPLEMENT_COMPLEX_KERNEL(FLOATING, log) IMPLEMENT_COMPLEX_KERNEL(FLOATING, log10) IMPLEMENT_FLOAT_KERNEL(FLOATING, log1p) IMPLEMENT_COMPLEX_KERNEL(FLOATING, log2) +IMPLEMENT_FLOAT_KERNEL(FLOATING, i0) IMPLEMENT_COMPLEX_KERNEL(FLOATING, round) IMPLEMENT_COMPLEX_KERNEL(FLOATING, sin) IMPLEMENT_COMPLEX_KERNEL(FLOATING, sqrt) diff --git a/aten/src/ATen/native/cuda/Math.cuh b/aten/src/ATen/native/cuda/Math.cuh index cc217e6e0f186..b08a00aebffcb 100644 --- a/aten/src/ATen/native/cuda/Math.cuh +++ b/aten/src/ATen/native/cuda/Math.cuh @@ -7,13 +7,8 @@ namespace at { namespace native { /* -* The following function was converted to CUDA form from code that comes -* with the following copyright notice. It has been released under the BSD license. - * - * Cephes Math Library Release 2.8: June, 2000 - * Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier + * For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h". */ - template static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { using accscalar_t = at::acc_type; @@ -94,12 +89,8 @@ static inline __host__ __device__ scalar_t zeta(scalar_t _x, scalar_t _q) { } /* -* The following function was converted to CUDA form from code that comes -* with the following copyright notice. It has been released under the BSD license. -* -* Cephes Math Library Release 2.8: June, 2000 -* Copyright 1984, 1987, 1992, 2000 by Stephen L. Moshier -*/ + * For licensing information, please refer to the the cpu implementation located in "ATen/native/Math.h". + */ template static inline __host__ __device__ scalar_t calc_digamma(scalar_t in) { using accscalar_t = at::acc_type; @@ -196,5 +187,117 @@ static inline C10_HOST_DEVICE scalar_t calc_gcd(scalar_t a_in, scalar_t b_in) { return b; } +/* + * For licensing information and documentation, please refer to the the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t chbevl(scalar_t _x, const scalar_t array[], size_t len) { + using accscalar_t = at::acc_type; + + accscalar_t x = static_cast(_x); + accscalar_t b0, b1, b2; + + b0 = static_cast(array[0]); + b1 = 0; + + for (size_t i = 1; i < len; ++i) { + b2 = b1; + b1 = b0; + b0 = x * b1 - b2 + static_cast(array[i]); + } + + return static_cast(0.5 * (b0 - b2)); +} + +/* + * For licensing information and documentation, please refer to the the cpu implementation located in "ATen/native/Math.h". + */ +template +static inline C10_HOST_DEVICE scalar_t calc_i0(scalar_t _x) { + using accscalar_t = at::acc_type; + + // Upcast input for numerical accuracy purposes + // Needed for accurate results if input is bfloat16 or float16 + accscalar_t x = ::abs(static_cast(_x)); + + /* Chebyshev coefficients for exp(-x) I0(x) + * in the interval [0,8]. + * + * lim(x->0){ exp(-x) I0(x) } = 1. + */ + const accscalar_t A[] = { + -4.41534164647933937950E-18, + 3.33079451882223809783E-17, + -2.43127984654795469359E-16, + 1.71539128555513303061E-15, + -1.16853328779934516808E-14, + 7.67618549860493561688E-14, + -4.85644678311192946090E-13, + 2.95505266312963983461E-12, + -1.72682629144155570723E-11, + 9.67580903537323691224E-11, + -5.18979560163526290666E-10, + 2.65982372468238665035E-9, + -1.30002500998624804212E-8, + 6.04699502254191894932E-8, + -2.67079385394061173391E-7, + 1.11738753912010371815E-6, + -4.41673835845875056359E-6, + 1.64484480707288970893E-5, + -5.75419501008210370398E-5, + 1.88502885095841655729E-4, + -5.76375574538582365885E-4, + 1.63947561694133579842E-3, + -4.32430999505057594430E-3, + 1.05464603945949983183E-2, + -2.37374148058994688156E-2, + 4.93052842396707084878E-2, + -9.49010970480476444210E-2, + 1.71620901522208775349E-1, + -3.04682672343198398683E-1, + 6.76795274409476084995E-1 + }; + + /* Chebyshev coefficients for exp(-x) sqrt(x) I0(x) + * in the inverted interval [8,infinity]. + * + * lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). + */ + const accscalar_t B[] = { + -7.23318048787475395456E-18, + -4.83050448594418207126E-18, + 4.46562142029675999901E-17, + 3.46122286769746109310E-17, + -2.82762398051658348494E-16, + -3.42548561967721913462E-16, + 1.77256013305652638360E-15, + 3.81168066935262242075E-15, + -9.55484669882830764870E-15, + -4.15056934728722208663E-14, + 1.54008621752140982691E-14, + 3.85277838274214270114E-13, + 7.18012445138366623367E-13, + -1.79417853150680611778E-12, + -1.32158118404477131188E-11, + -3.14991652796324136454E-11, + 1.18891471078464383424E-11, + 4.94060238822496958910E-10, + 3.39623202570838634515E-9, + 2.26666899049817806459E-8, + 2.04891858946906374183E-7, + 2.89137052083475648297E-6, + 6.88975834691682398426E-5, + 3.36911647825569408990E-3, + 8.04490411014108831608E-1 + }; + + if (x <= 8.0) { + accscalar_t y = static_cast((x / 2.0) - 2.0); + return static_cast(::exp(x) * chbevl(y, A, 30)); + } + + return static_cast(::exp(x) * chbevl(static_cast(32.0 / x - 2.0), B, 25) / ::sqrt(x)); +} + } } diff --git a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu index 85269c2811766..000d66f37fc22 100644 --- a/aten/src/ATen/native/cuda/UnaryOpsKernel.cu +++ b/aten/src/ATen/native/cuda/UnaryOpsKernel.cu @@ -47,6 +47,16 @@ void expm1_kernel_cuda(TensorIterator& iter) { }); } +void i0_kernel_cuda(TensorIterator& iter) { + AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "i0_cuda", [&]() { + AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "i0_cuda", [&] { + gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t { + return calc_i0(a); + }); + }); + }); +} + // We manually overload rsqrt because std::rsqrt does not work with complex types. template __host__ __device__ static inline scalar_t rsqrt_wrapper(scalar_t v) { @@ -176,6 +186,7 @@ void clamp_max_kernel_cuda(TensorIterator& iter, Scalar max_value) { REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda); REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda); REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda); +REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda); REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda); REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda); REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 326af4d7d6497..bf7ef8ace4dac 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -5373,6 +5373,16 @@ CPU: _erfinv_out_cpu CUDA: _erfinv_out_cuda +- func: i0(Tensor self) -> Tensor + use_c10_dispatcher: full + variants: function, method + +- func: i0_(Tensor(a!) self) -> Tensor(a!) + use_c10_dispatcher: full + variants: function, method + +- func: i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + - func: sign(Tensor self) -> Tensor use_c10_dispatcher: full variants: function, method diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 1cd67d48a9885..2eb9c806d656f 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -338,6 +338,8 @@ view of a storage and defines numeric operations on it. .. automethod:: histc .. automethod:: hypot .. automethod:: hypot_ + .. automethod:: i0 + .. automethod:: i0_ .. automethod:: ifft .. automethod:: index_add_ .. automethod:: index_add diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 7c7b064c79cc2..f806f36d08a9e 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -306,6 +306,7 @@ Pointwise Ops logical_xor logit hypot + i0 mul mvlgamma neg diff --git a/test/test_torch.py b/test/test_torch.py index f88b92b7f82c1..f043ea14291f5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -13369,6 +13369,8 @@ def test_unary_out_op_mem_overlap(self, device, dtype): ("floor", doubles, True, True, 'cuda'), ("frac", doubles, True, True, 'cpu'), ("frac", doubles, True, True, 'cuda'), + ("i0", doubles, True, True, 'cpu'), + ("i0", doubles, True, True, 'cuda'), ("log", positives, True, True, 'cpu'), ("log", positives, True, True, 'cuda'), ("log10", positives, True, True, 'cpu'), @@ -14547,6 +14549,8 @@ def _test_helper(x, y, bias, memory_format): lambda x, y: x.frac(), lambda x, y: x.hypot(y), lambda x, y: x.hypot_(y), + lambda x, y: x.i0(), + lambda x, y: x.i0_(), # lambda x, y: x.lerp(y, 0.5), # Need to update Lerp.cu with TensorIterator lambda x, y: x.log(), lambda x, y: x.log_(), @@ -16872,6 +16876,59 @@ def test_nextafter(self, device, dtype): expected = np.nextafter(a.cpu().numpy(), b.cpu().numpy()) self.assertEqual(actual, expected, atol=0, rtol=0) + def _i0_helper(self, t): + # Test by comparing to scipy + dtype = t.dtype + actual = torch.i0(t) + if dtype is torch.bfloat16: + t = t.to(torch.float32) + expected = scipy.special.i0(t.cpu().numpy()) + # Casting down for dtype float16 is required since scipy upcasts to float32 + if dtype is torch.bfloat16 or dtype is torch.float16: + expected = torch.from_numpy(expected).to(dtype) + self.assertEqual(actual, expected) + + def _i0_range_helper(self, range, device, dtype): + # i0 tests are broken up by the domain for which the function does not overflow for each dtype + # This is done to ensure that the function performs well across all possible input values, without worrying + # about inf or nan possibilities + for r in (range, -range): + t = torch.rand(1000, device=device).to(dtype) * r + self._i0_helper(t) + + @dtypesIfCUDA(*([torch.float16, torch.float32, torch.float64] + ([torch.bfloat16] if TEST_WITH_ROCM else []))) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range1(self, device, dtype): + # This tests the domain for i0 for which float16 does not overflow + # The domain is (-13.25, 13.25) + self._i0_range_helper(13.25, device, dtype) + + @dtypesIfCUDA(*([torch.float32, torch.float64] + ([torch.bfloat16] if TEST_WITH_ROCM else []))) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range2(self, device, dtype): + # This tests the domain for i0 for which float32 and bfloat16 does not overflow + # The domain is (-88.5, 88.5) + self._i0_range_helper(88.5, device, dtype) + + @dtypes(torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_range3(self, device, dtype): + # This tests the domain for i0 for which float64 does not overflow + # The domain is (-709.75, 709.75) + self._i0_range_helper(709.75, device, dtype) + + @dtypesIfCUDA(*([torch.float16, torch.float32, torch.float64] + ([torch.bfloat16] if TEST_WITH_ROCM else []))) + @dtypes(torch.bfloat16, torch.float32, torch.float64) + @unittest.skipIf(not TEST_SCIPY, "SciPy not found") + def test_i0_special(self, device, dtype): + t = torch.tensor([], device=device, dtype=dtype) + self._i0_helper(t) + + t = torch.tensor([inf, -inf, nan], device=device, dtype=dtype) + self.assertTrue(torch.i0(t).isnan().all()) + @slowTest @onlyOnCPUAndCUDA @dtypes(torch.float32, torch.float64, torch.bfloat16, torch.int32, torch.int64, torch.cfloat, torch.cdouble) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e577836043723..9557d7b86e83b 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -530,6 +530,9 @@ self: grad * self / result other: grad * other / result +- name: i0(Tensor self) -> Tensor + self: not_implemented("i0") + - name: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor self: index_backward(zeros_like(self), indices, grad) indices: TensorList() diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 9e85eecd49b4a..bc7341eed4477 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1464,6 +1464,20 @@ def add_docstr_all(method, docstr): In-place version of :meth:`~Tensor.hypot` """) +add_docstr_all('i0', + r""" +i0() -> Tensor + +See :func:`torch.i0` +""") + +add_docstr_all('i0_', + r""" +i0_() -> Tensor + +In-place version of :meth:`~Tensor.i0` +""") + add_docstr_all('indices', r""" indices() -> Tensor diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ef68288e6c12c..ae64efdc664c0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3137,6 +3137,29 @@ def merge_dicts(*dicts): """.format(**common_args)) +add_docstr(torch.i0, + r""" +i0(input, *, out=None) -> Tensor + +Computes the zeroth order modified Bessel function of the first kind for each element of :attr:`input`. + +.. math:: + \text{out}_{i} = I_0(\text{input}_{i}) = \sum_{k=0}^{\infty} \frac{(\text{input}_{i}^2/4)^k}{(k!)^2} + +""" + r""" +Args: + input (Tensor): the input tensor + +Keyword args: + {out} + +Example:: + + >>> torch.i0(torch.arange(5, dtype=torch.float32)) + tensor([ 1.0000, 1.2661, 2.2796, 4.8808, 11.3019]) + +""".format(**common_args)) + add_docstr(torch.index_select, r""" index_select(input, dim, index, *, out=None) -> Tensor diff --git a/torch/overrides.py b/torch/overrides.py index 492111fef37d1..f52aff10ce169 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -369,6 +369,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.gcd: lambda input, other, out=None: -1, torch.ge: lambda input, other, out=None: -1, torch.geqrf: lambda input, out=None: -1, + torch.i0: lambda input, out=None: -1, torch.outer: lambda input, vec2, out=None: -1, # alias for torch.ger torch.ger: lambda input, vec2, out=None: -1, torch.grid_sampler: lambda input, grid, interpolation_mode, padding_mode, align_corners: -1,