#ifdef __NVCC__ #include #endif // __NVCC__ namespace { using int8_t = signed char; using uint8_t = unsigned char; using int16_t = short int; using uint16_t = unsigned short int; using int32_t = int; using uint32_t = unsigned int; using int64_t = long long int; using uint64_t = unsigned long long int; // Modified from cuda.h struct TensorMap { alignas(64) uint64_t opaque[16]; }; typedef int nvfuser_index_t; // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #ifdef __NVCC__ #include #else // The following namespace std is modified from LLVM, see the following // copyright information // // -*- C++ -*- //===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // copy-pasted from some llvm files: // - https://github.com/llvm/llvm-project/blob/main/libcxx/include/type_traits // - // https://github.com/llvm/llvm-project/blob/main/clang/test/Headers/Inputs/include/type_traits namespace std { template _Tp&& __declval(int); template _Tp __declval(long); template decltype(__declval<_Tp>(0)) declval() noexcept; template struct integral_constant { static const _Tp value = __v; typedef _Tp value_type; typedef integral_constant type; }; typedef integral_constant true_type; typedef integral_constant false_type; // is_same, functional template struct is_same : public false_type {}; template struct is_same<_Tp, _Tp> : public true_type {}; // is_integral, for some types. template struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; template <> struct is_integral : public integral_constant {}; // enable_if, functional template struct enable_if {}; template struct enable_if { using type = _Tp; }; template using enable_if_t = typename enable_if::type; template struct remove_const { typedef _Tp type; }; template struct remove_const { typedef _Tp type; }; template using remove_const_t = typename remove_const<_Tp>::type; template struct remove_volatile { typedef _Tp type; }; template struct remove_volatile { typedef _Tp type; }; template using remove_volatile_t = typename remove_volatile<_Tp>::type; template struct remove_cv { typedef typename remove_volatile::type>::type type; }; template using remove_cv_t = typename remove_cv<_Tp>::type; template struct __libcpp_is_floating_point : public false_type {}; template <> struct __libcpp_is_floating_point : public true_type {}; template <> struct __libcpp_is_floating_point : public true_type {}; template <> struct __libcpp_is_floating_point : public true_type {}; template struct is_floating_point : public __libcpp_is_floating_point::type> {}; template struct is_arithmetic : public integral_constant< bool, is_integral<_Tp>::value || is_floating_point<_Tp>::value> {}; template inline constexpr bool is_arithmetic_v = is_arithmetic<_Tp>::value; template struct __numeric_type { static void __test(...); static float __test(float); static double __test(char); static double __test(int); static double __test(unsigned); static double __test(long); static double __test(unsigned long); static double __test(long long); static double __test(unsigned long long); static double __test(double); static long double __test(long double); typedef decltype(__test(declval<_Tp>())) type; static const bool value = !is_same::value; }; template <> struct __numeric_type { static const bool value = true; }; // __promote template < class _A1, class _A2 = void, class _A3 = void, bool = __numeric_type<_A1>::value && __numeric_type<_A2>::value && __numeric_type<_A3>::value> class __promote_imp { public: static const bool value = false; }; template class __promote_imp<_A1, _A2, _A3, true> { private: typedef typename __promote_imp<_A1>::type __type1; typedef typename __promote_imp<_A2>::type __type2; typedef typename __promote_imp<_A3>::type __type3; public: typedef decltype(__type1() + __type2() + __type3()) type; static const bool value = true; }; template class __promote_imp<_A1, _A2, void, true> { private: typedef typename __promote_imp<_A1>::type __type1; typedef typename __promote_imp<_A2>::type __type2; public: typedef decltype(__type1() + __type2()) type; static const bool value = true; }; template class __promote_imp<_A1, void, void, true> { public: typedef typename __numeric_type<_A1>::type type; static const bool value = true; }; template class __promote : public __promote_imp<_A1, _A2, _A3> {}; } // namespace std #endif // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #ifdef __NVCC__ #include #else namespace std { template std::enable_if_t bit_cast( const From& src) noexcept { return *reinterpret_cast(&src); } } // namespace std #endif // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #ifndef __NVCC__ #define POS_INFINITY __int_as_float(0x7f800000) #define INFINITY POS_INFINITY #define NEG_INFINITY __int_as_float(0xff800000) #define NAN __int_as_float(0x7fffffff) //===----------------------------------------------------------------------===// // The following namespace std is modified from LLVM, see the following // copyright information // // -*- C++ -*- //===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // copy-pasted from the following llvm file: // https://github.com/llvm/llvm-project/blob/main/libcxx/include/complex namespace std { template class complex; template complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w); template complex<_Tp> operator/(const complex<_Tp>& __x, const complex<_Tp>& __y); template class complex { public: typedef _Tp value_type; private: value_type __re_; value_type __im_; public: constexpr complex( const value_type& __re = value_type(), const value_type& __im = value_type()) : __re_(__re), __im_(__im) {} template constexpr complex(const complex<_Xp>& __c) : __re_(__c.real()), __im_(__c.imag()) {} constexpr value_type real() const { return __re_; } constexpr value_type imag() const { return __im_; } void real(value_type __re) { __re_ = __re; } void imag(value_type __im) { __im_ = __im; } constexpr operator bool() const { return real() || imag(); } complex& operator=(const value_type& __re) { __re_ = __re; __im_ = value_type(); return *this; } complex& operator+=(const value_type& __re) { __re_ += __re; return *this; } complex& operator-=(const value_type& __re) { __re_ -= __re; return *this; } complex& operator*=(const value_type& __re) { __re_ *= __re; __im_ *= __re; return *this; } complex& operator/=(const value_type& __re) { __re_ /= __re; __im_ /= __re; return *this; } template complex& operator=(const complex<_Xp>& __c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } template complex& operator+=(const complex<_Xp>& __c) { __re_ += __c.real(); __im_ += __c.imag(); return *this; } template complex& operator-=(const complex<_Xp>& __c) { __re_ -= __c.real(); __im_ -= __c.imag(); return *this; } template complex& operator*=(const complex<_Xp>& __c) { *this = *this * complex(__c.real(), __c.imag()); return *this; } template complex& operator/=(const complex<_Xp>& __c) { *this = *this / complex(__c.real(), __c.imag()); return *this; } }; template <> class complex; template <> class complex { float __re_; float __im_; public: typedef float value_type; constexpr complex(float __re = 0.0f, float __im = 0.0f) : __re_(__re), __im_(__im) {} explicit constexpr complex(const complex& __c); // copy volatile to non-volatile constexpr complex(const volatile complex& other) : __re_(other.__re_), __im_(other.__im_) {} constexpr complex(const complex& other) : __re_(other.__re_), __im_(other.__im_) {} constexpr float real() const { return __re_; } constexpr float imag() const { return __im_; } void real(value_type __re) { __re_ = __re; } void imag(value_type __im) { __im_ = __im; } constexpr operator bool() const { return real() || imag(); } complex& operator=(float __re) { __re_ = __re; __im_ = value_type(); return *this; } complex& operator+=(float __re) { __re_ += __re; return *this; } complex& operator-=(float __re) { __re_ -= __re; return *this; } complex& operator*=(float __re) { __re_ *= __re; __im_ *= __re; return *this; } complex& operator/=(float __re) { __re_ /= __re; __im_ /= __re; return *this; } template complex& operator=(const complex<_Xp>& __c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // non-volatile to volatile template volatile complex& operator=(const complex<_Xp>& __c) volatile { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // volatile to non-volatile template complex& operator=(const volatile complex<_Xp>& __c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // volatile to volatile template volatile complex& operator=(const volatile complex<_Xp>& __c) volatile { __re_ = __c.real(); __im_ = __c.imag(); return *this; } template complex& operator+=(const complex<_Xp>& __c) { __re_ += __c.real(); __im_ += __c.imag(); return *this; } template complex& operator-=(const complex<_Xp>& __c) { __re_ -= __c.real(); __im_ -= __c.imag(); return *this; } template complex& operator*=(const complex<_Xp>& __c) { *this = *this * complex(__c.real(), __c.imag()); return *this; } template complex& operator/=(const complex<_Xp>& __c) { *this = *this / complex(__c.real(), __c.imag()); return *this; } }; template <> class complex { double __re_; double __im_; public: typedef double value_type; constexpr complex(double __re = 0.0, double __im = 0.0) : __re_(__re), __im_(__im) {} constexpr complex(const complex& __c); // copy volatile to non-volatile constexpr complex(const volatile complex& other) : __re_(other.__re_), __im_(other.__im_) {} constexpr complex(const complex& other) : __re_(other.__re_), __im_(other.__im_) {} constexpr double real() const { return __re_; } constexpr double imag() const { return __im_; } void real(value_type __re) { __re_ = __re; } void imag(value_type __im) { __im_ = __im; } constexpr operator bool() const { return real() || imag(); } complex& operator=(double __re) { __re_ = __re; __im_ = value_type(); return *this; } complex& operator+=(double __re) { __re_ += __re; return *this; } complex& operator-=(double __re) { __re_ -= __re; return *this; } complex& operator*=(double __re) { __re_ *= __re; __im_ *= __re; return *this; } complex& operator/=(double __re) { __re_ /= __re; __im_ /= __re; return *this; } template complex& operator=(const complex<_Xp>& __c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // non-volatile to volatile template volatile complex& operator=(const complex<_Xp>& __c) volatile { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // volatile to non-volatile template complex& operator=(const volatile complex<_Xp>& __c) { __re_ = __c.real(); __im_ = __c.imag(); return *this; } // volatile to volatile template volatile complex& operator=(const volatile complex<_Xp>& __c) volatile { __re_ = __c.real(); __im_ = __c.imag(); return *this; } template complex& operator+=(const complex<_Xp>& __c) { __re_ += __c.real(); __im_ += __c.imag(); return *this; } template complex& operator-=(const complex<_Xp>& __c) { __re_ -= __c.real(); __im_ -= __c.imag(); return *this; } template complex& operator*=(const complex<_Xp>& __c) { *this = *this * complex(__c.real(), __c.imag()); return *this; } template complex& operator/=(const complex<_Xp>& __c) { *this = *this / complex(__c.real(), __c.imag()); return *this; } }; inline constexpr complex::complex(const complex& __c) : __re_(__c.real()), __im_(__c.imag()) {} inline constexpr complex::complex(const complex& __c) : __re_(__c.real()), __im_(__c.imag()) {} // 26.3.6 operators: template inline complex<_Tp> operator+( const complex<_Tp>& __x, const complex<_Tp>& __y) { complex<_Tp> __t(__x); __t += __y; return __t; } template inline complex<_Tp> operator+(const complex<_Tp>& __x, const _Tp& __y) { complex<_Tp> __t(__x); __t += __y; return __t; } template inline complex<_Tp> operator+(const _Tp& __x, const complex<_Tp>& __y) { complex<_Tp> __t(__y); __t += __x; return __t; } template inline complex<_Tp> operator-( const complex<_Tp>& __x, const complex<_Tp>& __y) { complex<_Tp> __t(__x); __t -= __y; return __t; } template inline complex<_Tp> operator-(const complex<_Tp>& __x, const _Tp& __y) { complex<_Tp> __t(__x); __t -= __y; return __t; } template inline complex<_Tp> operator-(const _Tp& __x, const complex<_Tp>& __y) { complex<_Tp> __t(-__y); __t += __x; return __t; } template complex<_Tp> operator*(const complex<_Tp>& __z, const complex<_Tp>& __w) { _Tp __a = __z.real(); _Tp __b = __z.imag(); _Tp __c = __w.real(); _Tp __d = __w.imag(); _Tp __ac = __a * __c; _Tp __bd = __b * __d; _Tp __ad = __a * __d; _Tp __bc = __b * __c; _Tp __x = __ac - __bd; _Tp __y = __ad + __bc; if (isnan(__x) && isnan(__y)) { bool __recalc = false; if (isinf(__a) || isinf(__b)) { __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); if (isnan(__c)) __c = copysign(_Tp(0), __c); if (isnan(__d)) __d = copysign(_Tp(0), __d); __recalc = true; } if (isinf(__c) || isinf(__d)) { __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); if (isnan(__a)) __a = copysign(_Tp(0), __a); if (isnan(__b)) __b = copysign(_Tp(0), __b); __recalc = true; } if (!__recalc && (isinf(__ac) || isinf(__bd) || isinf(__ad) || isinf(__bc))) { if (isnan(__a)) __a = copysign(_Tp(0), __a); if (isnan(__b)) __b = copysign(_Tp(0), __b); if (isnan(__c)) __c = copysign(_Tp(0), __c); if (isnan(__d)) __d = copysign(_Tp(0), __d); __recalc = true; } if (__recalc) { __x = _Tp(INFINITY) * (__a * __c - __b * __d); __y = _Tp(INFINITY) * (__a * __d + __b * __c); } } return complex<_Tp>(__x, __y); } template inline complex<_Tp> operator*(const complex<_Tp>& __x, const _Tp& __y) { complex<_Tp> __t(__x); __t *= __y; return __t; } template inline complex<_Tp> operator*(const _Tp& __x, const complex<_Tp>& __y) { complex<_Tp> __t(__y); __t *= __x; return __t; } template complex<_Tp> operator/(const complex<_Tp>& __z, const complex<_Tp>& __w) { int __ilogbw = 0; _Tp __a = __z.real(); _Tp __b = __z.imag(); _Tp __c = __w.real(); _Tp __d = __w.imag(); _Tp __logbw = logb(fmax(fabs(__c), fabs(__d))); if (isfinite(__logbw)) { __ilogbw = static_cast(__logbw); __c = scalbn(__c, -__ilogbw); __d = scalbn(__d, -__ilogbw); } _Tp __denom = __c * __c + __d * __d; _Tp __x = scalbn((__a * __c + __b * __d) / __denom, -__ilogbw); _Tp __y = scalbn((__b * __c - __a * __d) / __denom, -__ilogbw); if (isnan(__x) && isnan(__y)) { if ((__denom == _Tp(0)) && (!isnan(__a) || !isnan(__b))) { __x = copysign(_Tp(INFINITY), __c) * __a; __y = copysign(_Tp(INFINITY), __c) * __b; } else if ((isinf(__a) || isinf(__b)) && isfinite(__c) && isfinite(__d)) { __a = copysign(isinf(__a) ? _Tp(1) : _Tp(0), __a); __b = copysign(isinf(__b) ? _Tp(1) : _Tp(0), __b); __x = _Tp(INFINITY) * (__a * __c + __b * __d); __y = _Tp(INFINITY) * (__b * __c - __a * __d); } else if ( isinf(__logbw) && __logbw > _Tp(0) && isfinite(__a) && isfinite(__b)) { __c = copysign(isinf(__c) ? _Tp(1) : _Tp(0), __c); __d = copysign(isinf(__d) ? _Tp(1) : _Tp(0), __d); __x = _Tp(0) * (__a * __c + __b * __d); __y = _Tp(0) * (__b * __c - __a * __d); } } return complex<_Tp>(__x, __y); } template inline complex<_Tp> operator/(const complex<_Tp>& __x, const _Tp& __y) { return complex<_Tp>(__x.real() / __y, __x.imag() / __y); } template inline complex<_Tp> operator/(const _Tp& __x, const complex<_Tp>& __y) { complex<_Tp> __t(__x); __t /= __y; return __t; } template inline complex<_Tp> operator+(const complex<_Tp>& __x) { return __x; } template inline complex<_Tp> operator-(const complex<_Tp>& __x) { return complex<_Tp>(-__x.real(), -__x.imag()); } template inline constexpr bool operator==( const complex<_Tp>& __x, const complex<_Tp>& __y) { return __x.real() == __y.real() && __x.imag() == __y.imag(); } template inline constexpr bool operator==(const complex<_Tp>& __x, const _Tp& __y) { return __x.real() == __y && __x.imag() == 0; } template inline constexpr bool operator==(const _Tp& __x, const complex<_Tp>& __y) { return __x == __y.real() && 0 == __y.imag(); } template inline constexpr bool operator!=( const complex<_Tp>& __x, const complex<_Tp>& __y) { return !(__x == __y); } template inline constexpr bool operator!=(const complex<_Tp>& __x, const _Tp& __y) { return !(__x == __y); } template inline constexpr bool operator!=(const _Tp& __x, const complex<_Tp>& __y) { return !(__x == __y); } template inline constexpr bool operator&&( const complex<_Tp>& __x, const complex<_Tp>& __y) { return bool(__x) && bool(__y); } template inline constexpr bool isnan(const complex<_Tp>& __x) { return isnan(__x.real()) || isnan(__x.imag()); } template inline constexpr bool operator||( const complex<_Tp>& __x, const complex<_Tp>& __y) { return bool(__x) || bool(__y); } // 26.3.7 values: template < class _Tp, bool = is_integral<_Tp>::value, bool = is_floating_point<_Tp>::value> struct __libcpp_complex_overload_traits {}; // Integral Types template struct __libcpp_complex_overload_traits<_Tp, true, false> { typedef double _ValueType; typedef complex _ComplexType; }; // Floating point types template struct __libcpp_complex_overload_traits<_Tp, false, true> { typedef _Tp _ValueType; typedef complex<_Tp> _ComplexType; }; // real template inline constexpr _Tp real(const complex<_Tp>& __c) { return __c.real(); } template inline constexpr typename __libcpp_complex_overload_traits<_Tp>::_ValueType real( _Tp __re) { return __re; } // imag template inline constexpr _Tp imag(const complex<_Tp>& __c) { return __c.imag(); } template inline constexpr typename __libcpp_complex_overload_traits<_Tp>::_ValueType imag( _Tp) { return 0; } // abs template inline _Tp abs(const complex<_Tp>& __c) { return hypot(__c.real(), __c.imag()); } // arg template inline _Tp arg(const complex<_Tp>& __c) { return atan2(__c.imag(), __c.real()); } template inline typename enable_if< is_integral<_Tp>::value || is_same<_Tp, double>::value, double>::type arg(_Tp __re) { return atan2(0., __re); } template inline typename enable_if::value, float>::type arg( _Tp __re) { return atan2f(0.F, __re); } } // namespace std namespace std { using ::isfinite; using ::isinf; using ::isnan; using ::signbit; using ::abs; using ::acos; using ::acosf; using ::asin; using ::asinf; using ::atan; using ::atan2; using ::atan2f; using ::atanf; using ::ceil; using ::ceilf; using ::cos; using ::cosf; using ::cosh; using ::coshf; using ::exp; using ::expf; using ::fabs; using ::fabsf; using ::floor; using ::floorf; using ::fmod; using ::fmodf; using ::frexp; using ::frexpf; using ::ldexp; using ::ldexpf; using ::log; using ::logf; using ::log10; using ::log10f; using ::modf; using ::modff; using ::pow; using ::powf; using ::sin; using ::sinf; using ::sinh; using ::sinhf; using ::sqrt; using ::sqrtf; using ::tan; using ::tanf; using ::tanh; using ::tanhf; using ::acosh; using ::acoshf; using ::asinh; using ::asinhf; using ::atanh; using ::atanhf; using ::cbrt; using ::cbrtf; using ::copysign; using ::copysignf; using ::erf; using ::erfc; using ::erfcf; using ::erff; using ::exp2; using ::exp2f; using ::expm1; using ::expm1f; using ::fdim; using ::fdimf; using ::fma; using ::fmaf; using ::fmax; using ::fmaxf; using ::fmin; using ::fminf; using ::hypot; using ::hypotf; using ::ilogb; using ::ilogbf; using ::lgamma; using ::lgammaf; using ::llrint; using ::llrintf; using ::llround; using ::llroundf; using ::log1p; using ::log1pf; using ::log2; using ::log2f; using ::logb; using ::logbf; using ::lrint; using ::lrintf; using ::lround; using ::lroundf; using ::nan; using ::nanf; using ::nearbyint; using ::nearbyintf; using ::nextafter; using ::nextafterf; using ::remainder; using ::remainderf; using ::remquo; using ::remquof; using ::rint; using ::rintf; using ::round; using ::roundf; using ::scalbln; using ::scalblnf; using ::scalbn; using ::scalbnf; using ::tgamma; using ::tgammaf; using ::trunc; using ::truncf; } // namespace std namespace std { // norm template inline _Tp norm(const complex<_Tp>& __c) { if (isinf(__c.real())) return abs(__c.real()); if (isinf(__c.imag())) return abs(__c.imag()); return __c.real() * __c.real() + __c.imag() * __c.imag(); } template inline typename __libcpp_complex_overload_traits<_Tp>::_ValueType norm( _Tp __re) { typedef typename __libcpp_complex_overload_traits<_Tp>::_ValueType _ValueType; return static_cast<_ValueType>(__re) * __re; } // conj template inline complex<_Tp> conj(const complex<_Tp>& __c) { return complex<_Tp>(__c.real(), -__c.imag()); } template inline typename __libcpp_complex_overload_traits<_Tp>::_ComplexType conj( _Tp __re) { typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; return _ComplexType(__re); } // proj template inline complex<_Tp> proj(const complex<_Tp>& __c) { complex<_Tp> __r = __c; if (isinf(__c.real()) || isinf(__c.imag())) __r = complex<_Tp>(INFINITY, copysign(_Tp(0), __c.imag())); return __r; } template inline typename enable_if< is_floating_point<_Tp>::value, typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type proj(_Tp __re) { if (isinf(__re)) __re = abs(__re); return complex<_Tp>(__re); } template inline typename enable_if< is_integral<_Tp>::value, typename __libcpp_complex_overload_traits<_Tp>::_ComplexType>::type proj(_Tp __re) { typedef typename __libcpp_complex_overload_traits<_Tp>::_ComplexType _ComplexType; return _ComplexType(__re); } // polar template complex<_Tp> polar(const _Tp& __rho, const _Tp& __theta = _Tp()) { if (isnan(__rho) || signbit(__rho)) return complex<_Tp>(_Tp(NAN), _Tp(NAN)); if (isnan(__theta)) { if (isinf(__rho)) return complex<_Tp>(__rho, __theta); return complex<_Tp>(__theta, __theta); } if (isinf(__theta)) { if (isinf(__rho)) return complex<_Tp>(__rho, _Tp(NAN)); return complex<_Tp>(_Tp(NAN), _Tp(NAN)); } _Tp __x = __rho * cos(__theta); if (isnan(__x)) __x = 0; _Tp __y = __rho * sin(__theta); if (isnan(__y)) __y = 0; return complex<_Tp>(__x, __y); } // log template inline complex<_Tp> log(const complex<_Tp>& __x) { return complex<_Tp>(log(abs(__x)), arg(__x)); } // log10 template inline complex<_Tp> log10(const complex<_Tp>& __x) { return log(__x) / log(_Tp(10)); } // log2 template inline complex<_Tp> log2(const complex<_Tp>& __x) { return log(__x) / log(_Tp(2)); } // sqrt template complex<_Tp> sqrt(const complex<_Tp>& __x) { if (isinf(__x.imag())) return complex<_Tp>(_Tp(INFINITY), __x.imag()); if (isinf(__x.real())) { if (__x.real() > _Tp(0)) return complex<_Tp>( __x.real(), isnan(__x.imag()) ? __x.imag() : copysign(_Tp(0), __x.imag())); return complex<_Tp>( isnan(__x.imag()) ? __x.imag() : _Tp(0), copysign(__x.real(), __x.imag())); } return polar(sqrt(abs(__x)), arg(__x) / _Tp(2)); } // exp template complex<_Tp> exp(const complex<_Tp>& __x) { _Tp __i = __x.imag(); if (__i == 0) { return complex<_Tp>(exp(__x.real()), copysign(_Tp(0), __x.imag())); } if (isinf(__x.real())) { if (__x.real() < _Tp(0)) { if (!isfinite(__i)) __i = _Tp(1); } else if (__i == 0 || !isfinite(__i)) { if (isinf(__i)) __i = _Tp(NAN); return complex<_Tp>(__x.real(), __i); } } _Tp __e = exp(__x.real()); return complex<_Tp>(__e * cos(__i), __e * sin(__i)); } // pow template inline complex<_Tp> pow(const complex<_Tp>& __x, const complex<_Tp>& __y) { return exp(__y * log(__x)); } template inline complex::type> pow( const complex<_Tp>& __x, const complex<_Up>& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); } template inline typename enable_if< is_arithmetic<_Up>::value, complex::type>>::type pow(const complex<_Tp>& __x, const _Up& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); } template inline typename enable_if< is_arithmetic<_Tp>::value, complex::type>>::type pow(const _Tp& __x, const complex<_Up>& __y) { typedef complex::type> result_type; return std::pow(result_type(__x), result_type(__y)); } // __sqr, computes pow(x, 2) template inline complex<_Tp> __sqr(const complex<_Tp>& __x) { return complex<_Tp>( (__x.real() - __x.imag()) * (__x.real() + __x.imag()), _Tp(2) * __x.real() * __x.imag()); } // asinh template complex<_Tp> asinh(const complex<_Tp>& __x) { const _Tp __pi(atan2(+0., -0.)); if (isinf(__x.real())) { if (isnan(__x.imag())) return __x; if (isinf(__x.imag())) return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); } if (isnan(__x.real())) { if (isinf(__x.imag())) return complex<_Tp>(__x.imag(), __x.real()); if (__x.imag() == 0) return __x; return complex<_Tp>(__x.real(), __x.real()); } if (isinf(__x.imag())) return complex<_Tp>( copysign(__x.imag(), __x.real()), copysign(__pi / _Tp(2), __x.imag())); complex<_Tp> __z = log(__x + sqrt(__sqr(__x) + _Tp(1))); return complex<_Tp>( copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); } // acosh template complex<_Tp> acosh(const complex<_Tp>& __x) { const _Tp __pi(atan2(+0., -0.)); if (isinf(__x.real())) { if (isnan(__x.imag())) return complex<_Tp>(abs(__x.real()), __x.imag()); if (isinf(__x.imag())) { if (__x.real() > 0) return complex<_Tp>(__x.real(), copysign(__pi * _Tp(0.25), __x.imag())); else return complex<_Tp>( -__x.real(), copysign(__pi * _Tp(0.75), __x.imag())); } if (__x.real() < 0) return complex<_Tp>(-__x.real(), copysign(__pi, __x.imag())); return complex<_Tp>(__x.real(), copysign(_Tp(0), __x.imag())); } if (isnan(__x.real())) { if (isinf(__x.imag())) return complex<_Tp>(abs(__x.imag()), __x.real()); return complex<_Tp>(__x.real(), __x.real()); } if (isinf(__x.imag())) return complex<_Tp>(abs(__x.imag()), copysign(__pi / _Tp(2), __x.imag())); complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); return complex<_Tp>( copysign(__z.real(), _Tp(0)), copysign(__z.imag(), __x.imag())); } // atanh template complex<_Tp> atanh(const complex<_Tp>& __x) { const _Tp __pi(atan2(+0., -0.)); if (isinf(__x.imag())) { return complex<_Tp>( copysign(_Tp(0), __x.real()), copysign(__pi / _Tp(2), __x.imag())); } if (isnan(__x.imag())) { if (isinf(__x.real()) || __x.real() == 0) return complex<_Tp>(copysign(_Tp(0), __x.real()), __x.imag()); return complex<_Tp>(__x.imag(), __x.imag()); } if (isnan(__x.real())) { return complex<_Tp>(__x.real(), __x.real()); } if (isinf(__x.real())) { return complex<_Tp>( copysign(_Tp(0), __x.real()), copysign(__pi / _Tp(2), __x.imag())); } if (abs(__x.real()) == _Tp(1) && __x.imag() == _Tp(0)) { return complex<_Tp>( copysign(_Tp(INFINITY), __x.real()), copysign(_Tp(0), __x.imag())); } complex<_Tp> __z = log((_Tp(1) + __x) / (_Tp(1) - __x)) / _Tp(2); return complex<_Tp>( copysign(__z.real(), __x.real()), copysign(__z.imag(), __x.imag())); } // sinh template complex<_Tp> sinh(const complex<_Tp>& __x) { if (isinf(__x.real()) && !isfinite(__x.imag())) return complex<_Tp>(__x.real(), _Tp(NAN)); if (__x.real() == 0 && !isfinite(__x.imag())) return complex<_Tp>(__x.real(), _Tp(NAN)); if (__x.imag() == 0 && !isfinite(__x.real())) return __x; return complex<_Tp>( sinh(__x.real()) * cos(__x.imag()), cosh(__x.real()) * sin(__x.imag())); } // cosh template complex<_Tp> cosh(const complex<_Tp>& __x) { if (isinf(__x.real()) && !isfinite(__x.imag())) return complex<_Tp>(abs(__x.real()), _Tp(NAN)); if (__x.real() == 0 && !isfinite(__x.imag())) return complex<_Tp>(_Tp(NAN), __x.real()); if (__x.real() == 0 && __x.imag() == 0) return complex<_Tp>(_Tp(1), __x.imag()); if (__x.imag() == 0 && !isfinite(__x.real())) return complex<_Tp>(abs(__x.real()), __x.imag()); return complex<_Tp>( cosh(__x.real()) * cos(__x.imag()), sinh(__x.real()) * sin(__x.imag())); } // tanh template complex<_Tp> tanh(const complex<_Tp>& __x) { if (isinf(__x.real())) { if (!isfinite(__x.imag())) return complex<_Tp>(copysign(_Tp(1), __x.real()), _Tp(0)); return complex<_Tp>( copysign(_Tp(1), __x.real()), copysign(_Tp(0), sin(_Tp(2) * __x.imag()))); } if (isnan(__x.real()) && __x.imag() == 0) return __x; _Tp __2r(_Tp(2) * __x.real()); _Tp __2i(_Tp(2) * __x.imag()); _Tp __d(cosh(__2r) + cos(__2i)); _Tp __2rsh(sinh(__2r)); if (isinf(__2rsh) && isinf(__d)) return complex<_Tp>( __2rsh > _Tp(0) ? _Tp(1) : _Tp(-1), __2i > _Tp(0) ? _Tp(0) : _Tp(-0.)); return complex<_Tp>(__2rsh / __d, sin(__2i) / __d); } // asin template complex<_Tp> asin(const complex<_Tp>& __x) { complex<_Tp> __z = asinh(complex<_Tp>(-__x.imag(), __x.real())); return complex<_Tp>(__z.imag(), -__z.real()); } // acos template complex<_Tp> acos(const complex<_Tp>& __x) { const _Tp __pi(atan2(+0., -0.)); if (isinf(__x.real())) { if (isnan(__x.imag())) return complex<_Tp>(__x.imag(), __x.real()); if (isinf(__x.imag())) { if (__x.real() < _Tp(0)) return complex<_Tp>(_Tp(0.75) * __pi, -__x.imag()); return complex<_Tp>(_Tp(0.25) * __pi, -__x.imag()); } if (__x.real() < _Tp(0)) return complex<_Tp>(__pi, signbit(__x.imag()) ? -__x.real() : __x.real()); return complex<_Tp>(_Tp(0), signbit(__x.imag()) ? __x.real() : -__x.real()); } if (isnan(__x.real())) { if (isinf(__x.imag())) return complex<_Tp>(__x.real(), -__x.imag()); return complex<_Tp>(__x.real(), __x.real()); } if (isinf(__x.imag())) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); if (signbit(__x.imag())) return complex<_Tp>(abs(__z.imag()), abs(__z.real())); return complex<_Tp>(abs(__z.imag()), -abs(__z.real())); } // atan template complex<_Tp> atan(const complex<_Tp>& __x) { complex<_Tp> __z = atanh(complex<_Tp>(-__x.imag(), __x.real())); return complex<_Tp>(__z.imag(), -__z.real()); } // sin template complex<_Tp> sin(const complex<_Tp>& __x) { complex<_Tp> __z = sinh(complex<_Tp>(-__x.imag(), __x.real())); return complex<_Tp>(__z.imag(), -__z.real()); } // cos template inline complex<_Tp> cos(const complex<_Tp>& __x) { return cosh(complex<_Tp>(-__x.imag(), __x.real())); } // tan template complex<_Tp> tan(const complex<_Tp>& __x) { complex<_Tp> __z = tanh(complex<_Tp>(-__x.imag(), __x.real())); return complex<_Tp>(__z.imag(), -__z.real()); } // Literal suffix for complex number literals [complex.literals] inline namespace literals { inline namespace complex_literals { constexpr complex operator""i(long double __im) { return {0.0, static_cast(__im)}; } constexpr complex operator""i(unsigned long long __im) { return {0.0, static_cast(__im)}; } constexpr complex operator""if(long double __im) { return {0.0f, static_cast(__im)}; } constexpr complex operator""if(unsigned long long __im) { return {0.0f, static_cast(__im)}; } } // namespace complex_literals } // namespace literals } // namespace std __device__ std::complex lerp( std::complex start, std::complex end, std::complex weight) { if (abs(weight) < 0.5) { return start + weight * (end - start); } else { return end - (end - start) * (1.0 - weight); } } __device__ std::complex lerp( std::complex start, std::complex end, std::complex weight) { if (abs(weight) < 0.5f) { return start + weight * (end - start); } else { return end - (end - start) * (1.0f - weight); } } __device__ std::complex reciprocal(std::complex x) { return 1.0 / x; } __device__ std::complex reciprocal(std::complex x) { return 1.0f / x; } __device__ std::complex sigmoid(std::complex x) { return 1.0 / (1.0 + exp(-x)); } __device__ std::complex sigmoid(std::complex x) { return 1.0f / (1.0f + exp(-x)); } // The reciprocal of a complex number z is // 1/z = conj(z)/|z|^2. // The principal square root of a complex number z can be obtained by [1] // sqrt(z) = sqrt(|z|) (z + |z|) / |z + |z||. // Combining these formulas we have // 1/sqrt(z) = (conj(z) + |z|) / (sqrt(|z|) |z + |z||). // [1] https://math.stackexchange.com/a/44500 __device__ std::complex rsqrt(std::complex z) { auto a = std::real(z); auto b = std::imag(z); auto absa = ::fabsf(a); auto absb = ::fabsf(b); // scale to avoid precision loss due to underflow/overflow auto scale = fmax(absa, absb); a /= scale; b /= scale; auto a_sq = a * a; auto b_sq = b * b; auto modz_sq = a_sq + b_sq; auto modz = ::sqrtf(modz_sq); auto a_plus_modz = a + modz; auto mod_zplusmodz_sq = a_plus_modz * a_plus_modz + b_sq; auto fac = ::rsqrtf(scale * modz * mod_zplusmodz_sq); return std::complex(a_plus_modz * fac, -b * fac); } __device__ std::complex rsqrt(std::complex z) { auto a = std::real(z); auto b = std::imag(z); auto absa = ::abs(a); auto absb = ::abs(b); // scale to avoid precision loss due to underflow/overflow auto scale = fmax(absa, absb); a /= scale; b /= scale; auto a_sq = a * a; auto b_sq = b * b; auto modz_sq = a_sq + b_sq; auto modz = ::sqrt(modz_sq); auto a_plus_modz = a + modz; auto mod_zplusmodz_sq = a_plus_modz * a_plus_modz + b_sq; auto fac = ::rsqrt(scale * modz * mod_zplusmodz_sq); return std::complex(a_plus_modz * fac, -b * fac); } template bool isfinite(std::complex x) { return ::isfinite(std::real(x)) && ::isfinite(std::imag(x)); } template bool isinf(std::complex x) { return ::isinf(std::real(x)) || ::isinf(std::imag(x)); } template bool isreal(std::complex x) { return std::imag(x) == 0; } #endif // __NVCC__ // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #define __NVFUSER_HALF_TO_US(var) *(reinterpret_cast(&(var))) #define __NVFUSER_HALF_TO_CUS(var) \ *(reinterpret_cast(&(var))) struct __half; __device__ __inline__ __half __float2half(const float); struct __align__(2) __half { __half() = default; __half(const __half& other) { __x = other.__x; } __half(const __half&& other) { __x = other.__x; } __half(const volatile __half& other) { __x = other.__x; } __half(const volatile __half&& other) { __x = other.__x; } // Note: not returning reference for `__half::operator=` // Doing so would requires us to return `volatile __half&` for the volatile // variants, which would trigger a gcc warning `implicit dereference will not // access object of type ‘volatile S’ in statement` __device__ void operator=(const __half& other) { __x = other.__x; } __device__ void operator=(const __half&& other) { __x = other.__x; } __device__ void operator=(const volatile __half& other) { __x = other.__x; } __device__ void operator=(const volatile __half&& other) { __x = other.__x; } __device__ void operator=(const __half& other) volatile { __x = other.__x; } __device__ void operator=(const __half&& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __half& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __half&& other) volatile { __x = other.__x; } __device__ __half(const float f) { __x = __float2half(f).__x; } __device__ uint16_t raw() const { return __x; } protected: unsigned short __x; }; __device__ __inline__ __half __float2half(const float f) { __half val; asm("{ cvt.rn.f16.f32 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "f"(f)); return val; } __device__ __inline__ __half __double2half(const double d) { __half val; asm("{ cvt.rn.f16.f64 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "d"(d)); return val; } __device__ __inline__ __half __int2half(const int i) { __half val; asm("{ cvt.rn.f16.s32 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "r"(i)); return val; } __device__ __inline__ __half __int2half(const int64_t i64) { __half val; asm("{ cvt.rn.f16.s64 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "l"(i64)); return val; } __device__ __inline__ __half __int2half(const uint32_t i) { __half val; asm("{ cvt.rn.f16.u32 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "r"(i)); return val; } __device__ __inline__ __half __int2half(const uint64_t i64) { __half val; asm("{ cvt.rn.f16.u64 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "l"(i64)); return val; } __device__ __inline__ __half __bool2half(const bool b) { return __int2half((int)b); } __device__ __inline__ float __half2float(const __half h) { float val; asm("{ cvt.f32.f16 %0, %1;}\n" : "=f"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ __inline__ double __half2double(const __half h) { double val; asm("{ cvt.f64.f16 %0, %1;}\n" : "=d"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ int __half2int32(const __half h) { int val; asm("{ cvt.rzi.s32.f16 %0, %1;}\n" : "=r"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ __inline__ int64_t __half2int(const __half h) { int64_t val; asm("{ cvt.rzi.s64.f16 %0, %1;}\n" : "=l"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ int __half2uint32(const __half h) { int val; asm("{ cvt.rzi.u32.f16 %0, %1;}\n" : "=r"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ __inline__ int64_t __half2uint(const __half h) { int64_t val; asm("{ cvt.rzi.u64.f16 %0, %1;}\n" : "=l"(val) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; } __device__ __inline__ void __half2int(const __half h, int& output) { output = __half2int32(h); } __device__ __inline__ void __half2int(const __half h, int64_t& output) { output = __half2int(h); } __device__ __inline__ void __half2int(const __half h, uint32_t& output) { output = __half2uint32(h); } __device__ __inline__ void __half2int(const __half h, uint64_t& output) { output = __half2uint(h); } __device__ __inline__ nvfuser_index_t __half2index(const __half h) { nvfuser_index_t result; __half2int(h, result); return result; } __device__ __inline__ bool __half2bool(const __half h) { return (bool)__half2float(h) != 0; } __device__ __inline__ __half __real_then_2half(const std::complex c) { return __float2half(std::real(c)); } __device__ __inline__ __half __real_then_2half(const std::complex c) { return __double2half(std::real(c)); } __device__ __inline__ bool __heq(const __half a, const __half b) { // From cuda_fp16.hpp unsigned short val; asm("{ .reg .pred __$temp3;\n" " setp.eq.f16 __$temp3, %1, %2;\n" " selp.u16 %0, 1, 0, __$temp3;}" : "=h"(val) : "h"(__NVFUSER_HALF_TO_CUS(a)), "h"(__NVFUSER_HALF_TO_CUS(b))); return (val != 0U) ? true : false; } // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #define __NVFUSER_BFLOAT_TO_US(var) *(reinterpret_cast(&(var))) #define __NVFUSER_BFLOAT_TO_CUS(var) \ *(reinterpret_cast(&(var))) struct __bfloat; __device__ __inline__ __bfloat __float2bfloat(const float); struct __align__(2) __bfloat { __bfloat() = default; __bfloat(const __bfloat& other) { __x = other.__x; } __bfloat(const __bfloat&& other) { __x = other.__x; } __bfloat(const volatile __bfloat& other) { __x = other.__x; } __bfloat(const volatile __bfloat&& other) { __x = other.__x; } // Note: not returning reference for `__bfloat::operator=` // Doing so would requires us to return `volatile __bfloat&` for the volatile // variants, which would trigger a gcc warning `implicit dereference will not // access object of type ‘volatile S’ in statement` __device__ void operator=(const __bfloat& other) { __x = other.__x; } __device__ void operator=(const __bfloat&& other) { __x = other.__x; } __device__ void operator=(const volatile __bfloat& other) { __x = other.__x; } __device__ void operator=(const volatile __bfloat&& other) { __x = other.__x; } __device__ void operator=(const __bfloat& other) volatile { __x = other.__x; } __device__ void operator=(const __bfloat&& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __bfloat& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __bfloat&& other) volatile { __x = other.__x; } __device__ __bfloat(const float f) { __x = __float2bfloat(f).__x; } __device__ uint16_t raw() const { return __x; } protected: unsigned short __x; }; __device__ __inline__ __bfloat __float2bfloat(const float f) { __bfloat val; asm("{ cvt.rn.bf16.f32 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "f"(f)); return val; } __device__ __inline__ __bfloat __double2bfloat(const double d) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.f64 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "d"(d)); return val; #else return __float2bfloat(static_cast(d)); #endif } __device__ __inline__ __bfloat __int2bfloat(const int i) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.s32 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "r"(i)); return val; #else return __float2bfloat(static_cast(i)); #endif } __device__ __inline__ __bfloat __int2bfloat(const int64_t i64) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.s64 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "l"(i64)); return val; #else return __float2bfloat(static_cast(i64)); #endif } __device__ __inline__ __bfloat __int2bfloat(const uint32_t i) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.u32 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "r"(i)); return val; #else return __float2bfloat(static_cast(i)); #endif } __device__ __inline__ __bfloat __int2bfloat(const uint64_t i64) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.u64 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "l"(i64)); return val; #else return __float2bfloat(static_cast(i64)); #endif } __device__ __inline__ __bfloat __bool2bfloat(const bool b) { return __int2bfloat((int)b); } __device__ __inline__ float __bfloat2float(const __bfloat h) { float val; asm("{ mov.b32 %0, {0,%1};}\n" : "=f"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; } __device__ __inline__ double __bfloat2double(const __bfloat h) { #if __CUDA_ARCH__ >= 900 double val; asm("{ cvt.f64.bf16 %0, %1;}\n" : "=d"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return static_cast(__bfloat2float(h)); #endif } __device__ int __bfloat2int32(const __bfloat h) { #if __CUDA_ARCH__ >= 900 int val; asm("{ cvt.rzi.s32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return static_cast(__bfloat2float(h)); #endif } __device__ __inline__ int64_t __bfloat2int(const __bfloat h) { #if __CUDA_ARCH__ >= 900 int64_t val; asm("{ cvt.rzi.s64.bf16 %0, %1;}\n" : "=l"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return static_cast(__bfloat2float(h)); #endif } __device__ int __bfloat2uint32(const __bfloat h) { #if __CUDA_ARCH__ >= 900 int val; asm("{ cvt.rzi.u32.bf16 %0, %1;}\n" : "=r"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return static_cast(__bfloat2float(h)); #endif } __device__ __inline__ int64_t __bfloat2uint(const __bfloat h) { #if __CUDA_ARCH__ >= 900 int64_t val; asm("{ cvt.rzi.u64.bf16 %0, %1;}\n" : "=l"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return static_cast(__bfloat2float(h)); #endif } __device__ __inline__ void __bfloat2int(const __bfloat h, int& output) { output = __bfloat2int32(h); } __device__ __inline__ void __bfloat2int(const __bfloat h, int64_t& output) { output = __bfloat2int(h); } __device__ __inline__ void __bfloat2int(const __bfloat h, uint32_t& output) { output = __bfloat2uint32(h); } __device__ __inline__ void __bfloat2int(const __bfloat h, uint64_t& output) { output = __bfloat2uint(h); } __device__ __inline__ nvfuser_index_t __bfloat2index( const __bfloat h, bool& output) { nvfuser_index_t result; __bfloat2int(h, result); return result; } __device__ __inline__ bool __bfloat2bool(const __bfloat h) { return (bool)__bfloat2float(h) != 0; } __device__ __inline__ __bfloat __half2bfloat(const __half h) { #if __CUDA_ARCH__ >= 900 __bfloat val; asm("{ cvt.rn.bf16.f16 %0, %1;}\n" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "h"(__NVFUSER_HALF_TO_CUS(h))); return val; #else return __float2bfloat(__half2float(h)); #endif } __device__ __inline__ __half __bfloat2half(const __bfloat h) { #if __CUDA_ARCH__ >= 900 __half val; asm("{ cvt.rn.f16.bf16 %0, %1;}\n" : "=h"(__NVFUSER_HALF_TO_US(val)) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); return val; #else return __float2half(__bfloat2float(h)); #endif } __device__ __inline__ __bfloat __real_then_2bfloat( const std::complex c) { return __float2bfloat(std::real(c)); } __device__ __inline__ __bfloat __real_then_2bfloat( const std::complex c) { return __double2bfloat(std::real(c)); } __device__ __inline__ bool __heq(const __bfloat a, const __bfloat b) { // From cuda_bf16.hpp #if __CUDA_ARCH__ >= 900 unsigned short val; asm("{ .reg .pred __$temp3;\n" " setp.eq.bf16 __$temp3, %1, %2;\n" " selp.u16 %0, 1, 0, __$temp3;}" : "=h"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(a)), "h"(__NVFUSER_BFLOAT_TO_CUS(b))); #else unsigned int val; asm("{.reg .b32 a,b;\n" " mov.b32 a, {0, %1};\n" " mov.b32 b, {0, %2};\n" " set.eq.f32.f32 %0, a, b;}\n" : "=r"(val) : "h"(__NVFUSER_BFLOAT_TO_CUS(a)), "h"(__NVFUSER_BFLOAT_TO_CUS(b))); #endif return (val != 0U) ? true : false; } // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on struct __e4m3; __device__ __inline__ __e4m3 __float2e4m3(const float); struct __align__(1) __e4m3 { __e4m3() = default; __e4m3(const __e4m3& other) { __x = other.__x; } __e4m3(const __e4m3&& other) { __x = other.__x; } __e4m3(const volatile __e4m3& other) { __x = other.__x; } __e4m3(const volatile __e4m3&& other) { __x = other.__x; } // Note: not returning reference for `__e4m3::operator=` // Doing so would requires us to return `volatile __e4m3&` for the volatile // variants, which would trigger a gcc warning `implicit dereference will not // access object of type ‘volatile S’ in statement` __device__ void operator=(const __e4m3& other) { __x = other.__x; } __device__ void operator=(const __e4m3&& other) { __x = other.__x; } __device__ void operator=(const volatile __e4m3& other) { __x = other.__x; } __device__ void operator=(const volatile __e4m3&& other) { __x = other.__x; } __device__ void operator=(const __e4m3& other) volatile { __x = other.__x; } __device__ void operator=(const __e4m3&& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __e4m3& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __e4m3&& other) volatile { __x = other.__x; } __device__ __e4m3(const float f) { __x = __float2e4m3(f).__x; } __device__ uint8_t raw() const { return __x; } protected: uint8_t __x; }; __device__ __inline__ __e4m3 __double2e4m3(const double f) { unsigned short _tmp_buffer; __e4m3 val; asm("{\n\t" ".reg .b16 buf0;\n\t" ".reg .b32 buf1;\n\t" "cvt.rn.f16.f64 buf0, %1;\n\t" "cvt.u32.u16 buf1, buf0;\n\t" "cvt.rn.satfinite.e4m3x2.f16x2 %0, buf1;\n\t" "}" : "=h"(_tmp_buffer) : "d"(f)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ double __e4m32double(const __e4m3 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); double val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" "cvt.u16.u32 %1, buf0;\n\t" "cvt.f64.f16 %0, %1;" "}" : "=d"(val) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e4m3 __float2e4m3(const float f) { constexpr float f_const_zero = 0.f; unsigned short _tmp_buffer; __e4m3 val; asm("{cvt.rn.satfinite.e4m3x2.f32 %0, %1, %2;}" : "=h"(_tmp_buffer) : "f"(f_const_zero), "f"(f)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ float __e4m32float(const __e4m3 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); float val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" "cvt.u16.u32 %1, buf0;\n\t" "cvt.f32.f16 %0, %1;\n\t" "}" : "=f"(val) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e4m3 __half2e4m3(const __half h) { uint32_t buffer; memcpy(&buffer, &h, sizeof(__half)); unsigned short _tmp_buffer; __e4m3 val; asm("{cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;}\n\t" : "=h"(_tmp_buffer) : "r"(buffer)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ __half __e4m32half(const __e4m3 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); __half val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" "cvt.u16.u32 %0, buf0;\n\t" "}" : "=h"(__NVFUSER_HALF_TO_US(val)) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e4m3 __bfloat2e4m3(const __bfloat h) { unsigned short _tmp_buffer; __e4m3 val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16.bf16 %1, %1;\n\t" "cvt.u32.u16 buf0, %1;\n\t" "cvt.rn.satfinite.e4m3x2.f16x2 %0, buf0;\n\t" "}" : "=h"(_tmp_buffer) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ __bfloat __e4m32bfloat(const __e4m3 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); __bfloat val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e4m3x2 buf0, %1;\n\t" "cvt.u16.u32 %0, buf0;\n\t" "cvt.bf16.f16 %0, %0;\n\t" "}" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "h"(_tmp_buffer)); return val; } struct __e5m2; __device__ __inline__ __e5m2 __float2e5m2(const float); struct __align__(1) __e5m2 { __e5m2() = default; __e5m2(const __e5m2& other) { __x = other.__x; } __e5m2(const __e5m2&& other) { __x = other.__x; } __e5m2(const volatile __e5m2& other) { __x = other.__x; } __e5m2(const volatile __e5m2&& other) { __x = other.__x; } // Note: not returning reference for `__e5m2::operator=` // Doing so would requires us to return `volatile __e5m2&` for the volatile // variants, which would trigger a gcc warning `implicit dereference will not // access object of type ‘volatile S’ in statement` __device__ void operator=(const __e5m2& other) { __x = other.__x; } __device__ void operator=(const __e5m2&& other) { __x = other.__x; } __device__ void operator=(const volatile __e5m2& other) { __x = other.__x; } __device__ void operator=(const volatile __e5m2&& other) { __x = other.__x; } __device__ void operator=(const __e5m2& other) volatile { __x = other.__x; } __device__ void operator=(const __e5m2&& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __e5m2& other) volatile { __x = other.__x; } __device__ void operator=(const volatile __e5m2&& other) volatile { __x = other.__x; } __device__ __e5m2(const float f) { __x = __float2e5m2(f).__x; } __device__ uint8_t raw() const { return __x; } protected: uint8_t __x; }; __device__ __inline__ __e5m2 __double2e5m2(const double f) { unsigned short _tmp_buffer; __e5m2 val; asm("{\n\t" ".reg .b16 buf0;\n\t" ".reg .b32 buf1;\n\t" "cvt.rn.f16.f64 buf0, %1;\n\t" "cvt.u32.u16 buf1, buf0;\n\t" "cvt.rn.satfinite.e5m2x2.f16x2 %0, buf1;\n\t" "}" : "=h"(_tmp_buffer) : "d"(f)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ double __e5m22double(const __e5m2 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); double val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" "cvt.u16.u32 %1, buf0;\n\t" "cvt.f64.f16 %0, %1;" "}" : "=d"(val) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e5m2 __float2e5m2(const float f) { constexpr float f_const_zero = 0.f; unsigned short _tmp_buffer; __e5m2 val; asm("{cvt.rn.satfinite.e5m2x2.f32 %0, %1, %2;}" : "=h"(_tmp_buffer) : "f"(f_const_zero), "f"(f)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ float __e5m22float(const __e5m2 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); float val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" "cvt.u16.u32 %1, buf0;\n\t" "cvt.f32.f16 %0, %1;\n\t" "}" : "=f"(val) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e5m2 __half2e5m2(const __half h) { uint32_t buffer; memcpy(&buffer, &h, sizeof(__half)); unsigned short _tmp_buffer; __e5m2 val; asm("{cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;}\n\t" : "=h"(_tmp_buffer) : "r"(buffer)); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ __half __e5m22half(const __e5m2 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); __half val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" "cvt.u16.u32 %0, buf0;\n\t" "}" : "=h"(__NVFUSER_HALF_TO_US(val)) : "h"(_tmp_buffer)); return val; } __device__ __inline__ __e5m2 __bfloat2e5m2(const __bfloat h) { unsigned short _tmp_buffer; __e5m2 val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16.bf16 %1, %1;\n\t" "cvt.u32.u16 buf0, %1;\n\t" "cvt.rn.satfinite.e5m2x2.f16x2 %0, buf0;\n\t" "}" : "=h"(_tmp_buffer) : "h"(__NVFUSER_BFLOAT_TO_CUS(h))); memcpy(&val, &_tmp_buffer, sizeof(uint8_t)); return val; } __device__ __inline__ __bfloat __e5m22bfloat(const __e5m2 h) { unsigned short _tmp_buffer; memcpy(&_tmp_buffer, &h, sizeof(uint8_t)); __bfloat val; asm("{\n\t" ".reg .b32 buf0;\n\t" "cvt.rn.f16x2.e5m2x2 buf0, %1;\n\t" "cvt.u16.u32 %0, buf0;\n\t" "cvt.bf16.f16 %0, %0;\n\t" "}" : "=h"(__NVFUSER_BFLOAT_TO_US(val)) : "h"(_tmp_buffer)); return val; } // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on // Type trait utils template struct MaybeVolatile; template struct MaybeVolatile { using type = volatile Type; }; template struct MaybeVolatile { using type = Type; }; template struct TypeList {}; template struct TypeSelector { using type = typename TypeSelector::type; }; template struct TypeSelector<0, T, Types...> { using type = T; }; template struct IsSameType { static constexpr bool value = false; }; template struct IsSameType { static constexpr bool value = true; }; template struct IsPointerType { static constexpr bool value = false; }; template struct IsPointerType { static constexpr bool value = true; }; // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on // aligned register array for vectorized load/store template struct alignas(sizeof(scalar_t) * align_size) Array { scalar_t array[size]; __device__ void set(scalar_t v) { #pragma unroll for (int i = 0; i < size; ++i) { array[i] = v; } } __device__ scalar_t& operator[](const unsigned int i) { return array[i]; } __device__ const scalar_t& operator[](const unsigned int i) const { return array[i]; } Array& operator=(const Array& a) { #pragma unroll for (int i = 0; i < size; ++i) { array[i] = a[i]; } return *this; } }; // Used for vectorized allocations that are not in registers template __device__ void arraySet(scalar_t* buff, scalar_t val) { #pragma unroll for (int i = 0; i < vec_size; ++i) { buff[i] = val; } } template __device__ void loadGeneric(scalar_t* to, scalar_t* from) { // It would be really nice to use memcpy here, but one example was failing // with: // // memcpy(to, from, vec_size * sizeof(scalar_t)); // // Yet passing with: // // for(int i = 0; i < vec_size; i++){ // to[i] = from[i]; // } switch (sizeof(scalar_t) * vec_size) { case 1: *reinterpret_cast(to) = *reinterpret_cast(from); break; case 2: *reinterpret_cast(to) = *reinterpret_cast(from); break; case 4: *reinterpret_cast(to) = *reinterpret_cast(from); break; case 8: *reinterpret_cast(to) = *reinterpret_cast(from); break; case 12: *reinterpret_cast(to) = *reinterpret_cast(from); break; case 16: *reinterpret_cast(to) = *reinterpret_cast(from); break; } } // Volatile version only works with c++ fundamnetal types template < typename scalar_t, int vec_size, bool is_volatile_to, bool is_volatile_from> __device__ void loadGenericVolatile( typename MaybeVolatile::type* to, typename MaybeVolatile::type* from) { switch (sizeof(scalar_t) * vec_size) { // Reinterpret cast like this with volatile types only works for C++ // fundamental types otherwise the = operator is not defined case 1: *reinterpret_cast< typename MaybeVolatile::type*>(to) = *reinterpret_cast< typename MaybeVolatile::type*>( from); break; case 2: *reinterpret_cast::type*>( to) = *reinterpret_cast< typename MaybeVolatile::type*>(from); break; case 4: *reinterpret_cast< typename MaybeVolatile::type*>(to) = *reinterpret_cast< typename MaybeVolatile::type*>( from); break; case 8: *reinterpret_cast::type*>( to) = *reinterpret_cast< typename MaybeVolatile::type*>(from); break; } } template __device__ void loadLocalToGlobal( typename MaybeVolatile::type* to, scalar_t* from) { switch (sizeof(scalar_t) * vec_size) { case 1: case 2: case 4: loadGenericVolatile(to, from); break; case 8: { uint2 const& data = *reinterpret_cast(from); if (is_volatile) { asm volatile( "st.volatile.global.v2.s32 [%0], {%1,%2};" ::"l"( (typename MaybeVolatile::type*)to), "r"(data.x), "r"(data.y)); } else { asm volatile( "st.global.cs.v2.s32 [%0], {%1,%2};" ::"l"( (typename MaybeVolatile::type*)to), "r"(data.x), "r"(data.y)); } break; } case 16: { uint4 const& data = *reinterpret_cast(from); if (is_volatile) { asm volatile( "st.volatile.global.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( (typename MaybeVolatile::type*)to), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); } else { asm volatile( "st.global.cs.v4.s32 [%0], {%1,%2,%3,%4};" ::"l"( (typename MaybeVolatile::type*)to), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); } break; } } } // This is copied from csrc/type.h and should be kept consistent. enum class CacheOp { AllLevels, Streaming, Global, }; template __device__ void loadGlobalToLocalCached(void* to, void* from) { T* typed_to = reinterpret_cast(to); T* typed_from = reinterpret_cast(from); switch (cache_op) { case CacheOp::AllLevels: *typed_to = __ldca(typed_from); break; case CacheOp::Streaming: *typed_to = __ldcs(typed_from); break; case CacheOp::Global: *typed_to = __ldcg(typed_from); break; } } // For simplicity, cache_op is only used for non-volatile loads written in // inline assembly. Other loads are done with the default cache operator -- // cache all levels. ld.volatile doesn't accept cache operator anyway. template __device__ void loadGlobalToLocal( scalar_t* to, typename MaybeVolatile::type* from) { switch (sizeof(scalar_t) * vec_size) { case 1: case 2: case 4: loadGenericVolatile(to, from); break; case 8: { if (is_volatile) { uint2& data = *reinterpret_cast(to); asm volatile("ld.volatile.global.v2.s32 {%0,%1}, [%2];" : "=r"(data.x), "=r"(data.y) : "l"((uint2*)from)); } else { loadGlobalToLocalCached( to, const_cast(from)); } break; } case 16: { if (is_volatile) { uint4& data = *reinterpret_cast(to); asm volatile("ld.volatile.global.v4.s32 {%0,%1,%2,%3}, [%4];" : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) : "l"((uint4*)from)); } else { loadGlobalToLocalCached( to, const_cast(from)); } break; } } } template < typename scalar_t, int vec_size, bool is_volatile_to, bool is_volatile_from> __device__ void loadGlobalToGlobal( typename MaybeVolatile::type* to, typename MaybeVolatile::type* from) { switch (sizeof(scalar_t) * vec_size) { // Reinterpret cast like this with volatile types only works for C++ // fundamental types otherwise the = operator is not defined case 1: case 2: case 4: case 8: loadGenericVolatile( to, from); break; case 12: { uint3 local_intermediate; loadGlobalToLocal< scalar_t, vec_size, is_volatile_from, CacheOp::Streaming>( reinterpret_cast(&local_intermediate), from); loadLocalToGlobal( to, reinterpret_cast(&local_intermediate)); break; } case 16: { uint4 local_intermediate; loadGlobalToLocal< scalar_t, vec_size, is_volatile_from, CacheOp::Streaming>( reinterpret_cast(&local_intermediate), from); loadLocalToGlobal( to, reinterpret_cast(&local_intermediate)); break; } } } // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on template struct Tensor { __device__ T& operator[](nvfuser_index_t ind) { return data[ind]; }; T* data; Array logical_size; Array alloc_stride; }; // Specialization for 0-dim case as it does not need size and stride arrays. // They will be an error as well since zero-length arrays are not allowed. template struct Tensor { __device__ T& operator[](nvfuser_index_t i) { return *data; }; T* data; }; // Specialization for 0-dim case that's easy to pass in a CPU based tensor. template struct CpuScalarTensor { __device__ T& operator[](int i) { return data; }; T data; }; // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on __device__ unsigned int mulhilo32( unsigned int a, unsigned int b, unsigned int* result_high) { *result_high = __umulhi(a, b); return a * b; } __device__ uint4 single_round(uint4 ctr, uint2 key) { constexpr unsigned long kPhiloxSA = 0xD2511F53; constexpr unsigned long kPhiloxSB = 0xCD9E8D57; unsigned int hi0; unsigned int hi1; unsigned int lo0 = mulhilo32(kPhiloxSA, ctr.x, &hi0); unsigned int lo1 = mulhilo32(kPhiloxSB, ctr.z, &hi1); uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; return ret; } __device__ uint4 philox( unsigned long long seed, unsigned long long subsequence, unsigned long long offset) { constexpr unsigned long kPhilox10A = 0x9E3779B9; constexpr unsigned long kPhilox10B = 0xBB67AE85; uint2 key = {}; key.x = (unsigned int)seed; key.y = (unsigned int)(seed >> 32); uint4 counter = make_uint4(0, 0, 0, 0); counter.x = (unsigned int)(offset); counter.y = (unsigned int)(offset >> 32); counter.z = (unsigned int)(subsequence); counter.w = (unsigned int)(subsequence >> 32); uint4 output = {}; uint2 key_ = key; uint4 counter_ = counter; for (int i = 0; i < 9; i++) { counter_ = single_round(counter_, key_); key_.x += (kPhilox10A); key_.y += (kPhilox10B); } output = single_round(counter_, key_); return output; } // This is a uniform double in the range (0, 1] __device__ double raw_uniform_double(unsigned int x, unsigned int y) { constexpr double scale = 1.0 / (double)(1ll << 53); const unsigned long long z = (unsigned long long)x ^ ((unsigned long long)y << (53 - 32)); return (double)z * scale + 0.5 * scale; } // This is a uniform float in the range (0, 1] __device__ float raw_uniform_float(unsigned int x) { constexpr float scale = (float)(1.0 / (double)(1ll << 32)); return (float)x * scale + 0.5f * scale; } __device__ __half uniform_half(unsigned int x) { __half result = __float2half(raw_uniform_float(x)); return __heq(result, __float2half(1.0f)) ? __float2half(0.0f) : result; } __device__ __bfloat uniform_bfloat(unsigned int x) { __bfloat result = __float2bfloat(raw_uniform_float(x)); return __heq(result, __float2bfloat(1.0f)) ? __float2bfloat(0.0f) : result; } __device__ float uniformf(unsigned int x) { float result = raw_uniform_float(x); return result == 1.0f ? 0.0f : result; } __device__ double uniform(unsigned int x, unsigned int y) { double result = raw_uniform_double(x, y); return result == 1.0 ? 0.0 : result; } __device__ double rng_uniform(const uint4& rng_result, int rng_component) { return uniform( (&rng_result.x)[rng_component * 2], (&rng_result.x)[rng_component * 2 + 1]); } __device__ float rng_uniformf(const uint4& rng_result, int rng_component) { return uniformf((&rng_result.x)[rng_component]); } __device__ __half rng_uniform_half(const uint4& rng_result, int rng_component) { return uniform_half((&rng_result.x)[rng_component]); } __device__ __bfloat rng_uniform_bfloat(const uint4& rng_result, int rng_component) { return uniform_bfloat((&rng_result.x)[rng_component]); } __device__ double rng_uniform_range( const uint4& rng_result, int rng_component, double from, double to) { auto range = to - from; auto uniform01 = rng_uniform(rng_result, rng_component); return from + range * uniform01; } __device__ float rng_uniform_rangef( const uint4& rng_result, int rng_component, float from, float to) { auto range = to - from; auto uniform01 = rng_uniformf(rng_result, rng_component); return from + range * uniform01; } __device__ __half rng_uniform_range_half( const uint4& rng_result, int rng_component, float from, float to) { auto range = to - from; float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]); __half result = __float2half(from + range * uniform01); return __heq(result, __float2half(to)) ? __float2half(from) : result; } __device__ __bfloat rng_uniform_range_bfloat( const uint4& rng_result, int rng_component, float from, float to) { auto range = to - from; float uniform01 = raw_uniform_float((&rng_result.x)[rng_component]); __bfloat result = __float2bfloat(from + range * uniform01); return __heq(result, __float2bfloat(to)) ? __float2bfloat(from) : result; } __device__ float normalf(unsigned int x, unsigned int y, int rng_component) { float u = uniformf(x); float v = uniformf(y) * 6.2831855f; if (rng_component % 2 == 0) { return sqrtf(-2.0f * logf(u)) * sinf(v); } else { return sqrtf(-2.0f * logf(u)) * cosf(v); } } __device__ double normal( unsigned int x0, unsigned int x1, unsigned int y0, unsigned int y1, int rng_component) { double u = uniform(x0, x1); double v = uniform(y0, y1) * 6.2831853071795860; if (rng_component % 2 == 0) { return sqrt(-2.0 * log(u)) * sin(v); } else { return sqrt(-2.0 * log(u)) * cos(v); } } __device__ double rng_normal_standard( const uint4& rng_result, int rng_component) { return normal( rng_result.x, rng_result.y, rng_result.z, rng_result.w, rng_component); } __device__ float rng_normal_standardf( const uint4& rng_result, int rng_component) { return normalf( (&rng_result.x)[rng_component / 2 * 2], (&rng_result.y)[rng_component / 2 * 2], rng_component); } __device__ __half rng_normal_standard_half(const uint4& rng_result, int rng_component) { return __float2half(normalf( (&rng_result.x)[rng_component / 2 * 2], (&rng_result.y)[rng_component / 2 * 2], rng_component)); } __device__ __bfloat rng_normal_standard_bfloat(const uint4& rng_result, int rng_component) { return __float2bfloat(normalf( (&rng_result.x)[rng_component / 2 * 2], (&rng_result.y)[rng_component / 2 * 2], rng_component)); } __device__ double rng_normal_general( const uint4& rng_result, int rng_component, double mean, double std) { auto normal01 = rng_normal_standard(rng_result, rng_component); return normal01 * std + mean; } __device__ float rng_normal_generalf( const uint4& rng_result, int rng_component, float mean, float std) { auto normal01 = rng_normal_standardf(rng_result, rng_component); return normal01 * std + mean; } __device__ __half rng_normal_general_half( const uint4& rng_result, int rng_component, float mean, float std) { auto normal01 = normalf( (&rng_result.x)[rng_component / 2 * 2], (&rng_result.y)[rng_component / 2 * 2], rng_component); return __float2half(normal01 * std + mean); } __device__ __bfloat rng_normal_general_bfloat( const uint4& rng_result, int rng_component, float mean, float std) { auto normal01 = normalf( (&rng_result.x)[rng_component / 2 * 2], (&rng_result.y)[rng_component / 2 * 2], rng_component); return __float2bfloat(normal01 * std + mean); } // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on #define NVFUSER_DEFINE_MAGIC_ZERO \ __shared__ int nvfuser_zero_s; \ if (threadIdx.x == 0) \ nvfuser_zero_s = 0; \ __syncthreads(); \ atomicMin(&nvfuser_zero_s, threadIdx.x); \ int nvfuser_zero = nvfuser_zero_s; #define NVFUSER_UPDATE_MAGIC_ZERO \ do { \ nvfuser_zero <<= 1; \ } while (0); #ifdef __NVCC__ #include #endif // __NVCC__ __device__ constexpr int ceilDiv(int a, int b) { return (a + b - 1) / b; } __device__ constexpr int64_t ceilDiv(int64_t a, int64_t b) { return (a + b - 1) / b; } __device__ constexpr int64_t ceilDiv(int64_t a, int b) { return ceilDiv(a, (int64_t)b); } __device__ constexpr int64_t ceilDiv(int a, int64_t b) { return ceilDiv((int64_t)a, b); } __device__ constexpr double ceilDiv(double a, double b) { return std::ceil(a / b); } __device__ constexpr double ceilDiv(double a, int64_t b) { return std::ceil(a / b); } __device__ constexpr double ceilDiv(int64_t a, double b) { return std::ceil(a / b); } // Monotonic and precise lerp is described here: // https://math.stackexchange.com/a/1798323 __device__ double lerp(double start, double end, double weight) { if (weight < 0.5) { return start + weight * (end - start); } else { return end - (end - start) * (1.0 - weight); } } __device__ float lerp(float start, float end, float weight) { if (weight < 0.5f) { return start + weight * (end - start); } else { return end - (end - start) * (1.0f - weight); } } __device__ float lerp(float start, float end, double weight) { return lerp(start, end, static_cast(weight)); } __device__ constexpr int max(int a, int b) { return a > b ? a : b; } __device__ constexpr int64_t max(int64_t a, int b) { return a > (int64_t)b ? a : (int64_t)b; } __device__ constexpr int64_t max(int a, int64_t b) { return (int64_t)a > b ? (int64_t)a : b; } __device__ constexpr int64_t max(int64_t a, int64_t b) { return a > b ? a : b; } __device__ double fmax(double a, double b) { // check and propagate NaN if (a != a) { return a; } else { // If b is nan, it will be returned in the next line return a > b ? a : b; } } __device__ float fmax(float a, float b) { // check and propagate NaN if (a != a) { return a; } else { // If b is nan, it will be returned in the next line return a > b ? a : b; } } __device__ constexpr int min(int a, int b) { return a > b ? b : a; } __device__ constexpr int64_t min(int64_t a, int b) { return (int64_t)a > b ? b : (int64_t)a; } __device__ constexpr int64_t min(int a, int64_t b) { return a > (int64_t)b ? (int64_t)b : a; } __device__ constexpr int64_t min(int64_t a, int64_t b) { return a > b ? b : a; } __device__ double fmin(double a, double b) { // check and propagate NaN if (b != b) { return b; } else { // If a is nan, it will be returned in the next line return a > b ? b : a; } } __device__ float fmin(float a, float b) { // check and propagate NaN if (b != b) { return b; } else { // If a is nan, it will be returned in the next line return a > b ? b : a; } } __device__ constexpr int alignBufferSize(int buffer, int size) { return (buffer + (size - 1)) & ~(size - 1); } __device__ double clamp(double x, double minv, double maxv) { return fmin(fmax(x, minv), maxv); } __device__ float clamp(float x, double minv, double maxv) { return fmin(fmax((double)x, minv), maxv); } __device__ int clamp(int x, int64_t minv, int64_t maxv) { return min(max((int64_t)x, minv), maxv); } __device__ int64_t clamp(int64_t x, int64_t minv, int64_t maxv) { return min(max(x, minv), maxv); } __device__ double frac(double x) { return x - trunc(x); } __device__ float frac(float x) { return x - trunc(x); } __device__ double reciprocal(double x) { return 1 / x; } __device__ float reciprocal(float x) { return 1 / x; } __device__ double relu(double x) { return x <= 0 ? 0 : x; } __device__ float relu(float x) { return x <= 0 ? 0 : x; } __device__ float relu(int64_t x) { return x <= 0 ? 0 : x; } __device__ float relu(int x) { return x <= 0 ? 0 : x; } __device__ double remainder(double a, double b) { auto mod = ::fmod(a, b); if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; return mod; } __device__ float remainder(float a, float b) { auto mod = ::fmod(a, b); if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; return mod; } __device__ double sigmoid(double x) { return 1.0 / (1.0 + exp(-x)); } __device__ float sigmoid(float x) { return 1.0f / (1.0f + exp(-x)); } __device__ double silu(double x) { return x * sigmoid(x); } __device__ float silu(float x) { return x * sigmoid(x); } __device__ double threshold(double x, double t, double v) { return x <= t ? v : x; } __device__ float threshold(float x, double t, double v) { return x <= t ? v : x; } __device__ int threshold(int x, int64_t t, int64_t v) { return x <= t ? v : x; } __device__ int64_t threshold(int64_t x, int64_t t, int64_t v) { return x <= t ? v : x; } __device__ constexpr int64_t remainder(int64_t a, int64_t b) { auto mod = a % b; if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; return mod; } __device__ constexpr int remainder(int a, int b) { auto mod = a % b; if ((mod != 0) && ((b < 0) != (mod < 0))) mod += b; return mod; } __device__ constexpr int64_t fmod(int64_t a, int64_t b) { return a % b; } __device__ constexpr int fmod(int a, int b) { return a % b; } __device__ constexpr double fmod(double a, double b) { return ::fmod(a, b); } __device__ constexpr float fmod(float a, float b) { return ::fmod(a, b); } __device__ constexpr double nextafter(double a, double b) { return ::nextafter(a, b); } __device__ constexpr float nextafter(float a, float b) { return ::nextafterf(a, b); } template __device__ T pow(T a, T b) { if (b < 0) { if (a == 1) { return 1; } else if (a == -1) { auto negative = (-b) % static_cast(2); return negative ? -1 : 1; } else { return 0; } } else { T result = 1; while (b) { if (b & 1) { result *= a; } b /= 2; a *= a; } return result; } } template __device__ int pow(int a, int b); template __device__ int64_t pow(int64_t a, int64_t b); template <> __device__ float pow(float a, float b) { return ::pow(a, b); } template <> __device__ double pow(double a, double b) { return ::pow(a, b); } __device__ float pow(float a, int b) { return pow(a, (float)b); } __device__ double pow(double a, int b) { return pow(a, (double)b); } __device__ float pow(float a, int64_t b) { return pow(a, (float)b); } __device__ double pow(double a, int64_t b) { return pow(a, (double)b); } __device__ int64_t pow(int64_t a, int b) { return pow(a, (int64_t)b); } __device__ int64_t pow(int a, int64_t b) { return pow((int64_t)a, b); } __device__ double rsqrt(double z) { return ::rsqrt(z); } __device__ float rsqrt(float z) { return ::rsqrtf(z); } __device__ int rsqrt(int z) { return ::rsqrtf((float)z); } __device__ int64_t rsqrt(int64_t z) { return ::rsqrt((double)z); } __device__ double signbit(double a) { return ::signbit(a); } __device__ float signbit(float a) { return ::signbit(a); } __device__ int signbit(int a) { return a < 0; } __device__ int64_t signbit(int64_t a) { return a < 0; } // Reference: // https://en.wikipedia.org/wiki/Euclidean_algorithm#Implementations // https://github.com/pytorch/pytorch/blob/c9f4f01981fd73fcc7c27676cc50230cd1b5bc22/aten/src/ATen/native/Math.h#L1232 template __device__ T gcd(T a, T b) { a = abs(a); b = abs(b); while (b != 0) { auto t = b; b = a % b; a = t; } return a; } template bool isfinite(T x) { return ::isfinite(x); } // ref: // https://github.com/NVIDIA/cutlass/blob/6fbc0d33800008d3180d3fefed4e1a653e5f72a0/include/cutlass/bfloat16.h#L213 template <> bool isfinite<__bfloat>(__bfloat x) { const auto exponent_biased = int((x.raw() >> 7) & 0x0ff); return exponent_biased != 0x0ff; } // ref: // https://github.com/NVIDIA/cutlass/blob/6fbc0d33800008d3180d3fefed4e1a653e5f72a0/include/cutlass/half.h#L511 template <> bool isfinite<__half>(__half x) { const auto exponent_biased = int((x.raw() >> 10) & 0x1f); return exponent_biased != 0x1f; } template bool isinf(T x) { return ::isinf(x); } //////////////////////////////////////////////////////////// // TODO: the following overloads are only needed for CUDA // // 10.2 Please remove when CUDA 10.2 support is dropped // //////////////////////////////////////////////////////////// bool isinf(int64_t x) { return false; } bool isinf(int x) { return false; } bool isinf(short x) { return false; } bool isinf(char x) { return false; } bool isinf(unsigned char x) { return false; } bool isinf(bool x) { return false; } bool isfinite(int64_t x) { return true; } bool isfinite(int x) { return true; } bool isfinite(short x) { return true; } bool isfinite(char x) { return true; } bool isfinite(unsigned char x) { return true; } bool isfinite(bool x) { return true; } //////////////////////////////////////////////////////////// // End TODO // //////////////////////////////////////////////////////////// template bool isnan(T x) { return x != x; } template bool isneginf(T x) { return x < 0 && isinf(x); } template bool isposinf(T x) { return x > 0 && isinf(x); } template bool isreal(T x) { return true; } // Return the current value of the cycle counter __device__ inline int64_t readCycleCounter() { // Ensures preceding memory operations are completed. Doing this // would make sense for measuring elapsed times enclosed with this // function. __threadfence(); return clock64(); } __device__ float print_impl(const char* name, float value) { printf( "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, value, (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } __device__ double print_impl(const char* name, double value) { printf( "%s = %lf @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, value, (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } __device__ int print_impl(const char* name, int value) { printf( "%s = %d @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, value, (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } __device__ int64_t print_impl(const char* name, int64_t value) { printf( "%s = %ld @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, value, (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } __device__ bool print_impl(const char* name, bool value) { printf( "%s = %s @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, value ? "true" : "false", (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } __device__ __half print_impl(const char* name, __half value) { printf( "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, __half2float(value), (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } #if __CUDACC_VER_MAJOR__ >= 11 __device__ __bfloat print_impl(const char* name, __bfloat value) { printf( "%s = %f @ threadIdx=(%d,%d,%d), blockIdx=(%d,%d,%d)\n", name, __bfloat2float(value), (int)threadIdx.x, (int)threadIdx.y, (int)threadIdx.z, (int)blockIdx.x, (int)blockIdx.y, (int)blockIdx.z); return value; } #endif #define print(...) print_impl(#__VA_ARGS__, (__VA_ARGS__)) // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on namespace index_utils { // Utility functions // Total size of provided dimension template __device__ __forceinline__ nvfuser_index_t size(const _dim3& d) { return (nvfuser_index_t)d.x * (nvfuser_index_t)d.y * (nvfuser_index_t)d.z; } // Linearized indexing of idx based on dim, if bool==false that dimension does // not participate template __device__ nvfuser_index_t maskedOffset(const _dim3& idx, const _dim3_2& dim) { nvfuser_index_t offset = 0; if (Z) offset += idx.z; if (Y) offset = offset * dim.y + idx.y; if (X) offset = offset * dim.x + idx.x; return offset; } // Linearized indexing of idx based on dim. All dimensions participate. template __device__ nvfuser_index_t offset(const _dim3& idx, const _dim3_2& dim) { nvfuser_index_t offset = idx.z; offset = offset * dim.y + idx.y; offset = offset * dim.x + idx.x; return offset; } // Masks the provided dim3, those == false get truncated to 1 template __device__ dim3 maskedDims(const _dim3& dim) { return dim3{ X ? (unsigned)dim.x : 1U, Y ? (unsigned)dim.y : 1U, Z ? (unsigned)dim.z : 1U}; } // Provides total size of dim with masking, those dims == false do not // participate in the size calculation template __device__ nvfuser_index_t maskedSize(const _dim3& dim) { return size(maskedDims(dim)); } // Checks if provided idx is zero on those dims == true template __device__ bool maskedIsZero(const _dim3& idx) { bool isZero = true; if (X) isZero = isZero && idx.x == 0; if (Y) isZero = isZero && idx.y == 0; if (Z) isZero = isZero && idx.z == 0; return isZero; } // Checks if provided idx is zero on those dims == true template __device__ bool maskedIsLast(const _dim3& idx, const _dim3_2& dim) { bool isZero = true; if (X) isZero = isZero && idx.x == dim.x - 1; if (Y) isZero = isZero && idx.y == dim.y - 1; if (Z) isZero = isZero && idx.z == dim.z - 1; return isZero; } } // namespace index_utils // clang-format off /* * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. * All rights reserved. * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on // std::tuple-like type template struct Tuple; #define TUPLE_INCREMENT_PTR(idx) \ do { \ static_assert( \ IsPointerType::value, "Invalid for non-pointer types"); \ val##idx += offset; \ } while (0) template struct Tuple { T0 val0; Tuple() = default; __device__ Tuple(T0 _val0) : val0(_val0) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); } }; template struct Tuple { T0 val0; T1 val1; Tuple() = default; __device__ Tuple(T0 _val0, T1 _val1) : val0(_val0), val1(_val1) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); } }; template struct Tuple { T0 val0; T1 val1; T2 val2; Tuple() = default; __device__ Tuple(T0 _val0, T1 _val1, T2 _val2) : val0(_val0), val1(_val1), val2(_val2) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); } }; template struct Tuple { T0 val0; T1 val1; T2 val2; T3 val3; Tuple() = default; __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3) : val0(_val0), val1(_val1), val2(_val2), val3(_val3) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); } }; template struct Tuple { T0 val0; T1 val1; T2 val2; T3 val3; T4 val4; Tuple() = default; __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4) : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); TUPLE_INCREMENT_PTR(4); } }; template < typename T0, typename T1, typename T2, typename T3, typename T4, typename T5> struct Tuple { T0 val0; T1 val1; T2 val2; T3 val3; T4 val4; T5 val5; Tuple() = default; __device__ Tuple(T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5) : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4), val5(_val5) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); TUPLE_INCREMENT_PTR(4); TUPLE_INCREMENT_PTR(5); } }; template < typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6> struct Tuple { T0 val0; T1 val1; T2 val2; T3 val3; T4 val4; T5 val5; T6 val6; Tuple() = default; __device__ Tuple( T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5, T6 _val6) : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4), val5(_val5), val6(_val6) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); TUPLE_INCREMENT_PTR(4); TUPLE_INCREMENT_PTR(5); TUPLE_INCREMENT_PTR(6); } }; template < typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename T7> struct Tuple { T0 val0; T1 val1; T2 val2; T3 val3; T4 val4; T5 val5; T6 val6; T7 val7; Tuple() = default; __device__ Tuple( T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5, T6 _val6, T7 _val7) : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4), val5(_val5), val6(_val6), val7(_val7) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); TUPLE_INCREMENT_PTR(4); TUPLE_INCREMENT_PTR(5); TUPLE_INCREMENT_PTR(6); TUPLE_INCREMENT_PTR(7); } }; template < typename T0, typename T1, typename T2, typename T3, typename T4, typename T5, typename T6, typename T7, typename T8, typename T9, typename T10, typename T11, typename T12, typename T13, typename T14, typename T15> struct Tuple< T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15> { T0 val0; T1 val1; T2 val2; T3 val3; T4 val4; T5 val5; T6 val6; T7 val7; T8 val8; T9 val9; T10 val10; T11 val11; T12 val12; T13 val13; T14 val14; T15 val15; Tuple() = default; __device__ Tuple( T0 _val0, T1 _val1, T2 _val2, T3 _val3, T4 _val4, T5 _val5, T6 _val6, T7 _val7, T8 _val8, T9 _val9, T10 _val10, T11 _val11, T12 _val12, T13 _val13, T14 _val14, T15 _val15) : val0(_val0), val1(_val1), val2(_val2), val3(_val3), val4(_val4), val5(_val5), val6(_val6), val7(_val7), val8(_val8), val9(_val9), val10(_val10), val11(_val11), val12(_val12), val13(_val13), val14(_val14), val15(_val15) {} // Only valid when instantiated for pointer types __device__ void operator+=(nvfuser_index_t offset) { TUPLE_INCREMENT_PTR(0); TUPLE_INCREMENT_PTR(1); TUPLE_INCREMENT_PTR(2); TUPLE_INCREMENT_PTR(3); TUPLE_INCREMENT_PTR(4); TUPLE_INCREMENT_PTR(5); TUPLE_INCREMENT_PTR(6); TUPLE_INCREMENT_PTR(7); TUPLE_INCREMENT_PTR(8); TUPLE_INCREMENT_PTR(9); TUPLE_INCREMENT_PTR(10); TUPLE_INCREMENT_PTR(11); TUPLE_INCREMENT_PTR(12); TUPLE_INCREMENT_PTR(13); TUPLE_INCREMENT_PTR(14); TUPLE_INCREMENT_PTR(15); } }; #undef TUPLE_INCREMENT_PTR // Accessor for Tuple template struct get; #define DEFINE_TUPLE_GET(idx) \ template <> \ struct get { \ template \ __device__ auto& operator()(Tuple& vals) { \ return vals.val##idx; \ } \ template \ __device__ const auto& operator()(const Tuple& vals) { \ return vals.val##idx; \ } \ }; DEFINE_TUPLE_GET(0); DEFINE_TUPLE_GET(1); DEFINE_TUPLE_GET(2); DEFINE_TUPLE_GET(3); DEFINE_TUPLE_GET(4); DEFINE_TUPLE_GET(5); DEFINE_TUPLE_GET(6); DEFINE_TUPLE_GET(7); DEFINE_TUPLE_GET(8); DEFINE_TUPLE_GET(9); DEFINE_TUPLE_GET(10); DEFINE_TUPLE_GET(11); DEFINE_TUPLE_GET(12); DEFINE_TUPLE_GET(13); DEFINE_TUPLE_GET(14); DEFINE_TUPLE_GET(15); #undef DEFINE_TUPLE_GET template __inline__ __device__ static void copyTuple( DstType& dst, nvfuser_index_t dst_offset, const SrcType& src, nvfuser_index_t src_offset = 0); template __inline__ __device__ static void copyTuple( DstType& dst, const SrcType& src, nvfuser_index_t src_offset = 0); template __inline__ __device__ static void setTuple( DstType& dst, typename DstType::template ValType<0> src); template class LocalTuple { public: static constexpr int num_vals = sizeof...(Types); using ValTypes = TypeList; template using ValType = typename TypeSelector::type; LocalTuple() = default; __device__ explicit LocalTuple(Types... args) : vals_(args...) {} __device__ LocalTuple(const LocalTuple& other) : vals_(other.vals_) {} template