From e8400376760907ae37bd4af649d7c348e6d4c5b8 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Mon, 14 Apr 2025 14:52:29 -0700 Subject: [PATCH] import complex.h from c10 (#10155) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/10155 Import complex type and utils from c10 in order to do arithmetic with complex numbers in ExecuTorch kernels. Reviewed By: digantdesai, swolchok Differential Revision: D72197508 --- .../core/portable_type/c10/c10/targets.bzl | 3 + .../core/portable_type/c10/c10/util/complex.h | 668 ++++++++++++++++++ .../portable_type/c10/c10/util/complex_math.h | 406 +++++++++++ .../c10/c10/util/complex_utils.h | 46 ++ runtime/core/portable_type/complex.h | 37 +- 5 files changed, 1129 insertions(+), 31 deletions(-) create mode 100644 runtime/core/portable_type/c10/c10/util/complex.h create mode 100644 runtime/core/portable_type/c10/c10/util/complex_math.h create mode 100644 runtime/core/portable_type/c10/c10/util/complex_utils.h diff --git a/runtime/core/portable_type/c10/c10/targets.bzl b/runtime/core/portable_type/c10/c10/targets.bzl index d9d72b5be3f..e9728745270 100644 --- a/runtime/core/portable_type/c10/c10/targets.bzl +++ b/runtime/core/portable_type/c10/c10/targets.bzl @@ -25,6 +25,9 @@ def define_common_targets(): "util/Half-inl.h", "util/TypeSafeSignMath.h", "util/bit_cast.h", + "util/complex.h", + "util/complex_math.h", + "util/complex_utils.h", "util/floating_point_utils.h", "util/irange.h", ], diff --git a/runtime/core/portable_type/c10/c10/util/complex.h b/runtime/core/portable_type/c10/c10/util/complex.h new file mode 100644 index 00000000000..b63710d9458 --- /dev/null +++ b/runtime/core/portable_type/c10/c10/util/complex.h @@ -0,0 +1,668 @@ +#pragma once + +#include + +#include +#include + +#if defined(__CUDACC__) || defined(__HIPCC__) +#include +#endif + +C10_CLANG_DIAGNOSTIC_PUSH() +#if C10_CLANG_HAS_WARNING("-Wimplicit-float-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wimplicit-float-conversion") +#endif +#if C10_CLANG_HAS_WARNING("-Wfloat-conversion") +C10_CLANG_DIAGNOSTIC_IGNORE("-Wfloat-conversion") +#endif + +namespace c10 { + +// c10::complex is an implementation of complex numbers that aims +// to work on all devices supported by PyTorch +// +// Most of the APIs duplicates std::complex +// Reference: https://en.cppreference.com/w/cpp/numeric/complex +// +// [NOTE: Complex Operator Unification] +// Operators currently use a mix of std::complex, thrust::complex, and +// c10::complex internally. The end state is that all operators will use +// c10::complex internally. Until then, there may be some hacks to support all +// variants. +// +// +// [Note on Constructors] +// +// The APIs of constructors are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/complex +// +// Since C++14, all constructors are constexpr in std::complex +// +// There are three types of constructors: +// - initializing from real and imag: +// `constexpr complex( const T& re = T(), const T& im = T() );` +// - implicitly-declared copy constructor +// - converting constructors +// +// Converting constructors: +// - std::complex defines converting constructor between float/double/long +// double, +// while we define converting constructor between float/double. +// - For these converting constructors, upcasting is implicit, downcasting is +// explicit. +// - We also define explicit casting from std::complex/thrust::complex +// - Note that the conversion from thrust is not constexpr, because +// thrust does not define them as constexpr ???? +// +// +// [Operator =] +// +// The APIs of operator = are mostly copied from C++ standard: +// https://en.cppreference.com/w/cpp/numeric/complex/operator%3D +// +// Since C++20, all operator= are constexpr. Although we are not building with +// C++20, we also obey this behavior. +// +// There are three types of assign operator: +// - Assign a real value from the same scalar type +// - In std, this is templated as complex& operator=(const T& x) +// with specialization `complex& operator=(T x)` for float/double/long +// double Since we only support float and double, on will use `complex& +// operator=(T x)` +// - Copy assignment operator and converting assignment operator +// - There is no specialization of converting assignment operators, which type +// is +// convertible is solely dependent on whether the scalar type is convertible +// +// In addition to the standard assignment, we also provide assignment operators +// with std and thrust +// +// +// [Casting operators] +// +// std::complex does not have casting operators. We define casting operators +// casting to std::complex and thrust::complex +// +// +// [Operator ""] +// +// std::complex has custom literals `i`, `if` and `il` defined in namespace +// `std::literals::complex_literals`. We define our own custom literals in the +// namespace `c10::complex_literals`. Our custom literals does not follow the +// same behavior as in std::complex, instead, we define _if, _id to construct +// float/double complex literals. +// +// +// [real() and imag()] +// +// In C++20, there are two overload of these functions, one it to return the +// real/imag, another is to set real/imag, they are both constexpr. We follow +// this design. +// +// +// [Operator +=,-=,*=,/=] +// +// Since C++20, these operators become constexpr. In our implementation, they +// are also constexpr. +// +// There are two types of such operators: operating with a real number, or +// operating with another complex number. For the operating with a real number, +// the generic template form has argument type `const T &`, while the overload +// for float/double/long double has `T`. We will follow the same type as +// float/double/long double in std. +// +// [Unary operator +-] +// +// Since C++20, they are constexpr. We also make them expr +// +// [Binary operators +-*/] +// +// Each operator has three versions (taking + as example): +// - complex + complex +// - complex + real +// - real + complex +// +// [Operator ==, !=] +// +// Each operator has three versions (taking == as example): +// - complex == complex +// - complex == real +// - real == complex +// +// Some of them are removed on C++20, but we decide to keep them +// +// [Operator <<, >>] +// +// These are implemented by casting to std::complex +// +// +// +// TODO(@zasdfgbnm): c10::complex is not currently supported, +// because: +// - lots of members and functions of c10::Half are not constexpr +// - thrust::complex only support float and double + +template +struct alignas(sizeof(T) * 2) complex { + using value_type = T; + + T real_ = T(0); + T imag_ = T(0); + + constexpr complex() = default; + C10_HOST_DEVICE constexpr complex(const T& re, const T& im = T()) + : real_(re), imag_(im) {} + template + explicit constexpr complex(const std::complex& other) + : complex(other.real(), other.imag()) {} +#if defined(__CUDACC__) || defined(__HIPCC__) + template + explicit C10_HOST_DEVICE complex(const thrust::complex& other) + : real_(other.real()), imag_(other.imag()) {} +// NOTE can not be implemented as follow due to ROCm bug: +// explicit C10_HOST_DEVICE complex(const thrust::complex &other): +// complex(other.real(), other.imag()) {} +#endif + + // Use SFINAE to specialize casting constructor for c10::complex and + // c10::complex + template + C10_HOST_DEVICE explicit constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + template + C10_HOST_DEVICE constexpr complex( + const std::enable_if_t, complex>& other) + : real_(other.real_), imag_(other.imag_) {} + + constexpr complex& operator=(T re) { + real_ = re; + imag_ = 0; + return *this; + } + + constexpr complex& operator+=(T re) { + real_ += re; + return *this; + } + + constexpr complex& operator-=(T re) { + real_ -= re; + return *this; + } + + constexpr complex& operator*=(T re) { + real_ *= re; + imag_ *= re; + return *this; + } + + constexpr complex& operator/=(T re) { + real_ /= re; + imag_ /= re; + return *this; + } + + template + constexpr complex& operator=(const complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + + template + constexpr complex& operator+=(const complex& rhs) { + real_ += rhs.real(); + imag_ += rhs.imag(); + return *this; + } + + template + constexpr complex& operator-=(const complex& rhs) { + real_ -= rhs.real(); + imag_ -= rhs.imag(); + return *this; + } + + template + constexpr complex& operator*=(const complex& rhs) { + // (a + bi) * (c + di) = (a*c - b*d) + (a * d + b * c) i + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } + +#ifdef __APPLE__ +#define FORCE_INLINE_APPLE __attribute__((always_inline)) +#else +#define FORCE_INLINE_APPLE +#endif + template + constexpr FORCE_INLINE_APPLE complex& operator/=(const complex& rhs) + __ubsan_ignore_float_divide_by_zero__ { + // (a + bi) / (c + di) = (ac + bd)/(c^2 + d^2) + (bc - ad)/(c^2 + d^2) i + // the calculation below follows numpy's complex division + T a = real_; + T b = imag_; + U c = rhs.real(); + U d = rhs.imag(); + +#if defined(__GNUC__) && !defined(__clang__) + // std::abs is already constexpr by gcc + auto abs_c = std::abs(c); + auto abs_d = std::abs(d); +#else + auto abs_c = c < 0 ? -c : c; + auto abs_d = d < 0 ? -d : d; +#endif + + if (abs_c >= abs_d) { + if (abs_c == U(0) && abs_d == U(0)) { + /* divide by zeros should yield a complex inf or nan */ + real_ = a / abs_c; + imag_ = b / abs_d; + } else { + auto rat = d / c; + auto scl = U(1.0) / (c + d * rat); + real_ = (a + b * rat) * scl; + imag_ = (b - a * rat) * scl; + } + } else { + auto rat = c / d; + auto scl = U(1.0) / (d + c * rat); + real_ = (a * rat + b) * scl; + imag_ = (b * rat - a) * scl; + } + return *this; + } +#undef FORCE_INLINE_APPLE + + template + constexpr complex& operator=(const std::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE complex& operator=(const thrust::complex& rhs) { + real_ = rhs.real(); + imag_ = rhs.imag(); + return *this; + } +#endif + + template + explicit constexpr operator std::complex() const { + return std::complex(std::complex(real(), imag())); + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + template + C10_HOST_DEVICE explicit operator thrust::complex() const { + return static_cast>(thrust::complex(real(), imag())); + } +#endif + + // consistent with NumPy behavior + explicit constexpr operator bool() const { + return real() || imag(); + } + + C10_HOST_DEVICE constexpr T real() const { + return real_; + } + constexpr void real(T value) { + real_ = value; + } + C10_HOST_DEVICE constexpr T imag() const { + return imag_; + } + constexpr void imag(T value) { + imag_ = value; + } +}; + +namespace complex_literals { + +constexpr complex operator""_if(long double imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(long double imag) { + return complex(0.0, static_cast(imag)); +} + +constexpr complex operator""_if(unsigned long long imag) { + return complex(0.0f, static_cast(imag)); +} + +constexpr complex operator""_id(unsigned long long imag) { + return complex(0.0, static_cast(imag)); +} + +} // namespace complex_literals + +template +constexpr complex operator+(const complex& val) { + return val; +} + +template +constexpr complex operator-(const complex& val) { + return complex(-val.real(), -val.imag()); +} + +template +constexpr complex operator+(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const complex& lhs, const T& rhs) { + complex result = lhs; + return result += rhs; +} + +template +constexpr complex operator+(const T& lhs, const complex& rhs) { + return complex(lhs + rhs.real(), rhs.imag()); +} + +template +constexpr complex operator-(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const complex& lhs, const T& rhs) { + complex result = lhs; + return result -= rhs; +} + +template +constexpr complex operator-(const T& lhs, const complex& rhs) { + complex result = -rhs; + return result += lhs; +} + +template +constexpr complex operator*(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const complex& lhs, const T& rhs) { + complex result = lhs; + return result *= rhs; +} + +template +constexpr complex operator*(const T& lhs, const complex& rhs) { + complex result = rhs; + return result *= lhs; +} + +template +constexpr complex operator/(const complex& lhs, const complex& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const complex& lhs, const T& rhs) { + complex result = lhs; + return result /= rhs; +} + +template +constexpr complex operator/(const T& lhs, const complex& rhs) { + complex result(lhs, T()); + return result /= rhs; +} + +// Define operators between integral scalars and c10::complex. std::complex does +// not support this when T is a floating-point number. This is useful because it +// saves a lot of "static_cast" when operate a complex and an integer. This +// makes the code both less verbose and potentially more efficient. +#define COMPLEX_INTEGER_OP_TEMPLATE_CONDITION \ + typename std::enable_if_t< \ + std::is_floating_point_v && std::is_integral_v, \ + int> = 0 + +template +constexpr c10::complex operator+(const c10::complex& a, const iT& b) { + return a + static_cast(b); +} + +template +constexpr c10::complex operator+(const iT& a, const c10::complex& b) { + return static_cast(a) + b; +} + +template +constexpr c10::complex operator-(const c10::complex& a, const iT& b) { + return a - static_cast(b); +} + +template +constexpr c10::complex operator-(const iT& a, const c10::complex& b) { + return static_cast(a) - b; +} + +template +constexpr c10::complex operator*(const c10::complex& a, const iT& b) { + return a * static_cast(b); +} + +template +constexpr c10::complex operator*(const iT& a, const c10::complex& b) { + return static_cast(a) * b; +} + +template +constexpr c10::complex operator/(const c10::complex& a, const iT& b) { + return a / static_cast(b); +} + +template +constexpr c10::complex operator/(const iT& a, const c10::complex& b) { + return static_cast(a) / b; +} + +#undef COMPLEX_INTEGER_OP_TEMPLATE_CONDITION + +template +constexpr bool operator==(const complex& lhs, const complex& rhs) { + return (lhs.real() == rhs.real()) && (lhs.imag() == rhs.imag()); +} + +template +constexpr bool operator==(const complex& lhs, const T& rhs) { + return (lhs.real() == rhs) && (lhs.imag() == T()); +} + +template +constexpr bool operator==(const T& lhs, const complex& rhs) { + return (lhs == rhs.real()) && (T() == rhs.imag()); +} + +template +constexpr bool operator!=(const complex& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const complex& lhs, const T& rhs) { + return !(lhs == rhs); +} + +template +constexpr bool operator!=(const T& lhs, const complex& rhs) { + return !(lhs == rhs); +} + +template +std::basic_ostream& operator<<( + std::basic_ostream& os, + const complex& x) { + return (os << static_cast>(x)); +} + +template +std::basic_istream& operator>>( + std::basic_istream& is, + complex& x) { + std::complex tmp; + is >> tmp; + x = tmp; + return is; +} + +} // namespace c10 + +// std functions +// +// The implementation of these functions also follow the design of C++20 + +namespace std { + +template +constexpr T real(const c10::complex& z) { + return z.real(); +} + +template +constexpr T imag(const c10::complex& z) { + return z.imag(); +} + +template +C10_HOST_DEVICE T abs(const c10::complex& z) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return thrust::abs(static_cast>(z)); +#else + return std::abs(static_cast>(z)); +#endif +} + +#if defined(USE_ROCM) +#define ROCm_Bug(x) +#else +#define ROCm_Bug(x) x +#endif + +template +C10_HOST_DEVICE T arg(const c10::complex& z) { + return ROCm_Bug(std)::atan2(std::imag(z), std::real(z)); +} + +#undef ROCm_Bug + +template +constexpr T norm(const c10::complex& z) { + return z.real() * z.real() + z.imag() * z.imag(); +} + +// For std::conj, there are other versions of it: +// constexpr std::complex conj( float z ); +// template< class DoubleOrInteger > +// constexpr std::complex conj( DoubleOrInteger z ); +// constexpr std::complex conj( long double z ); +// These are not implemented +// TODO(@zasdfgbnm): implement them as c10::conj +template +constexpr c10::complex conj(const c10::complex& z) { + return c10::complex(z.real(), -z.imag()); +} + +// Thrust does not have complex --> complex version of thrust::proj, +// so this function is not implemented at c10 right now. +// TODO(@zasdfgbnm): implement it by ourselves + +// There is no c10 version of std::polar, because std::polar always +// returns std::complex. Use c10::polar instead; + +} // namespace std + +namespace c10 { + +template +C10_HOST_DEVICE complex polar(const T& r, const T& theta = T()) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::polar(r, theta)); +#else + // std::polar() requires r >= 0, so spell out the explicit implementation to + // avoid a branch. + return complex(r * std::cos(theta), r * std::sin(theta)); +#endif +} + +template <> +struct alignas(4) complex { + Half real_; + Half imag_; + + // Constructors + complex() = default; + // Half constructor is not constexpr so the following constructor can't + // be constexpr + C10_HOST_DEVICE explicit inline complex(const Half& real, const Half& imag) + : real_(real), imag_(imag) {} + C10_HOST_DEVICE inline complex(const c10::complex& value) + : real_(value.real()), imag_(value.imag()) {} + + // Conversion operator + inline C10_HOST_DEVICE operator c10::complex() const { + return {real_, imag_}; + } + + constexpr C10_HOST_DEVICE Half real() const { + return real_; + } + constexpr C10_HOST_DEVICE Half imag() const { + return imag_; + } + + C10_HOST_DEVICE complex& operator+=(const complex& other) { + real_ = static_cast(real_) + static_cast(other.real_); + imag_ = static_cast(imag_) + static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator-=(const complex& other) { + real_ = static_cast(real_) - static_cast(other.real_); + imag_ = static_cast(imag_) - static_cast(other.imag_); + return *this; + } + + C10_HOST_DEVICE complex& operator*=(const complex& other) { + auto a = static_cast(real_); + auto b = static_cast(imag_); + auto c = static_cast(other.real()); + auto d = static_cast(other.imag()); + real_ = a * c - b * d; + imag_ = a * d + b * c; + return *this; + } +}; + +} // namespace c10 + +C10_CLANG_DIAGNOSTIC_POP() + +#define C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H +// math functions are included in a separate file +#include // IWYU pragma: keep +// utilities for complex types +#include // IWYU pragma: keep +#undef C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H diff --git a/runtime/core/portable_type/c10/c10/util/complex_math.h b/runtime/core/portable_type/c10/c10/util/complex_math.h new file mode 100644 index 00000000000..2b591026c94 --- /dev/null +++ b/runtime/core/portable_type/c10/c10/util/complex_math.h @@ -0,0 +1,406 @@ +#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "c10/util/complex_math.h is not meant to be individually included. Include c10/util/complex.h instead." +#endif + +namespace c10_complex_math { + +// Exponential functions + +template +C10_HOST_DEVICE inline c10::complex exp(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::exp(static_cast>(x))); +#else + return static_cast>( + std::exp(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log(static_cast>(x))); +#else + return static_cast>( + std::log(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log10(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::log10(static_cast>(x))); +#else + return static_cast>( + std::log10(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log2(const c10::complex& x) { + const c10::complex log2 = c10::complex(::log(2.0), 0.0); + return c10_complex_math::log(x) / log2; +} + +// Power functions +// +#if defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX)) +namespace _detail { +C10_API c10::complex sqrt(const c10::complex& in); +C10_API c10::complex sqrt(const c10::complex& in); +C10_API c10::complex acos(const c10::complex& in); +C10_API c10::complex acos(const c10::complex& in); +} // namespace _detail +#endif + +template +C10_HOST_DEVICE inline c10::complex sqrt(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sqrt(static_cast>(x))); +#elif !( \ + defined(_LIBCPP_VERSION) || \ + (defined(__GLIBCXX__) && !defined(_GLIBCXX11_USE_C99_COMPLEX))) + return static_cast>( + std::sqrt(static_cast>(x))); +#else + return _detail::sqrt(x); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const T& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const T& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>(thrust::pow( + static_cast>(x), static_cast>(y))); +#else + return static_cast>(std::pow( + static_cast>(x), static_cast>(y))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const c10::complex& x, + const U& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(static_cast>(x), y)); +#else + return static_cast>( + std::pow(static_cast>(x), y)); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex pow( + const T& x, + const c10::complex& y) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::pow(x, static_cast>(y))); +#else + return static_cast>( + std::pow(x, static_cast>(y))); +#endif +} + +// Trigonometric functions + +template +C10_HOST_DEVICE inline c10::complex sin(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sin(static_cast>(x))); +#else + return static_cast>( + std::sin(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex cos(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cos(static_cast>(x))); +#else + return static_cast>( + std::cos(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex tan(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tan(static_cast>(x))); +#else + return static_cast>( + std::tan(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex asin(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asin(static_cast>(x))); +#else + return static_cast>( + std::asin(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex acos(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acos(static_cast>(x))); +#elif !defined(_LIBCPP_VERSION) + return static_cast>( + std::acos(static_cast>(x))); +#else + return _detail::acos(x); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex atan(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atan(static_cast>(x))); +#else + return static_cast>( + std::atan(static_cast>(x))); +#endif +} + +// Hyperbolic functions + +template +C10_HOST_DEVICE inline c10::complex sinh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::sinh(static_cast>(x))); +#else + return static_cast>( + std::sinh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex cosh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::cosh(static_cast>(x))); +#else + return static_cast>( + std::cosh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex tanh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::tanh(static_cast>(x))); +#else + return static_cast>( + std::tanh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex asinh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::asinh(static_cast>(x))); +#else + return static_cast>( + std::asinh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex acosh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::acosh(static_cast>(x))); +#else + return static_cast>( + std::acosh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return static_cast>( + thrust::atanh(static_cast>(x))); +#else + return static_cast>( + std::atanh(static_cast>(x))); +#endif +} + +template +C10_HOST_DEVICE inline c10::complex log1p(const c10::complex& z) { +#if defined(__APPLE__) || defined(__MACOSX) || defined(__CUDACC__) || \ + defined(__HIPCC__) + // For Mac, the new implementation yielded a high relative error. Falling back + // to the old version for now. + // See https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + // For CUDA we also use this one, as thrust::log(thrust::complex) takes + // *forever* to compile + + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +#else + // CPU path + // Based on https://github.com/numpy/numpy/pull/22611#issuecomment-1667945354 + c10::complex u = z + T(1); + if (u == T(1)) { + return z; + } else { + auto log_u = log(u); + if (u - T(1) == z) { + return log_u; + } + return log_u * (z / (u - T(1))); + } +#endif +} + +template +C10_HOST_DEVICE inline c10::complex expm1(const c10::complex& z) { + // expm1(z) = exp(z) - 1 + // Define z = x + i * y + // f = e ^ (x + i * y) - 1 + // = e ^ x * e ^ (i * y) - 1 + // = (e ^ x * cos(y) - 1) + i * (e ^ x * sin(y)) + // = (e ^ x - 1) * cos(y) - (1 - cos(y)) + i * e ^ x * sin(y) + // = expm1(x) * cos(y) - 2 * sin(y / 2) ^ 2 + i * e ^ x * sin(y) + T x = z.real(); + T y = z.imag(); + T a = std::sin(y / 2); + T er = std::expm1(x) * std::cos(y) - T(2) * a * a; + T ei = std::exp(x) * std::sin(y); + return {er, ei}; +} + +} // namespace c10_complex_math + +using c10_complex_math::acos; +using c10_complex_math::acosh; +using c10_complex_math::asin; +using c10_complex_math::asinh; +using c10_complex_math::atan; +using c10_complex_math::atanh; +using c10_complex_math::cos; +using c10_complex_math::cosh; +using c10_complex_math::exp; +using c10_complex_math::expm1; +using c10_complex_math::log; +using c10_complex_math::log10; +using c10_complex_math::log1p; +using c10_complex_math::log2; +using c10_complex_math::pow; +using c10_complex_math::sin; +using c10_complex_math::sinh; +using c10_complex_math::sqrt; +using c10_complex_math::tan; +using c10_complex_math::tanh; + +namespace std { + +using c10_complex_math::acos; +using c10_complex_math::acosh; +using c10_complex_math::asin; +using c10_complex_math::asinh; +using c10_complex_math::atan; +using c10_complex_math::atanh; +using c10_complex_math::cos; +using c10_complex_math::cosh; +using c10_complex_math::exp; +using c10_complex_math::expm1; +using c10_complex_math::log; +using c10_complex_math::log10; +using c10_complex_math::log1p; +using c10_complex_math::log2; +using c10_complex_math::pow; +using c10_complex_math::sin; +using c10_complex_math::sinh; +using c10_complex_math::sqrt; +using c10_complex_math::tan; +using c10_complex_math::tanh; + +} // namespace std diff --git a/runtime/core/portable_type/c10/c10/util/complex_utils.h b/runtime/core/portable_type/c10/c10/util/complex_utils.h new file mode 100644 index 00000000000..1ca105f1d0a --- /dev/null +++ b/runtime/core/portable_type/c10/c10/util/complex_utils.h @@ -0,0 +1,46 @@ +#if !defined(C10_INTERNAL_INCLUDE_COMPLEX_REMAINING_H) +#error \ + "c10/util/complex_utils.h is not meant to be individually included. Include c10/util/complex.h instead." +#endif + +#include + +namespace c10 { + +template +struct is_complex : public std::false_type {}; + +template +struct is_complex> : public std::true_type {}; + +template +struct is_complex> : public std::true_type {}; + +// Extract double from std::complex; is identity otherwise +// TODO: Write in more idiomatic C++17 +template +struct scalar_value_type { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; +template +struct scalar_value_type> { + using type = T; +}; + +} // namespace c10 + +namespace std { + +template +class numeric_limits> : public numeric_limits {}; + +template +bool isnan(const c10::complex& v) { + return std::isnan(v.real()) || std::isnan(v.imag()); +} + +} // namespace std diff --git a/runtime/core/portable_type/complex.h b/runtime/core/portable_type/complex.h index e89a19e54d7..faf13a0432f 100644 --- a/runtime/core/portable_type/complex.h +++ b/runtime/core/portable_type/complex.h @@ -8,39 +8,14 @@ #pragma once -#include +#include -namespace executorch { -namespace runtime { -namespace etensor { +namespace executorch::runtime::etensor { +using c10::complex; +} // namespace executorch::runtime::etensor -/** - * An implementation of complex numbers, compatible with c10/util/complex.h from - * pytorch core. - */ -template -struct alignas(sizeof(T) * 2) complex { - T real_ = T(0); - T imag_ = T(0); -}; - -/** - * Specialization for Half, which is not a primitive C numeric type. - */ -template <> -struct alignas(4) complex { - Half real_; - Half imag_; -}; - -} // namespace etensor -} // namespace runtime -} // namespace executorch - -namespace torch { -namespace executor { +namespace torch::executor { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. using ::executorch::runtime::etensor::complex; -} // namespace executor -} // namespace torch +} // namespace torch::executor