Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion kernels/portable/cpu/test/scalar_utils_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@ struct promote_type_with_scalar_type_is_valid
std::is_same<T2, torch::executor::internal::F8>::value) &&
!std::is_same<T1, exec_aten::BFloat16>::value &&
!torch::executor::is_qint_type<T1>::value &&
!torch::executor::is_bits_type<T1>::value> {};
!torch::executor::is_bits_type<T1>::value &&
!executorch::runtime::is_bits_type<T2>::value &&
!executorch::runtime::is_float8_type<T1>::value &&
!executorch::runtime::is_float8_type<T2>::value &&
!executorch::runtime::is_barebones_unsigned_type<T1>::value &&
!executorch::runtime::is_barebones_unsigned_type<T2>::value> {};

template <typename T1, bool half_to_float>
struct CompileTimePromoteTypeWithScalarTypeTestCase {
Expand Down
56 changes: 56 additions & 0 deletions runtime/core/exec_aten/util/scalar_type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,33 @@ struct is_qint_type
: std::integral_constant<bool, isQIntType(CppTypeToScalarType<T>::value)> {
};

constexpr bool isFloat8Type(::executorch::aten::ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ::executorch::aten::ScalarType::Float8_e5m2 ||
t == ::executorch::aten::ScalarType::Float8_e4m3fn ||
t == ::executorch::aten::ScalarType::Float8_e5m2fnuz ||
t == ::executorch::aten::ScalarType::Float8_e4m3fnuz;
}

template <typename T>
struct is_float8_type
: std::
integral_constant<bool, isFloat8Type(CppTypeToScalarType<T>::value)> {
};

constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) {
// Don't forget to extend this when adding new QInt types
return t == ::executorch::aten::ScalarType::UInt16 ||
t == ::executorch::aten::ScalarType::UInt32 ||
t == ::executorch::aten::ScalarType::UInt64;
}

template <typename T>
struct is_barebones_unsigned_type
: std::integral_constant<
bool,
isBarebonesUnsignedType(CppTypeToScalarType<T>::value)> {};

inline ::executorch::aten::ScalarType toQIntType(
::executorch::aten::ScalarType t) {
switch (t) {
Expand Down Expand Up @@ -883,6 +910,15 @@ struct promote_types {
std::is_same<T1, T2>::value ||
(!is_bits_type<T1>::value && !is_bits_type<T2>::value),
"promote_types not valid for bits dtypes");
static_assert(
std::is_same<T1, T2>::value ||
(!is_float8_type<T1>::value && !is_float8_type<T2>::value),
"promote_types not valid for float8 dtypes");
static_assert(
std::is_same<T1, T2>::value ||
(!is_barebones_unsigned_type<T1>::value &&
!is_barebones_unsigned_type<T2>::value),
"promote_types not valid for barebones unsigned dtypes");

using promoted_type_not_respecting_half_to_float =
typename internal::promote_types_lookup<T1, T2>::type;
Expand Down Expand Up @@ -945,6 +981,24 @@ inline ::executorch::aten::ScalarType promoteTypes(
ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes");
}

// For Float8 types, only allow exact match
if (::executorch::runtime::isFloat8Type(a) && a == b) {
return a;
}
if (::executorch::runtime::isFloat8Type(a) ||
::executorch::runtime::isFloat8Type(b)) {
ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes");
}

// For barebones uint types, only allow exact match
if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) {
return a;
}
if (::executorch::runtime::isBarebonesUnsignedType(a) ||
::executorch::runtime::isBarebonesUnsignedType(b)) {
ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes");
}

// 12 types are handled by this function, see the constexpr definitions above
const int NUM_PROMOTE_TYPES = 13;

Expand Down Expand Up @@ -1433,8 +1487,10 @@ using ::executorch::runtime::canCast;
using ::executorch::runtime::convert;
using ::executorch::runtime::CppTypeToScalarType;
using ::executorch::runtime::elementSize;
using ::executorch::runtime::is_barebones_unsigned_type;
using ::executorch::runtime::is_bits_type;
using ::executorch::runtime::is_complex_type;
using ::executorch::runtime::is_float8_type;
using ::executorch::runtime::is_integral_type;
using ::executorch::runtime::is_qint_type;
using ::executorch::runtime::isBitsType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ struct promote_types_is_valid
(!executorch::runtime::is_qint_type<T1>::value &&
!executorch::runtime::is_qint_type<T2>::value &&
!executorch::runtime::is_bits_type<T1>::value &&
!executorch::runtime::is_bits_type<T2>::value))> {};
!executorch::runtime::is_bits_type<T2>::value &&
!executorch::runtime::is_float8_type<T1>::value &&
!executorch::runtime::is_float8_type<T2>::value &&
!executorch::runtime::is_barebones_unsigned_type<T1>::value &&
!executorch::runtime::is_barebones_unsigned_type<T2>::value))> {};

