diff --git a/kernels/portable/cpu/op_to_copy.cpp b/kernels/portable/cpu/op_to_copy.cpp index 7ecd4f3b5e1..c0c04e65e93 100644 --- a/kernels/portable/cpu/op_to_copy.cpp +++ b/kernels/portable/cpu/op_to_copy.cpp @@ -46,10 +46,11 @@ Tensor& to_copy_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] { - _to_impl(self, out); - }); + ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES( + out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] { + _to_impl(self, out); + }); }); return out; diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 3daf3e72526..3d6dfb75e47 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -94,12 +94,6 @@ struct promote_type_with_scalar_type { static_assert( !is_bits_type::value, "promote_type_with_scalar_type not valid for bits dtypes"); - static_assert( - !std::is_same< - T1, - typename ScalarTypeToCppType::type>:: - value, - "promote_type_with_scalar_type not valid for BFloat16"); using promote_type_with_scalar_type_not_respecting_half_to_float = typename std::conditional< is_complex_type::value || @@ -119,10 +113,14 @@ struct promote_type_with_scalar_type { public: using type = typename std::conditional< half_to_float && - std::is_same< - promote_type_with_scalar_type_not_respecting_half_to_float, - typename ScalarTypeToCppType::type>:: - value, + (std::is_same< + promote_type_with_scalar_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::Half>::type>::value || + std::is_same< + promote_type_with_scalar_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::BFloat16>::type>::value), typename ScalarTypeToCppType::type, promote_type_with_scalar_type_not_respecting_half_to_float>::type; }; diff --git a/kernels/test/op_to_copy_test.cpp b/kernels/test/op_to_copy_test.cpp index 1cc892dedbe..0a6529e736d 100644 --- a/kernels/test/op_to_copy_test.cpp +++ b/kernels/test/op_to_copy_test.cpp @@ -36,7 +36,9 @@ typedef std::map< std::type_index, std::variant< std::vector, - std::vector>> + std::vector, + std::vector, + std::vector>> FloatingTypeToDataMap; typedef std::map< @@ -309,9 +311,9 @@ TEST_F(OpToTest, AllDtypesSupported) { ScalarType::OUTPUT_DTYPE>(test_cases); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_REALHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL @@ -323,14 +325,14 @@ TEST_F(OpToTest, BoolTests) { #define TEST_TO_BOOL(INPUT_CTYPE, INPUT_DTYPE) \ test_runner_to_bool( \ test_case_to_bool, result_to_bool); - ET_FORALL_REAL_TYPES(TEST_TO_BOOL); + ET_FORALL_REALHBF16_TYPES(TEST_TO_BOOL); std::vector test_case_from_bool = {true, true, false}; std::vector result_from_bool = {1.0, 1.0, 0}; #define TEST_FROM_BOOL(OUTPUT_CTYPE, OUTPUT_DTYPE) \ test_runner_from_bool( \ test_case_from_bool, result_from_bool); - ET_FORALL_REAL_TYPES(TEST_FROM_BOOL); + ET_FORALL_REALHBF16_TYPES(TEST_FROM_BOOL); } TEST_F(OpToTest, NanInfSupported) { @@ -349,9 +351,9 @@ TEST_F(OpToTest, NanInfSupported) { ScalarType::OUTPUT_DTYPE>(test_cases); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL @@ -381,6 +383,13 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { -0.30919688936285893988}; // clang-format on + std::vector half_data; + std::vector bf16_data; + for (auto d : double_data) { + half_data.emplace_back(d); + bf16_data.emplace_back(d); + } + std::vector int64_data = { -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; std::vector int32_data = { @@ -394,6 +403,8 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { FloatingTypeToDataMap floating_point_data; floating_point_data[typeid(float)] = float_data; floating_point_data[typeid(double)] = double_data; + floating_point_data[typeid(exec_aten::Half)] = half_data; + floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data; // Gathering all int data together for better traversial IntTypeToDataMap int_data; @@ -412,7 +423,7 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); } TEST_F(OpToTest, MismatchedSizesDie) { diff --git a/runtime/core/exec_aten/exec_aten.h b/runtime/core/exec_aten/exec_aten.h index 919b5420b3a..808d31502a9 100644 --- a/runtime/core/exec_aten/exec_aten.h +++ b/runtime/core/exec_aten/exec_aten.h @@ -17,6 +17,7 @@ #include // @manual #include // @manual #include // @manual +#include // @manual #include // @manual #include // @manual #include // @manual @@ -31,6 +32,7 @@ #else // use executor #include // @manual #include // @manual +#include // @manual #include // @manual #include // @manual #include // @manual diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index 03dffd208f0..1fd751dc882 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -16,6 +16,8 @@ #include #include +using exec_aten::BFloat16; +using exec_aten::Half; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -32,9 +34,7 @@ namespace { * T must be a floating point type. Non-floating point data should be compared * directly. */ -template < - typename T, - typename = std::enable_if_t::value>> +template bool data_is_close( const T* a, const T* b, @@ -119,6 +119,20 @@ bool tensors_are_close( a.numel(), rtol, atol); + } else if (a.scalar_type() == ScalarType::Half) { + return data_is_close( + a.const_data_ptr(), + b.const_data_ptr(), + a.numel(), + rtol, + atol); + } else if (a.scalar_type() == ScalarType::BFloat16) { + return data_is_close( + a.const_data_ptr(), + b.const_data_ptr(), + a.numel(), + rtol, + atol); } else { // Non-floating-point types can be compared bitwise. return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0; diff --git a/runtime/core/exec_aten/util/genScalarTypeTable.py b/runtime/core/exec_aten/util/genScalarTypeTable.py index 07100472ae4..c2bc84c2700 100644 --- a/runtime/core/exec_aten/util/genScalarTypeTable.py +++ b/runtime/core/exec_aten/util/genScalarTypeTable.py @@ -4,20 +4,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -indexToType = ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"] +indexToType = [ + "U1", + "I1", + "I2", + "I4", + "I8", + "F2", + "F4", + "F8", + "C2", + "C4", + "C8", + "B1", + "BF", +] promoteTypesLookup = [ - ["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1"], - ["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1"], - ["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2"], - ["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4"], - ["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8"], - ["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2"], - ["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4"], - ["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8"], - ["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2"], - ["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4"], - ["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"], - ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"], + ["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1", "BF"], + ["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1", "BF"], + ["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2", "BF"], + ["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4", "BF"], + ["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8", "BF"], + ["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2", "F4"], + ["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4", "F4"], + ["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8", "F8"], + ["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2", "C4"], + ["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4", "C4"], + ["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"], + ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1", "BF"], + ["BF", "BF", "BF", "BF", "BF", "F4", "F4", "F8", "C4", "C4", "C8", "BF", "BF"], ] for rowIndex, row in enumerate(promoteTypesLookup): for colIndex, col in enumerate(row): diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index c92f910431f..479767b4abb 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -21,6 +21,7 @@ #pragma once +#include #include #include #include @@ -164,8 +165,21 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) ::exec_aten::ScalarType::SCALARTYPE>::type, \ SCALARTYPE) +#define ET_FORALL_FLOAT_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + #define ET_FORALL_FLOATH_TYPES(_) ET_FORALL_FLOAT_TYPES_AND(Half, _) +#define ET_FORALL_FLOATHBF16_TYPES(_) \ + ET_FORALL_FLOAT_TYPES_AND2(Half, BFloat16, _) + // Here `ANOTHER_INPUT` should be another variable to be forwarded to a given // function. Not to be confused with another scalar type as in // `ET_FORALL_FLOAT_TYPES_AND`. @@ -177,6 +191,12 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) +#define ET_FORALL_FLOATHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::exec_aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::exec_aten::BFloat16, BFloat16) + // In this context, "REAL" means integer/float C types, which is why BFloat16 // and Half are not included. #define ET_FORALL_REAL_TYPES(_) \ @@ -209,6 +229,17 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) +#define ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, exec_aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, exec_aten::BFloat16, BFloat16) + // For macros that take `SCALARTYPEn` parameters, those parameters should be // an unquoted/unqualified enumerator name like `Int` or `Float`. #define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \ @@ -223,8 +254,29 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) ::exec_aten::ScalarType::SCALARTYPE>::type, \ SCALARTYPE) +#define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + #define ET_FORALL_REALH_TYPES(_) ET_FORALL_REAL_TYPES_AND(Half, _) +#define ET_FORALL_REALHBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND2(Half, BFloat16, _) + +#define ET_FORALL_REALHBBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND3(Bool, Half, BFloat16, _) + #define ET_FORALL_REAL_TYPES_AND_WITH(SCALARTYPE, ANOTHER_INPUT, _) \ _(ANOTHER_INPUT, uint8_t, Byte) \ _(ANOTHER_INPUT, int8_t, Char) \ @@ -381,6 +433,10 @@ inline bool isRealHBType(exec_aten::ScalarType t) { return (isRealHType(t) || t == exec_aten::ScalarType::Bool); } +inline bool isRealHBBF16Type(exec_aten::ScalarType t) { + return (isRealHBType(t) || t == exec_aten::ScalarType::BFloat16); +} + inline constexpr bool isComplexType(exec_aten::ScalarType t) { return ( t == exec_aten::ScalarType::ComplexHalf || @@ -589,6 +645,7 @@ using C4 = using C8 = typename ScalarTypeToCppType::type; using B1 = typename ScalarTypeToCppType::type; +using BF = typename ScalarTypeToCppType::type; #define TABLE_ENTRY(key1, key2, value) \ template <> \ @@ -613,6 +670,7 @@ TABLE_ENTRY(U1, C2, C2); TABLE_ENTRY(U1, C4, C4); TABLE_ENTRY(U1, C8, C8); TABLE_ENTRY(U1, B1, U1); +TABLE_ENTRY(U1, BF, BF); TABLE_ENTRY(I1, U1, I2); TABLE_ENTRY(I1, I1, I1); TABLE_ENTRY(I1, I2, I2); @@ -625,6 +683,7 @@ TABLE_ENTRY(I1, C2, C2); TABLE_ENTRY(I1, C4, C4); TABLE_ENTRY(I1, C8, C8); TABLE_ENTRY(I1, B1, I1); +TABLE_ENTRY(I1, BF, BF); TABLE_ENTRY(I2, U1, I2); TABLE_ENTRY(I2, I1, I2); TABLE_ENTRY(I2, I2, I2); @@ -637,6 +696,7 @@ TABLE_ENTRY(I2, C2, C2); TABLE_ENTRY(I2, C4, C4); TABLE_ENTRY(I2, C8, C8); TABLE_ENTRY(I2, B1, I2); +TABLE_ENTRY(I2, BF, BF); TABLE_ENTRY(I4, U1, I4); TABLE_ENTRY(I4, I1, I4); TABLE_ENTRY(I4, I2, I4); @@ -649,6 +709,7 @@ TABLE_ENTRY(I4, C2, C2); TABLE_ENTRY(I4, C4, C4); TABLE_ENTRY(I4, C8, C8); TABLE_ENTRY(I4, B1, I4); +TABLE_ENTRY(I4, BF, BF); TABLE_ENTRY(I8, U1, I8); TABLE_ENTRY(I8, I1, I8); TABLE_ENTRY(I8, I2, I8); @@ -661,6 +722,7 @@ TABLE_ENTRY(I8, C2, C2); TABLE_ENTRY(I8, C4, C4); TABLE_ENTRY(I8, C8, C8); TABLE_ENTRY(I8, B1, I8); +TABLE_ENTRY(I8, BF, BF); TABLE_ENTRY(F2, U1, F2); TABLE_ENTRY(F2, I1, F2); TABLE_ENTRY(F2, I2, F2); @@ -673,6 +735,7 @@ TABLE_ENTRY(F2, C2, C2); TABLE_ENTRY(F2, C4, C4); TABLE_ENTRY(F2, C8, C8); TABLE_ENTRY(F2, B1, F2); +TABLE_ENTRY(F2, BF, F4); TABLE_ENTRY(F4, U1, F4); TABLE_ENTRY(F4, I1, F4); TABLE_ENTRY(F4, I2, F4); @@ -685,6 +748,7 @@ TABLE_ENTRY(F4, C2, C4); TABLE_ENTRY(F4, C4, C4); TABLE_ENTRY(F4, C8, C8); TABLE_ENTRY(F4, B1, F4); +TABLE_ENTRY(F4, BF, F4); TABLE_ENTRY(F8, U1, F8); TABLE_ENTRY(F8, I1, F8); TABLE_ENTRY(F8, I2, F8); @@ -697,6 +761,7 @@ TABLE_ENTRY(F8, C2, C8); TABLE_ENTRY(F8, C4, C8); TABLE_ENTRY(F8, C8, C8); TABLE_ENTRY(F8, B1, F8); +TABLE_ENTRY(F8, BF, F8); TABLE_ENTRY(C2, U1, C2); TABLE_ENTRY(C2, I1, C2); TABLE_ENTRY(C2, I2, C2); @@ -709,6 +774,7 @@ TABLE_ENTRY(C2, C2, C2); TABLE_ENTRY(C2, C4, C4); TABLE_ENTRY(C2, C8, C8); TABLE_ENTRY(C2, B1, C2); +TABLE_ENTRY(C2, BF, C4); TABLE_ENTRY(C4, U1, C4); TABLE_ENTRY(C4, I1, C4); TABLE_ENTRY(C4, I2, C4); @@ -721,6 +787,7 @@ TABLE_ENTRY(C4, C2, C4); TABLE_ENTRY(C4, C4, C4); TABLE_ENTRY(C4, C8, C8); TABLE_ENTRY(C4, B1, C4); +TABLE_ENTRY(C4, BF, C4); TABLE_ENTRY(C8, U1, C8); TABLE_ENTRY(C8, I1, C8); TABLE_ENTRY(C8, I2, C8); @@ -733,6 +800,7 @@ TABLE_ENTRY(C8, C2, C8); TABLE_ENTRY(C8, C4, C8); TABLE_ENTRY(C8, C8, C8); TABLE_ENTRY(C8, B1, C8); +TABLE_ENTRY(C8, BF, C8); TABLE_ENTRY(B1, U1, U1); TABLE_ENTRY(B1, I1, I1); TABLE_ENTRY(B1, I2, I2); @@ -745,6 +813,20 @@ TABLE_ENTRY(B1, C2, C2); TABLE_ENTRY(B1, C4, C4); TABLE_ENTRY(B1, C8, C8); TABLE_ENTRY(B1, B1, B1); +TABLE_ENTRY(B1, BF, BF); +TABLE_ENTRY(BF, U1, BF); +TABLE_ENTRY(BF, I1, BF); +TABLE_ENTRY(BF, I2, BF); +TABLE_ENTRY(BF, I4, BF); +TABLE_ENTRY(BF, I8, BF); +TABLE_ENTRY(BF, F2, F4); +TABLE_ENTRY(BF, F4, F4); +TABLE_ENTRY(BF, F8, F8); +TABLE_ENTRY(BF, C2, C4); +TABLE_ENTRY(BF, C4, C4); +TABLE_ENTRY(BF, C8, C8); +TABLE_ENTRY(BF, B1, BF); +TABLE_ENTRY(BF, BF, BF); } // namespace internal @@ -760,26 +842,20 @@ struct promote_types { (!is_bits_type::value && !is_bits_type::value), "promote_types not valid for bits dtypes"); - static_assert( - !std::is_same< - T1, - typename ScalarTypeToCppType::type>:: - value && - !std::is_same< - T2, - typename ScalarTypeToCppType< - exec_aten::ScalarType::BFloat16>::type>::value, - "promote_types not valid for BFloat16"); using promoted_type_not_respecting_half_to_float = typename internal::promote_types_lookup::type; public: using type = typename std::conditional< half_to_float && - std::is_same< - promoted_type_not_respecting_half_to_float, - typename ScalarTypeToCppType::type>:: - value, + (std::is_same< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::Half>::type>::value || + std::is_same< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::BFloat16>::type>::value), typename ScalarTypeToCppType::type, promoted_type_not_respecting_half_to_float>::type; }; @@ -787,7 +863,8 @@ struct promote_types { /** * Implements type promotion rules that are consistent with ATen behaviour, * which in turn is consistent with NumPy's promote_types. - * If half_to_float is set to true, then half will be promoted to float instead + * If half_to_float is set to true, then half and bfloat16 will be promoted to + * float instead */ inline exec_aten::ScalarType promoteTypes( exec_aten::ScalarType a, @@ -806,6 +883,7 @@ inline exec_aten::ScalarType promoteTypes( constexpr auto c4 = exec_aten::ScalarType::ComplexFloat; constexpr auto c8 = exec_aten::ScalarType::ComplexDouble; constexpr auto b1 = exec_aten::ScalarType::Bool; + constexpr auto bf = exec_aten::ScalarType::BFloat16; // For QInt types, only allow exact match if (executorch::runtime::isQIntType(a) && a == b) { @@ -825,34 +903,41 @@ inline exec_aten::ScalarType promoteTypes( ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes"); } - ET_CHECK_MSG( - a != exec_aten::ScalarType::BFloat16 && - b != exec_aten::ScalarType::BFloat16, - "promoteTypes not valid for BFloat16"); // 12 types are handled by this function, see the constexpr definitions above - const int NUM_PROMOTE_TYPES = 12; - + const int NUM_PROMOTE_TYPES = 13; + + static constexpr std::array + dtype2index = {{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, + }}; + auto ix_a = dtype2index[(int)a]; + ET_CHECK(ix_a != -1); + auto ix_b = dtype2index[(int)b]; + ET_CHECK(ix_b != -1); static constexpr exec_aten::ScalarType _promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8}, - /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1}, + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, }; - exec_aten::ScalarType promoted_type = - _promoteTypesLookup[static_cast(a)][static_cast(b)]; + exec_aten::ScalarType promoted_type = _promoteTypesLookup[ix_a][ix_b]; - if (half_to_float && promoted_type == exec_aten::ScalarType::Half) { + if (half_to_float && + (promoted_type == exec_aten::ScalarType::Half || + promoted_type == exec_aten::ScalarType::BFloat16)) { promoted_type = exec_aten::ScalarType::Float; } @@ -974,6 +1059,13 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + exec_aten::ScalarType::ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1001,6 +1093,13 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \ + ADDITIONAL1, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + exec_aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1112,6 +1211,22 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \ ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) +#define ET_SWITCH_REAL_TYPES_AND3( \ + ADDITIONAL1, \ + ADDITIONAL2, \ + ADDITIONAL3, \ + TYPE, \ + CONTEXT, \ + NAME, \ + CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__)) + #define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_REAL_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) @@ -1122,6 +1237,10 @@ inline exec_aten::ScalarType promoteTypes( ET_SWITCH_REAL_TYPES_AND2( \ Half, Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) +#define ET_SWITCH_REALHBBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND3( \ + Half, Bool, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + #define ET_SWITCH_INT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \ TYPE, \ @@ -1154,9 +1273,22 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \ ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)) +#define ET_SWITCH_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) + #define ET_SWITCH_FLOATH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_FLOAT_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) +#define ET_SWITCH_FLOATHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_FLOAT_TYPES_AND2( \ + Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + #define ET_SWITCH_QINT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \ TYPE, \ diff --git a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp index b91c7009f45..9df01b7be9f 100644 --- a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp +++ b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp @@ -139,37 +139,38 @@ TEST(ScalarTypeUtilTest, promoteTypesTest) { // Check some common cases - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Double) == - ScalarType::Double); - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Short) == ScalarType::Float); - - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Int) == ScalarType::Float); - ET_CHECK( - promoteTypes(ScalarType::Long, ScalarType::Float) == ScalarType::Float); - - ET_CHECK( - promoteTypes(ScalarType::Bool, ScalarType::Bool) == ScalarType::Bool); - - ET_CHECK(promoteTypes(ScalarType::Byte, ScalarType::Int) == ScalarType::Int); - ET_CHECK( - promoteTypes(ScalarType::Char, ScalarType::Bool) == ScalarType::Char); - ET_CHECK(promoteTypes(ScalarType::Bool, ScalarType::Int) == ScalarType::Int); + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Double), ScalarType::Double); + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Short), ScalarType::Float); + + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Int), ScalarType::Float); + EXPECT_EQ( + promoteTypes(ScalarType::Long, ScalarType::Float), ScalarType::Float); + + EXPECT_EQ(promoteTypes(ScalarType::Bool, ScalarType::Bool), ScalarType::Bool); + + EXPECT_EQ(promoteTypes(ScalarType::Byte, ScalarType::Int), ScalarType::Int); + EXPECT_EQ(promoteTypes(ScalarType::Char, ScalarType::Bool), ScalarType::Char); + EXPECT_EQ(promoteTypes(ScalarType::Bool, ScalarType::Int), ScalarType::Int); + + EXPECT_EQ( + promoteTypes(ScalarType::BFloat16, ScalarType::Half), ScalarType::Float); + EXPECT_EQ( + promoteTypes(ScalarType::BFloat16, ScalarType::Bool), + ScalarType::BFloat16); } template struct promote_types_is_valid : std::integral_constant< bool, - !std::is_same::value && - !std::is_same::value && - (std::is_same::value || - (!executorch::runtime::is_qint_type::value && - !executorch::runtime::is_qint_type::value && - !executorch::runtime::is_bits_type::value && - !executorch::runtime::is_bits_type::value))> {}; + (std::is_same::value || + (!executorch::runtime::is_qint_type::value && + !executorch::runtime::is_qint_type::value && + !executorch::runtime::is_bits_type::value && + !executorch::runtime::is_bits_type::value))> {}; template struct CompileTimePromoteTypesTestCase { @@ -195,7 +196,8 @@ struct CompileTimePromoteTypesTestCase { auto expected = executorch::runtime::promoteTypes( scalarType1, scalarType2, half_to_float); EXPECT_EQ(actual, expected) - << "promoting " << (int)scalarType1 << " to " << (int)scalarType2; + << "promoting " << (int)scalarType1 << " to " << (int)scalarType2 + << " (half to float: " << half_to_float << ')'; } template < diff --git a/runtime/core/portable_type/bfloat16.h b/runtime/core/portable_type/bfloat16.h index a1ceb0c56a7..e665e6152e3 100644 --- a/runtime/core/portable_type/bfloat16.h +++ b/runtime/core/portable_type/bfloat16.h @@ -8,11 +8,41 @@ #pragma once +#include #include +#include +#include +#include namespace torch { namespace executor { +namespace internal { +inline float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + std::memcpy(&res, &tmp, sizeof(tmp)); + return res; +} + +inline uint16_t bits_from_f32(float src) { + uint32_t res = 0; + std::memcpy(&res, &src, sizeof(res)); + return res >> 16; +} + +inline uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } + uint32_t U32 = 0; + std::memcpy(&U32, &src, sizeof(U32)); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); +} +} // namespace internal + /** * The "brain floating-point" type, compatible with c10/util/BFloat16.h from * pytorch core. @@ -22,7 +52,288 @@ namespace executor { */ struct alignas(2) BFloat16 { uint16_t x; + + BFloat16() = default; + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} + /* implicit */ BFloat16(float value) + : x(internal::round_to_nearest_even(value)) {} + operator float() const { + return internal::f32_from_bits(x); + } }; +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { + out << (float)value; + return out; +} + +/// Arithmetic + +inline BFloat16 operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline BFloat16 operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline BFloat16 operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline BFloat16 operator/(const BFloat16& a, const BFloat16& b) { + return static_cast(a) / static_cast(b); +} + +inline BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; +} + +inline BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; +} + +inline BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; +} + +inline BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; +} + +inline BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + } // namespace executor } // namespace torch + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr torch::executor::BFloat16 min() { + return torch::executor::BFloat16( + 0x0080, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 lowest() { + return torch::executor::BFloat16( + 0xFF7F, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 max() { + return torch::executor::BFloat16( + 0x7F7F, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 epsilon() { + return torch::executor::BFloat16( + 0x3C00, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 round_error() { + return torch::executor::BFloat16( + 0x3F00, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 infinity() { + return torch::executor::BFloat16( + 0x7F80, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 quiet_NaN() { + return torch::executor::BFloat16( + 0x7FC0, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 signaling_NaN() { + return torch::executor::BFloat16( + 0x7F80, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 denorm_min() { + return torch::executor::BFloat16( + 0x0001, torch::executor::BFloat16::from_bits()); + } +}; + +} // namespace std diff --git a/runtime/core/portable_type/bfloat16_math.h b/runtime/core/portable_type/bfloat16_math.h new file mode 100644 index 00000000000..68ee77cf340 --- /dev/null +++ b/runtime/core/portable_type/bfloat16_math.h @@ -0,0 +1,290 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include + +namespace std { + +template +struct is_reduced_floating_point + : std::integral_constant< + bool, + std::is_same::value || + std::is_same::value> {}; + +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T acos(T a) { + return std::acos(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T asin(T a) { + return std::asin(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T atan(T a) { + return std::atan(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T atanh(T a) { + return std::atanh(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T erf(T a) { + return std::erf(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T erfc(T a) { + return std::erfc(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T exp(T a) { + return std::exp(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T expm1(T a) { + return std::expm1(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline bool isfinite(T a) { + return std::isfinite(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T log(T a) { + return std::log(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T log10(T a) { + return std::log10(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T log1p(T a) { + return std::log1p(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T log2(T a) { + return std::log2(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T ceil(T a) { + return std::ceil(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T cos(T a) { + return std::cos(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T floor(T a) { + return std::floor(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T nearbyint(T a) { + return std::nearbyint(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T sin(T a) { + return std::sin(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T tan(T a) { + return std::tan(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T sinh(T a) { + return std::sinh(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T cosh(T a) { + return std::cosh(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T tanh(T a) { + return std::tanh(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T trunc(T a) { + return std::trunc(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T lgamma(T a) { + return std::lgamma(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T sqrt(T a) { + return std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T rsqrt(T a) { + return 1.0 / std::sqrt(float(a)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T abs(T a) { + return std::abs(float(a)); +} +#if defined(_MSC_VER) && defined(__CUDACC__) +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T pow(T a, double b) { + return std::pow(float(a), float(b)); +} +#else +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T pow(T a, double b) { + return std::pow(float(a), b); +} +#endif +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T pow(T a, T b) { + return std::pow(float(a), float(b)); +} +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T fmod(T a, T b) { + return std::fmod(float(a), float(b)); +} + +/* + The following function is inspired from the implementation in `musl` + Link to License: https://git.musl-libc.org/cgit/musl/tree/COPYRIGHT + ---------------------------------------------------------------------- + Copyright © 2005-2020 Rich Felker, et al. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + ---------------------------------------------------------------------- + */ +template < + typename T, + typename std::enable_if::value, int>::type = 0> +inline T nextafter(T from, T to) { + // Reference: + // https://git.musl-libc.org/cgit/musl/tree/src/math/nextafter.c + using int_repr_t = uint16_t; + constexpr uint8_t bits = 16; + union { + T f; + int_repr_t i; + } ufrom = {from}, uto = {to}; + + // get a mask to get the sign bit i.e. MSB + int_repr_t sign_mask = int_repr_t{1} << (bits - 1); + + // short-circuit: if either is NaN, return NaN + if (from != from || to != to) { + return from + to; + } + + // short-circuit: if they are exactly the same. + if (ufrom.i == uto.i) { + return from; + } + + // mask the sign-bit to zero i.e. positive + // equivalent to abs(x) + int_repr_t abs_from = ufrom.i & ~sign_mask; + int_repr_t abs_to = uto.i & ~sign_mask; + if (abs_from == 0) { + // if both are zero but with different sign, + // preserve the sign of `to`. + if (abs_to == 0) { + return to; + } + // smallest subnormal with sign of `to`. + ufrom.i = (uto.i & sign_mask) | int_repr_t{1}; + return ufrom.f; + } + + // if abs(from) > abs(to) or sign(from) != sign(to) + if (abs_from > abs_to || ((ufrom.i ^ uto.i) & sign_mask)) { + ufrom.i--; + } else { + ufrom.i++; + } + + return ufrom.f; +} + +} // namespace std diff --git a/runtime/core/portable_type/targets.bzl b/runtime/core/portable_type/targets.bzl index 0d65ef36b85..b8ccbe602ed 100644 --- a/runtime/core/portable_type/targets.bzl +++ b/runtime/core/portable_type/targets.bzl @@ -43,6 +43,7 @@ def define_common_targets(): name = "scalar_type", exported_headers = [ "bfloat16.h", + "bfloat16_math.h", "complex.h", "half.h", "scalar_type.h", diff --git a/runtime/core/portable_type/test/CMakeLists.txt b/runtime/core/portable_type/test/CMakeLists.txt index 21eb4feae0f..58a69f656eb 100644 --- a/runtime/core/portable_type/test/CMakeLists.txt +++ b/runtime/core/portable_type/test/CMakeLists.txt @@ -24,7 +24,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/build/Test.cmake) set(_test_srcs optional_test.cpp tensor_test.cpp half_test.cpp scalar_test.cpp - tensor_impl_test.cpp + tensor_impl_test.cpp bfloat16_test.cpp ) et_cxx_test(runtime_core_portable_type_test SOURCES ${_test_srcs} EXTRA_LIBS) diff --git a/runtime/core/portable_type/test/bfloat16_test.cpp b/runtime/core/portable_type/test/bfloat16_test.cpp new file mode 100644 index 00000000000..9ea53e6cba2 --- /dev/null +++ b/runtime/core/portable_type/test/bfloat16_test.cpp @@ -0,0 +1,191 @@ +#include + +#include + +using torch::executor::BFloat16; + +namespace { +float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t bytes; + bytes = 0; + bytes |= sign; + bytes <<= 8; + bytes |= exponent; + bytes <<= 23; + bytes |= fraction; + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} + +TEST(BFloat16Conversion, FloatToBFloat16AndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; + + for (int i = 0; i < 100; ++i) { + bfloats[i].x = torch::executor::internal::bits_from_f32(in[i]); + out[i] = torch::executor::internal::f32_from_bits(bfloats[i].x); + + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } +} + +TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; + + for (int i = 0; i < 100; ++i) { + bfloats[i].x = torch::executor::internal::round_to_nearest_even(in[i]); + out[i] = torch::executor::internal::f32_from_bits(bfloats[i].x); + + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } +} + +TEST(BFloat16Conversion, NaN) { + float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); + EXPECT_TRUE(std::isnan(inNaN)); + + BFloat16 a = BFloat16(inNaN); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_TRUE(std::isnan(out)); +} + +TEST(BFloat16Conversion, Inf) { + float inInf = float_from_bytes(0, 0xFF, 0); + EXPECT_TRUE(std::isinf(inInf)); + + BFloat16 a = BFloat16(inInf); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_TRUE(std::isinf(out)); +} + +TEST(BFloat16Conversion, SmallestDenormal) { + float in = std::numeric_limits::denorm_min(); // The smallest non-zero + // subnormal number + BFloat16 a = BFloat16(in); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_FLOAT_EQ(in, out); +} + +TEST(BFloat16Math, Addition) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after addition, we should have no loss in precision. + + // input bits + // S | Exponent | Mantissa + // 0 | 10000000 | 10010000000000000000000 = 3.125 + float input = float_from_bytes(0, 0, 0x40480000); + + // expected bits + // S | Exponent | Mantissa + // 0 | 10000001 | 10010000000000000000000 = 6.25 + float expected = float_from_bytes(0, 0, 0x40c80000); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + BFloat16 b; + b.x = torch::executor::internal::bits_from_f32(input); + b = b + b; + + float res = torch::executor::internal::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} + +TEST(BFloat16Math, Subtraction) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after subtraction, we should have no loss in precision. + + // input bits + // S | Exponent | Mantissa + // 0 | 10000001 | 11101000000000000000000 = 7.625 + float input = float_from_bytes(0, 0, 0x40f40000); + + // expected bits + // S | Exponent | Mantissa + // 0 | 10000000 | 01010000000000000000000 = 2.625 + float expected = float_from_bytes(0, 0, 0x40280000); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + BFloat16 b; + b.x = torch::executor::internal::bits_from_f32(input); + b = b - 5; + + float res = torch::executor::internal::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Math, NextAfterZero) { + const BFloat16 zero{0}; + + auto check_nextafter = [](BFloat16 from, BFloat16 to, BFloat16 expected) { + BFloat16 actual = std::nextafter(from, to); + // Check for bitwise equality! + ASSERT_EQ(actual.x ^ expected.x, uint16_t{0}); + }; + check_nextafter(zero, zero, /*expected=*/zero); + check_nextafter(zero, -zero, /*expected=*/-zero); + check_nextafter(-zero, zero, /*expected=*/zero); + check_nextafter(-zero, -zero, /*expected=*/-zero); +} + +float BinaryToFloat(uint32_t bytes) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} + +struct BFloat16TestParam { + uint32_t input; + uint16_t rne; +}; + +class BFloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(BFloat16Test, BFloat16RNETest) { + float value = BinaryToFloat(GetParam().input); + uint16_t rounded = torch::executor::internal::round_to_nearest_even(value); + EXPECT_EQ(GetParam().rne, rounded); +} + +INSTANTIATE_TEST_SUITE_P( + BFloat16TestInstantiation, + BFloat16Test, + ::testing::Values( + BFloat16TestParam{0x3F848000, 0x3F84}, + BFloat16TestParam{0x3F848010, 0x3F85}, + BFloat16TestParam{0x3F850000, 0x3F85}, + BFloat16TestParam{0x3F858000, 0x3F86}, + BFloat16TestParam{0x3FFF8000, 0x4000})); + +} // namespace diff --git a/runtime/core/portable_type/test/targets.bzl b/runtime/core/portable_type/test/targets.bzl index af55f95e45e..c0b4ef00c78 100644 --- a/runtime/core/portable_type/test/targets.bzl +++ b/runtime/core/portable_type/test/targets.bzl @@ -6,6 +6,14 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ + runtime.cxx_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cpp"], + deps = [ + "//executorch/runtime/core/portable_type:portable_type", + ], + ) + runtime.cxx_test( name = "optional_test", srcs = ["optional_test.cpp"],