diff --git a/kernels/portable/cpu/test/scalar_utils_test.cpp b/kernels/portable/cpu/test/scalar_utils_test.cpp index 1983f707da1..c0fa9323431 100644 --- a/kernels/portable/cpu/test/scalar_utils_test.cpp +++ b/kernels/portable/cpu/test/scalar_utils_test.cpp @@ -18,7 +18,12 @@ struct promote_type_with_scalar_type_is_valid std::is_same::value) && !std::is_same::value && !torch::executor::is_qint_type::value && - !torch::executor::is_bits_type::value> {}; + !torch::executor::is_bits_type::value && + !executorch::runtime::is_bits_type::value && + !executorch::runtime::is_float8_type::value && + !executorch::runtime::is_float8_type::value && + !executorch::runtime::is_barebones_unsigned_type::value && + !executorch::runtime::is_barebones_unsigned_type::value> {}; template struct CompileTimePromoteTypeWithScalarTypeTestCase { diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index e500167fa04..3f186c3c647 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -503,6 +503,33 @@ struct is_qint_type : std::integral_constant::value)> { }; +constexpr bool isFloat8Type(::executorch::aten::ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ::executorch::aten::ScalarType::Float8_e5m2 || + t == ::executorch::aten::ScalarType::Float8_e4m3fn || + t == ::executorch::aten::ScalarType::Float8_e5m2fnuz || + t == ::executorch::aten::ScalarType::Float8_e4m3fnuz; +} + +template +struct is_float8_type + : std:: + integral_constant::value)> { +}; + +constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) { + // Don't forget to extend this when adding new QInt types + return t == ::executorch::aten::ScalarType::UInt16 || + t == ::executorch::aten::ScalarType::UInt32 || + t == ::executorch::aten::ScalarType::UInt64; +} + +template +struct is_barebones_unsigned_type + : std::integral_constant< + bool, + isBarebonesUnsignedType(CppTypeToScalarType::value)> {}; + inline ::executorch::aten::ScalarType toQIntType( ::executorch::aten::ScalarType t) { switch (t) { @@ -883,6 +910,15 @@ struct promote_types { std::is_same::value || (!is_bits_type::value && !is_bits_type::value), "promote_types not valid for bits dtypes"); + static_assert( + std::is_same::value || + (!is_float8_type::value && !is_float8_type::value), + "promote_types not valid for float8 dtypes"); + static_assert( + std::is_same::value || + (!is_barebones_unsigned_type::value && + !is_barebones_unsigned_type::value), + "promote_types not valid for barebones unsigned dtypes"); using promoted_type_not_respecting_half_to_float = typename internal::promote_types_lookup::type; @@ -945,6 +981,24 @@ inline ::executorch::aten::ScalarType promoteTypes( ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes"); } + // For Float8 types, only allow exact match + if (::executorch::runtime::isFloat8Type(a) && a == b) { + return a; + } + if (::executorch::runtime::isFloat8Type(a) || + ::executorch::runtime::isFloat8Type(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes"); + } + + // For barebones uint types, only allow exact match + if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) { + return a; + } + if (::executorch::runtime::isBarebonesUnsignedType(a) || + ::executorch::runtime::isBarebonesUnsignedType(b)) { + ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes"); + } + // 12 types are handled by this function, see the constexpr definitions above const int NUM_PROMOTE_TYPES = 13; @@ -1433,8 +1487,10 @@ using ::executorch::runtime::canCast; using ::executorch::runtime::convert; using ::executorch::runtime::CppTypeToScalarType; using ::executorch::runtime::elementSize; +using ::executorch::runtime::is_barebones_unsigned_type; using ::executorch::runtime::is_bits_type; using ::executorch::runtime::is_complex_type; +using ::executorch::runtime::is_float8_type; using ::executorch::runtime::is_integral_type; using ::executorch::runtime::is_qint_type; using ::executorch::runtime::isBitsType; diff --git a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp index 9df01b7be9f..cd18be977cf 100644 --- a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp +++ b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp @@ -170,7 +170,11 @@ struct promote_types_is_valid (!executorch::runtime::is_qint_type::value && !executorch::runtime::is_qint_type::value && !executorch::runtime::is_bits_type::value && - !executorch::runtime::is_bits_type::value))> {}; + !executorch::runtime::is_bits_type::value && + !executorch::runtime::is_float8_type::value && + !executorch::runtime::is_float8_type::value && + !executorch::runtime::is_barebones_unsigned_type::value && + !executorch::runtime::is_barebones_unsigned_type::value))> {}; template struct CompileTimePromoteTypesTestCase { diff --git a/runtime/core/portable_type/scalar_type.h b/runtime/core/portable_type/scalar_type.h index 286aee3387c..dc8142862f8 100644 --- a/runtime/core/portable_type/scalar_type.h +++ b/runtime/core/portable_type/scalar_type.h @@ -47,6 +47,36 @@ namespace executorch { namespace runtime { namespace etensor { +// Placing a bunch of unused dtypes here as our macros don't make it easy +// to skip scalar types defined in aten that we dont have. +namespace unused_dtype { +struct alignas(1) Float8_e5m2 { + uint8_t x; + using underlying = uint8_t; + Float8_e5m2() = default; + explicit Float8_e5m2(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e4m3fn { + uint8_t x; + using underlying = uint8_t; + Float8_e4m3fn() = default; + explicit Float8_e4m3fn(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e5m2fnuz { + uint8_t x; + using underlying = uint8_t; + Float8_e5m2fnuz() = default; + explicit Float8_e5m2fnuz(uint8_t val) : x(val) {} +}; +struct alignas(1) Float8_e4m3fnuz { + uint8_t x; + using underlying = uint8_t; + Float8_e4m3fnuz() = default; + explicit Float8_e4m3fnuz(uint8_t val) : x(val) {} +}; + +} // namespace unused_dtype + /** * Calls the provided macro on every ScalarType, providing the C type and the * ScalarType name to each call. @@ -59,30 +89,42 @@ namespace etensor { * @param _ A macro that takes two parameters: the name of a C type, and the * name of the corresponding ScalarType enumerator. */ -#define ET_FORALL_SCALAR_TYPES(_) \ - _(uint8_t, Byte) /* 0 */ \ - _(int8_t, Char) /* 1 */ \ - _(int16_t, Short) /* 2 */ \ - _(int32_t, Int) /* 3 */ \ - _(int64_t, Long) /* 4 */ \ - _(::torch::executor::Half, Half) /* 5 */ \ - _(float, Float) /* 6 */ \ - _(double, Double) /* 7 */ \ - _(::torch::executor::complex<::torch::executor::Half>, ComplexHalf) /* 8 */ \ - _(::torch::executor::complex, ComplexFloat) /* 9 */ \ - _(::torch::executor::complex, ComplexDouble) /* 10 */ \ - _(bool, Bool) /* 11 */ \ - _(::torch::executor::qint8, QInt8) /* 12 */ \ - _(::torch::executor::quint8, QUInt8) /* 13 */ \ - _(::torch::executor::qint32, QInt32) /* 14 */ \ - _(::torch::executor::BFloat16, BFloat16) /* 15 */ \ - _(::torch::executor::quint4x2, QUInt4x2) /* 16 */ \ - _(::torch::executor::quint2x4, QUInt2x4) /* 17 */ \ - _(::torch::executor::bits1x8, Bits1x8) /* 18 */ \ - _(::torch::executor::bits2x4, Bits2x4) /* 19 */ \ - _(::torch::executor::bits4x2, Bits4x2) /* 20 */ \ - _(::torch::executor::bits8, Bits8) /* 21 */ \ - _(::torch::executor::bits16, Bits16) /* 22 */ +#define ET_FORALL_SCALAR_TYPES(_) \ + _(uint8_t, Byte) /* 0 */ \ + _(int8_t, Char) /* 1 */ \ + _(int16_t, Short) /* 2 */ \ + _(int32_t, Int) /* 3 */ \ + _(int64_t, Long) /* 4 */ \ + _(::executorch::runtime::etensor::Half, Half) /* 5 */ \ + _(float, Float) /* 6 */ \ + _(double, Double) /* 7 */ \ + _(::executorch::runtime::etensor::complex<::torch::executor::Half>, \ + ComplexHalf) /* 8 */ \ + _(::executorch::runtime::etensor::complex, ComplexFloat) /* 9 */ \ + _(::executorch::runtime::etensor::complex, ComplexDouble) /* 10 */ \ + _(bool, Bool) /* 11 */ \ + _(::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \ + _(::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \ + _(::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \ + _(::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \ + _(::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \ + _(::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \ + _(::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \ + _(::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \ + _(::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \ + _(::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \ + _(::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e5m2, \ + Float8_e5m2) /* 23 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, \ + Float8_e4m3fn) /* 24 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, \ + Float8_e5m2fnuz) /* 25 */ \ + _(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, \ + Float8_e4m3fnuz) /* 26 */ \ + _(uint16_t, UInt16) /* 27 */ \ + _(uint32_t, UInt32) /* 28 */ \ + _(uint64_t, UInt64) /* 29 */ /** * Data types (dtypes) that can be used as element types in ETensors.