template <typename T1, bool half_to_float>
struct CompileTimePromoteTypesTestCase {
Expand Down
90 changes: 66 additions & 24 deletions runtime/core/portable_type/scalar_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,36 @@ namespace executorch {
namespace runtime {
namespace etensor {

// Placing a bunch of unused dtypes here as our macros don't make it easy
// to skip scalar types defined in aten that we dont have.
namespace unused_dtype {
struct alignas(1) Float8_e5m2 {
uint8_t x;
using underlying = uint8_t;
Float8_e5m2() = default;
explicit Float8_e5m2(uint8_t val) : x(val) {}
};
struct alignas(1) Float8_e4m3fn {
uint8_t x;
using underlying = uint8_t;
Float8_e4m3fn() = default;
explicit Float8_e4m3fn(uint8_t val) : x(val) {}
};
struct alignas(1) Float8_e5m2fnuz {
uint8_t x;
using underlying = uint8_t;
Float8_e5m2fnuz() = default;
explicit Float8_e5m2fnuz(uint8_t val) : x(val) {}
};
struct alignas(1) Float8_e4m3fnuz {
uint8_t x;
using underlying = uint8_t;
Float8_e4m3fnuz() = default;
explicit Float8_e4m3fnuz(uint8_t val) : x(val) {}
};

} // namespace unused_dtype

/**
* Calls the provided macro on every ScalarType, providing the C type and the
* ScalarType name to each call.
Expand All @@ -59,30 +89,42 @@ namespace etensor {
* @param _ A macro that takes two parameters: the name of a C type, and the
* name of the corresponding ScalarType enumerator.
*/
#define ET_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int32_t, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(::torch::executor::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(::torch::executor::complex<::torch::executor::Half>, ComplexHalf) /* 8 */ \
_(::torch::executor::complex<float>, ComplexFloat) /* 9 */ \
_(::torch::executor::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(::torch::executor::qint8, QInt8) /* 12 */ \
_(::torch::executor::quint8, QUInt8) /* 13 */ \
_(::torch::executor::qint32, QInt32) /* 14 */ \
_(::torch::executor::BFloat16, BFloat16) /* 15 */ \
_(::torch::executor::quint4x2, QUInt4x2) /* 16 */ \
_(::torch::executor::quint2x4, QUInt2x4) /* 17 */ \
_(::torch::executor::bits1x8, Bits1x8) /* 18 */ \
_(::torch::executor::bits2x4, Bits2x4) /* 19 */ \
_(::torch::executor::bits4x2, Bits4x2) /* 20 */ \
_(::torch::executor::bits8, Bits8) /* 21 */ \
_(::torch::executor::bits16, Bits16) /* 22 */
#define ET_FORALL_SCALAR_TYPES(_) \
_(uint8_t, Byte) /* 0 */ \
_(int8_t, Char) /* 1 */ \
_(int16_t, Short) /* 2 */ \
_(int32_t, Int) /* 3 */ \
_(int64_t, Long) /* 4 */ \
_(::executorch::runtime::etensor::Half, Half) /* 5 */ \
_(float, Float) /* 6 */ \
_(double, Double) /* 7 */ \
_(::executorch::runtime::etensor::complex<::torch::executor::Half>, \
ComplexHalf) /* 8 */ \
_(::executorch::runtime::etensor::complex<float>, ComplexFloat) /* 9 */ \
_(::executorch::runtime::etensor::complex<double>, ComplexDouble) /* 10 */ \
_(bool, Bool) /* 11 */ \
_(::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \
_(::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \
_(::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \
_(::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \
_(::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \
_(::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \
_(::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \
_(::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \
_(::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \
_(::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \
_(::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2, \
Float8_e5m2) /* 23 */ \
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, \
Float8_e4m3fn) /* 24 */ \
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, \
Float8_e5m2fnuz) /* 25 */ \
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, \
Float8_e4m3fnuz) /* 26 */ \
_(uint16_t, UInt16) /* 27 */ \
_(uint32_t, UInt32) /* 28 */ \
_(uint64_t, UInt64) /* 29 */

/**
* Data types (dtypes) that can be used as element types in ETensors.
Expand Down
Loading