From 4fb8d62ae1c1bff2e4b9651d260ceae2b2b0b4a3 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:05 -0700 Subject: [PATCH 01/11] [ExecuTorch] Implement BFloat16 and hook it up to scalar_type_util bfloat16.h was a stub. I've filled it out by porting the c10 implementation, added it to ET_SWITCH and ET_FORALL macros, and hooked it up to promoteTypes. I extended the half_to_float argument to promoteTypes to also coerce bfloat16 to float because I figured anybody who wants to ignore half probably also wants to ignore bf16. Differential Revision: [D61981361](https://our.internmc.facebook.com/intern/diff/D61981361/) [ghstack-poisoned] --- kernels/portable/cpu/scalar_utils.h | 18 +- .../core/exec_aten/util/genScalarTypeTable.py | 41 ++- .../core/exec_aten/util/scalar_type_util.h | 206 +++++++++--- .../util/test/scalar_type_util_test.cpp | 54 +-- runtime/core/portable_type/bfloat16.h | 310 ++++++++++++++++++ .../core/portable_type/test/CMakeLists.txt | 2 +- .../core/portable_type/test/bfloat16_test.cpp | 191 +++++++++++ runtime/core/portable_type/test/targets.bzl | 8 + 8 files changed, 743 insertions(+), 87 deletions(-) create mode 100644 runtime/core/portable_type/test/bfloat16_test.cpp diff --git a/kernels/portable/cpu/scalar_utils.h b/kernels/portable/cpu/scalar_utils.h index 3daf3e72526..3d6dfb75e47 100644 --- a/kernels/portable/cpu/scalar_utils.h +++ b/kernels/portable/cpu/scalar_utils.h @@ -94,12 +94,6 @@ struct promote_type_with_scalar_type { static_assert( !is_bits_type::value, "promote_type_with_scalar_type not valid for bits dtypes"); - static_assert( - !std::is_same< - T1, - typename ScalarTypeToCppType::type>:: - value, - "promote_type_with_scalar_type not valid for BFloat16"); using promote_type_with_scalar_type_not_respecting_half_to_float = typename std::conditional< is_complex_type::value || @@ -119,10 +113,14 @@ struct promote_type_with_scalar_type { public: using type = typename std::conditional< half_to_float && - std::is_same< - promote_type_with_scalar_type_not_respecting_half_to_float, - typename ScalarTypeToCppType::type>:: - value, + (std::is_same< + promote_type_with_scalar_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::Half>::type>::value || + std::is_same< + promote_type_with_scalar_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::BFloat16>::type>::value), typename ScalarTypeToCppType::type, promote_type_with_scalar_type_not_respecting_half_to_float>::type; }; diff --git a/runtime/core/exec_aten/util/genScalarTypeTable.py b/runtime/core/exec_aten/util/genScalarTypeTable.py index 07100472ae4..c2bc84c2700 100644 --- a/runtime/core/exec_aten/util/genScalarTypeTable.py +++ b/runtime/core/exec_aten/util/genScalarTypeTable.py @@ -4,20 +4,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -indexToType = ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"] +indexToType = [ + "U1", + "I1", + "I2", + "I4", + "I8", + "F2", + "F4", + "F8", + "C2", + "C4", + "C8", + "B1", + "BF", +] promoteTypesLookup = [ - ["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1"], - ["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1"], - ["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2"], - ["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4"], - ["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8"], - ["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2"], - ["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4"], - ["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8"], - ["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2"], - ["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4"], - ["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"], - ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1"], + ["U1", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "U1", "BF"], + ["I2", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I1", "BF"], + ["I2", "I2", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I2", "BF"], + ["I4", "I4", "I4", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I4", "BF"], + ["I8", "I8", "I8", "I8", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "I8", "BF"], + ["F2", "F2", "F2", "F2", "F2", "F2", "F4", "F8", "C2", "C4", "C8", "F2", "F4"], + ["F4", "F4", "F4", "F4", "F4", "F4", "F4", "F8", "C4", "C4", "C8", "F4", "F4"], + ["F8", "F8", "F8", "F8", "F8", "F8", "F8", "F8", "C8", "C8", "C8", "F8", "F8"], + ["C2", "C2", "C2", "C2", "C2", "C2", "C4", "C8", "C2", "C4", "C8", "C2", "C4"], + ["C4", "C4", "C4", "C4", "C4", "C4", "C4", "C8", "C4", "C4", "C8", "C4", "C4"], + ["C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8", "C8"], + ["U1", "I1", "I2", "I4", "I8", "F2", "F4", "F8", "C2", "C4", "C8", "B1", "BF"], + ["BF", "BF", "BF", "BF", "BF", "F4", "F4", "F8", "C4", "C4", "C8", "BF", "BF"], ] for rowIndex, row in enumerate(promoteTypesLookup): for colIndex, col in enumerate(row): diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index c92f910431f..479767b4abb 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -21,6 +21,7 @@ #pragma once +#include #include #include #include @@ -164,8 +165,21 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) ::exec_aten::ScalarType::SCALARTYPE>::type, \ SCALARTYPE) +#define ET_FORALL_FLOAT_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + #define ET_FORALL_FLOATH_TYPES(_) ET_FORALL_FLOAT_TYPES_AND(Half, _) +#define ET_FORALL_FLOATHBF16_TYPES(_) \ + ET_FORALL_FLOAT_TYPES_AND2(Half, BFloat16, _) + // Here `ANOTHER_INPUT` should be another variable to be forwarded to a given // function. Not to be confused with another scalar type as in // `ET_FORALL_FLOAT_TYPES_AND`. @@ -177,6 +191,12 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) +#define ET_FORALL_FLOATHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::exec_aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, ::exec_aten::BFloat16, BFloat16) + // In this context, "REAL" means integer/float C types, which is why BFloat16 // and Half are not included. #define ET_FORALL_REAL_TYPES(_) \ @@ -209,6 +229,17 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) +#define ET_FORALL_REALHBF16_TYPES_WITH2(ANOTHER_INPUT1, ANOTHER_INPUT2, _) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, uint8_t, Byte) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int8_t, Char) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int16_t, Short) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int32_t, Int) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, int64_t, Long) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, float, Float) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, double, Double) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, exec_aten::Half, Half) \ + _(ANOTHER_INPUT1, ANOTHER_INPUT2, exec_aten::BFloat16, BFloat16) + // For macros that take `SCALARTYPEn` parameters, those parameters should be // an unquoted/unqualified enumerator name like `Int` or `Float`. #define ET_FORALL_REAL_TYPES_AND(SCALARTYPE, _) \ @@ -223,8 +254,29 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) ::exec_aten::ScalarType::SCALARTYPE>::type, \ SCALARTYPE) +#define ET_FORALL_REAL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, _) \ + _(uint8_t, Byte) \ + _(int8_t, Char) \ + _(int16_t, Short) \ + _(int32_t, Int) \ + _(int64_t, Long) \ + _(float, Float) \ + _(double, Double) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE1>::type, \ + SCALARTYPE1) \ + _(::executorch::runtime::ScalarTypeToCppType< \ + ::exec_aten::ScalarType::SCALARTYPE2>::type, \ + SCALARTYPE2) + #define ET_FORALL_REALH_TYPES(_) ET_FORALL_REAL_TYPES_AND(Half, _) +#define ET_FORALL_REALHBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND2(Half, BFloat16, _) + +#define ET_FORALL_REALHBBF16_TYPES(_) \ + ET_FORALL_REAL_TYPES_AND3(Bool, Half, BFloat16, _) + #define ET_FORALL_REAL_TYPES_AND_WITH(SCALARTYPE, ANOTHER_INPUT, _) \ _(ANOTHER_INPUT, uint8_t, Byte) \ _(ANOTHER_INPUT, int8_t, Char) \ @@ -381,6 +433,10 @@ inline bool isRealHBType(exec_aten::ScalarType t) { return (isRealHType(t) || t == exec_aten::ScalarType::Bool); } +inline bool isRealHBBF16Type(exec_aten::ScalarType t) { + return (isRealHBType(t) || t == exec_aten::ScalarType::BFloat16); +} + inline constexpr bool isComplexType(exec_aten::ScalarType t) { return ( t == exec_aten::ScalarType::ComplexHalf || @@ -589,6 +645,7 @@ using C4 = using C8 = typename ScalarTypeToCppType::type; using B1 = typename ScalarTypeToCppType::type; +using BF = typename ScalarTypeToCppType::type; #define TABLE_ENTRY(key1, key2, value) \ template <> \ @@ -613,6 +670,7 @@ TABLE_ENTRY(U1, C2, C2); TABLE_ENTRY(U1, C4, C4); TABLE_ENTRY(U1, C8, C8); TABLE_ENTRY(U1, B1, U1); +TABLE_ENTRY(U1, BF, BF); TABLE_ENTRY(I1, U1, I2); TABLE_ENTRY(I1, I1, I1); TABLE_ENTRY(I1, I2, I2); @@ -625,6 +683,7 @@ TABLE_ENTRY(I1, C2, C2); TABLE_ENTRY(I1, C4, C4); TABLE_ENTRY(I1, C8, C8); TABLE_ENTRY(I1, B1, I1); +TABLE_ENTRY(I1, BF, BF); TABLE_ENTRY(I2, U1, I2); TABLE_ENTRY(I2, I1, I2); TABLE_ENTRY(I2, I2, I2); @@ -637,6 +696,7 @@ TABLE_ENTRY(I2, C2, C2); TABLE_ENTRY(I2, C4, C4); TABLE_ENTRY(I2, C8, C8); TABLE_ENTRY(I2, B1, I2); +TABLE_ENTRY(I2, BF, BF); TABLE_ENTRY(I4, U1, I4); TABLE_ENTRY(I4, I1, I4); TABLE_ENTRY(I4, I2, I4); @@ -649,6 +709,7 @@ TABLE_ENTRY(I4, C2, C2); TABLE_ENTRY(I4, C4, C4); TABLE_ENTRY(I4, C8, C8); TABLE_ENTRY(I4, B1, I4); +TABLE_ENTRY(I4, BF, BF); TABLE_ENTRY(I8, U1, I8); TABLE_ENTRY(I8, I1, I8); TABLE_ENTRY(I8, I2, I8); @@ -661,6 +722,7 @@ TABLE_ENTRY(I8, C2, C2); TABLE_ENTRY(I8, C4, C4); TABLE_ENTRY(I8, C8, C8); TABLE_ENTRY(I8, B1, I8); +TABLE_ENTRY(I8, BF, BF); TABLE_ENTRY(F2, U1, F2); TABLE_ENTRY(F2, I1, F2); TABLE_ENTRY(F2, I2, F2); @@ -673,6 +735,7 @@ TABLE_ENTRY(F2, C2, C2); TABLE_ENTRY(F2, C4, C4); TABLE_ENTRY(F2, C8, C8); TABLE_ENTRY(F2, B1, F2); +TABLE_ENTRY(F2, BF, F4); TABLE_ENTRY(F4, U1, F4); TABLE_ENTRY(F4, I1, F4); TABLE_ENTRY(F4, I2, F4); @@ -685,6 +748,7 @@ TABLE_ENTRY(F4, C2, C4); TABLE_ENTRY(F4, C4, C4); TABLE_ENTRY(F4, C8, C8); TABLE_ENTRY(F4, B1, F4); +TABLE_ENTRY(F4, BF, F4); TABLE_ENTRY(F8, U1, F8); TABLE_ENTRY(F8, I1, F8); TABLE_ENTRY(F8, I2, F8); @@ -697,6 +761,7 @@ TABLE_ENTRY(F8, C2, C8); TABLE_ENTRY(F8, C4, C8); TABLE_ENTRY(F8, C8, C8); TABLE_ENTRY(F8, B1, F8); +TABLE_ENTRY(F8, BF, F8); TABLE_ENTRY(C2, U1, C2); TABLE_ENTRY(C2, I1, C2); TABLE_ENTRY(C2, I2, C2); @@ -709,6 +774,7 @@ TABLE_ENTRY(C2, C2, C2); TABLE_ENTRY(C2, C4, C4); TABLE_ENTRY(C2, C8, C8); TABLE_ENTRY(C2, B1, C2); +TABLE_ENTRY(C2, BF, C4); TABLE_ENTRY(C4, U1, C4); TABLE_ENTRY(C4, I1, C4); TABLE_ENTRY(C4, I2, C4); @@ -721,6 +787,7 @@ TABLE_ENTRY(C4, C2, C4); TABLE_ENTRY(C4, C4, C4); TABLE_ENTRY(C4, C8, C8); TABLE_ENTRY(C4, B1, C4); +TABLE_ENTRY(C4, BF, C4); TABLE_ENTRY(C8, U1, C8); TABLE_ENTRY(C8, I1, C8); TABLE_ENTRY(C8, I2, C8); @@ -733,6 +800,7 @@ TABLE_ENTRY(C8, C2, C8); TABLE_ENTRY(C8, C4, C8); TABLE_ENTRY(C8, C8, C8); TABLE_ENTRY(C8, B1, C8); +TABLE_ENTRY(C8, BF, C8); TABLE_ENTRY(B1, U1, U1); TABLE_ENTRY(B1, I1, I1); TABLE_ENTRY(B1, I2, I2); @@ -745,6 +813,20 @@ TABLE_ENTRY(B1, C2, C2); TABLE_ENTRY(B1, C4, C4); TABLE_ENTRY(B1, C8, C8); TABLE_ENTRY(B1, B1, B1); +TABLE_ENTRY(B1, BF, BF); +TABLE_ENTRY(BF, U1, BF); +TABLE_ENTRY(BF, I1, BF); +TABLE_ENTRY(BF, I2, BF); +TABLE_ENTRY(BF, I4, BF); +TABLE_ENTRY(BF, I8, BF); +TABLE_ENTRY(BF, F2, F4); +TABLE_ENTRY(BF, F4, F4); +TABLE_ENTRY(BF, F8, F8); +TABLE_ENTRY(BF, C2, C4); +TABLE_ENTRY(BF, C4, C4); +TABLE_ENTRY(BF, C8, C8); +TABLE_ENTRY(BF, B1, BF); +TABLE_ENTRY(BF, BF, BF); } // namespace internal @@ -760,26 +842,20 @@ struct promote_types { (!is_bits_type::value && !is_bits_type::value), "promote_types not valid for bits dtypes"); - static_assert( - !std::is_same< - T1, - typename ScalarTypeToCppType::type>:: - value && - !std::is_same< - T2, - typename ScalarTypeToCppType< - exec_aten::ScalarType::BFloat16>::type>::value, - "promote_types not valid for BFloat16"); using promoted_type_not_respecting_half_to_float = typename internal::promote_types_lookup::type; public: using type = typename std::conditional< half_to_float && - std::is_same< - promoted_type_not_respecting_half_to_float, - typename ScalarTypeToCppType::type>:: - value, + (std::is_same< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::Half>::type>::value || + std::is_same< + promoted_type_not_respecting_half_to_float, + typename ScalarTypeToCppType< + exec_aten::ScalarType::BFloat16>::type>::value), typename ScalarTypeToCppType::type, promoted_type_not_respecting_half_to_float>::type; }; @@ -787,7 +863,8 @@ struct promote_types { /** * Implements type promotion rules that are consistent with ATen behaviour, * which in turn is consistent with NumPy's promote_types. - * If half_to_float is set to true, then half will be promoted to float instead + * If half_to_float is set to true, then half and bfloat16 will be promoted to + * float instead */ inline exec_aten::ScalarType promoteTypes( exec_aten::ScalarType a, @@ -806,6 +883,7 @@ inline exec_aten::ScalarType promoteTypes( constexpr auto c4 = exec_aten::ScalarType::ComplexFloat; constexpr auto c8 = exec_aten::ScalarType::ComplexDouble; constexpr auto b1 = exec_aten::ScalarType::Bool; + constexpr auto bf = exec_aten::ScalarType::BFloat16; // For QInt types, only allow exact match if (executorch::runtime::isQIntType(a) && a == b) { @@ -825,34 +903,41 @@ inline exec_aten::ScalarType promoteTypes( ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes"); } - ET_CHECK_MSG( - a != exec_aten::ScalarType::BFloat16 && - b != exec_aten::ScalarType::BFloat16, - "promoteTypes not valid for BFloat16"); // 12 types are handled by this function, see the constexpr definitions above - const int NUM_PROMOTE_TYPES = 12; - + const int NUM_PROMOTE_TYPES = 13; + + static constexpr std::array + dtype2index = {{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + -1, -1, -1, 12, -1, -1, -1, -1, -1, -1, -1, -1, + }}; + auto ix_a = dtype2index[(int)a]; + ET_CHECK(ix_a != -1); + auto ix_b = dtype2index[(int)b]; + ET_CHECK(ix_b != -1); static constexpr exec_aten::ScalarType _promoteTypesLookup[NUM_PROMOTE_TYPES][NUM_PROMOTE_TYPES] = { - /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 */ - /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1}, - /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1}, - /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2}, - /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4}, - /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8}, - /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2}, - /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4}, - /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8}, - /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2}, - /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4}, - /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, - /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1}, + /* u1 i1 i2 i4 i8 f2 f4 f8 c2 c4 c8 b1 bf*/ + /* u1 */ {u1, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, u1, bf}, + /* i1 */ {i2, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, i1, bf}, + /* i2 */ {i2, i2, i2, i4, i8, f2, f4, f8, c2, c4, c8, i2, bf}, + /* i4 */ {i4, i4, i4, i4, i8, f2, f4, f8, c2, c4, c8, i4, bf}, + /* i8 */ {i8, i8, i8, i8, i8, f2, f4, f8, c2, c4, c8, i8, bf}, + /* f2 */ {f2, f2, f2, f2, f2, f2, f4, f8, c2, c4, c8, f2, f4}, + /* f4 */ {f4, f4, f4, f4, f4, f4, f4, f8, c4, c4, c8, f4, f4}, + /* f8 */ {f8, f8, f8, f8, f8, f8, f8, f8, c8, c8, c8, f8, f8}, + /* c2 */ {c2, c2, c2, c2, c2, c2, c4, c8, c2, c4, c8, c2, c4}, + /* c4 */ {c4, c4, c4, c4, c4, c4, c4, c8, c4, c4, c8, c4, c4}, + /* c8 */ {c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8, c8}, + /* b1 */ {u1, i1, i2, i4, i8, f2, f4, f8, c2, c4, c8, b1, bf}, + /* bf */ {bf, bf, bf, bf, bf, f4, f4, f8, c4, c4, c8, bf, bf}, }; - exec_aten::ScalarType promoted_type = - _promoteTypesLookup[static_cast(a)][static_cast(b)]; + exec_aten::ScalarType promoted_type = _promoteTypesLookup[ix_a][ix_b]; - if (half_to_float && promoted_type == exec_aten::ScalarType::Half) { + if (half_to_float && + (promoted_type == exec_aten::ScalarType::Half || + promoted_type == exec_aten::ScalarType::BFloat16)) { promoted_type = exec_aten::ScalarType::Float; } @@ -974,6 +1059,13 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + exec_aten::ScalarType::ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_INT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::Byte, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1001,6 +1093,13 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \ + ADDITIONAL1, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + exec_aten::ScalarType::ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_QINT_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ exec_aten::ScalarType::QInt8, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1112,6 +1211,22 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND2( \ ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) +#define ET_SWITCH_REAL_TYPES_AND3( \ + ADDITIONAL1, \ + ADDITIONAL2, \ + ADDITIONAL3, \ + TYPE, \ + CONTEXT, \ + NAME, \ + CTYPE_ALIAS, \ + ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_REAL_TYPES_AND3( \ + ADDITIONAL1, ADDITIONAL2, ADDITIONAL3, CTYPE_ALIAS, __VA_ARGS__)) + #define ET_SWITCH_REALH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_REAL_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) @@ -1122,6 +1237,10 @@ inline exec_aten::ScalarType promoteTypes( ET_SWITCH_REAL_TYPES_AND2( \ Half, Bool, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) +#define ET_SWITCH_REALHBBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_REAL_TYPES_AND3( \ + Half, Bool, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + #define ET_SWITCH_INT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \ TYPE, \ @@ -1154,9 +1273,22 @@ inline exec_aten::ScalarType promoteTypes( ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND( \ ADDITIONAL, CTYPE_ALIAS, __VA_ARGS__)) +#define ET_SWITCH_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_FLOAT_TYPES_AND2( \ + ADDITIONAL1, ADDITIONAL2, CTYPE_ALIAS, __VA_ARGS__)) + #define ET_SWITCH_FLOATH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_SWITCH_FLOAT_TYPES_AND(Half, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) +#define ET_SWITCH_FLOATHBF16_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_SWITCH_FLOAT_TYPES_AND2( \ + Half, BFloat16, TYPE, CONTEXT, NAME, CTYPE_ALIAS, __VA_ARGS__) + #define ET_SWITCH_QINT_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \ TYPE, \ diff --git a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp index b91c7009f45..9df01b7be9f 100644 --- a/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp +++ b/runtime/core/exec_aten/util/test/scalar_type_util_test.cpp @@ -139,37 +139,38 @@ TEST(ScalarTypeUtilTest, promoteTypesTest) { // Check some common cases - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Double) == - ScalarType::Double); - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Short) == ScalarType::Float); - - ET_CHECK( - promoteTypes(ScalarType::Float, ScalarType::Int) == ScalarType::Float); - ET_CHECK( - promoteTypes(ScalarType::Long, ScalarType::Float) == ScalarType::Float); - - ET_CHECK( - promoteTypes(ScalarType::Bool, ScalarType::Bool) == ScalarType::Bool); - - ET_CHECK(promoteTypes(ScalarType::Byte, ScalarType::Int) == ScalarType::Int); - ET_CHECK( - promoteTypes(ScalarType::Char, ScalarType::Bool) == ScalarType::Char); - ET_CHECK(promoteTypes(ScalarType::Bool, ScalarType::Int) == ScalarType::Int); + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Double), ScalarType::Double); + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Short), ScalarType::Float); + + EXPECT_EQ( + promoteTypes(ScalarType::Float, ScalarType::Int), ScalarType::Float); + EXPECT_EQ( + promoteTypes(ScalarType::Long, ScalarType::Float), ScalarType::Float); + + EXPECT_EQ(promoteTypes(ScalarType::Bool, ScalarType::Bool), ScalarType::Bool); + + EXPECT_EQ(promoteTypes(ScalarType::Byte, ScalarType::Int), ScalarType::Int); + EXPECT_EQ(promoteTypes(ScalarType::Char, ScalarType::Bool), ScalarType::Char); + EXPECT_EQ(promoteTypes(ScalarType::Bool, ScalarType::Int), ScalarType::Int); + + EXPECT_EQ( + promoteTypes(ScalarType::BFloat16, ScalarType::Half), ScalarType::Float); + EXPECT_EQ( + promoteTypes(ScalarType::BFloat16, ScalarType::Bool), + ScalarType::BFloat16); } template struct promote_types_is_valid : std::integral_constant< bool, - !std::is_same::value && - !std::is_same::value && - (std::is_same::value || - (!executorch::runtime::is_qint_type::value && - !executorch::runtime::is_qint_type::value && - !executorch::runtime::is_bits_type::value && - !executorch::runtime::is_bits_type::value))> {}; + (std::is_same::value || + (!executorch::runtime::is_qint_type::value && + !executorch::runtime::is_qint_type::value && + !executorch::runtime::is_bits_type::value && + !executorch::runtime::is_bits_type::value))> {}; template struct CompileTimePromoteTypesTestCase { @@ -195,7 +196,8 @@ struct CompileTimePromoteTypesTestCase { auto expected = executorch::runtime::promoteTypes( scalarType1, scalarType2, half_to_float); EXPECT_EQ(actual, expected) - << "promoting " << (int)scalarType1 << " to " << (int)scalarType2; + << "promoting " << (int)scalarType1 << " to " << (int)scalarType2 + << " (half to float: " << half_to_float << ')'; } template < diff --git a/runtime/core/portable_type/bfloat16.h b/runtime/core/portable_type/bfloat16.h index a1ceb0c56a7..67f7478adff 100644 --- a/runtime/core/portable_type/bfloat16.h +++ b/runtime/core/portable_type/bfloat16.h @@ -8,11 +8,40 @@ #pragma once +#include #include +#include +#include namespace torch { namespace executor { +namespace internal { +inline float f32_from_bits(uint16_t src) { + float res = 0; + uint32_t tmp = src; + tmp <<= 16; + std::memcpy(&res, &tmp, sizeof(tmp)); + return res; +} + +inline uint16_t bits_from_f32(float src) { + uint32_t res = 0; + std::memcpy(&res, &src, sizeof(res)); + return res >> 16; +} + +inline uint16_t round_to_nearest_even(float src) { + if (std::isnan(src)) { + return UINT16_C(0x7FC0); + } + uint32_t U32 = 0; + std::memcpy(&U32, &src, sizeof(U32)); + uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); + return static_cast((U32 + rounding_bias) >> 16); +} +} // namespace internal + /** * The "brain floating-point" type, compatible with c10/util/BFloat16.h from * pytorch core. @@ -22,7 +51,288 @@ namespace executor { */ struct alignas(2) BFloat16 { uint16_t x; + + BFloat16() = default; + struct from_bits_t {}; + static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + + constexpr BFloat16(unsigned short bits, from_bits_t) : x(bits) {} + /* implicit */ BFloat16(float value) + : x(internal::round_to_nearest_even(value)) {} + operator float() const { + return internal::f32_from_bits(x); + } }; +inline std::ostream& operator<<(std::ostream& out, const BFloat16& value) { + out << (float)value; + return out; +} + +/// Arithmetic + +inline BFloat16 operator+(const BFloat16& a, const BFloat16& b) { + return static_cast(a) + static_cast(b); +} + +inline BFloat16 operator-(const BFloat16& a, const BFloat16& b) { + return static_cast(a) - static_cast(b); +} + +inline BFloat16 operator*(const BFloat16& a, const BFloat16& b) { + return static_cast(a) * static_cast(b); +} + +inline BFloat16 operator/(const BFloat16& a, const BFloat16& b) { + return static_cast(a) / static_cast(b); +} + +inline BFloat16 operator-(const BFloat16& a) { + return -static_cast(a); +} + +inline BFloat16& operator+=(BFloat16& a, const BFloat16& b) { + a = a + b; + return a; +} + +inline BFloat16& operator-=(BFloat16& a, const BFloat16& b) { + a = a - b; + return a; +} + +inline BFloat16& operator*=(BFloat16& a, const BFloat16& b) { + a = a * b; + return a; +} + +inline BFloat16& operator/=(BFloat16& a, const BFloat16& b) { + a = a / b; + return a; +} + +inline BFloat16& operator|(BFloat16& a, const BFloat16& b) { + a.x = a.x | b.x; + return a; +} + +inline BFloat16& operator^(BFloat16& a, const BFloat16& b) { + a.x = a.x ^ b.x; + return a; +} + +inline BFloat16& operator&(BFloat16& a, const BFloat16& b) { + a.x = a.x & b.x; + return a; +} + +/// Arithmetic with floats + +inline float operator+(BFloat16 a, float b) { + return static_cast(a) + b; +} +inline float operator-(BFloat16 a, float b) { + return static_cast(a) - b; +} +inline float operator*(BFloat16 a, float b) { + return static_cast(a) * b; +} +inline float operator/(BFloat16 a, float b) { + return static_cast(a) / b; +} + +inline float operator+(float a, BFloat16 b) { + return a + static_cast(b); +} +inline float operator-(float a, BFloat16 b) { + return a - static_cast(b); +} +inline float operator*(float a, BFloat16 b) { + return a * static_cast(b); +} +inline float operator/(float a, BFloat16 b) { + return a / static_cast(b); +} + +inline float& operator+=(float& a, const BFloat16& b) { + return a += static_cast(b); +} +inline float& operator-=(float& a, const BFloat16& b) { + return a -= static_cast(b); +} +inline float& operator*=(float& a, const BFloat16& b) { + return a *= static_cast(b); +} +inline float& operator/=(float& a, const BFloat16& b) { + return a /= static_cast(b); +} + +/// Arithmetic with doubles + +inline double operator+(BFloat16 a, double b) { + return static_cast(a) + b; +} +inline double operator-(BFloat16 a, double b) { + return static_cast(a) - b; +} +inline double operator*(BFloat16 a, double b) { + return static_cast(a) * b; +} +inline double operator/(BFloat16 a, double b) { + return static_cast(a) / b; +} + +inline double operator+(double a, BFloat16 b) { + return a + static_cast(b); +} +inline double operator-(double a, BFloat16 b) { + return a - static_cast(b); +} +inline double operator*(double a, BFloat16 b) { + return a * static_cast(b); +} +inline double operator/(double a, BFloat16 b) { + return a / static_cast(b); +} + +/// Arithmetic with ints + +inline BFloat16 operator+(BFloat16 a, int b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int a, BFloat16 b) { + return static_cast(a) / b; +} + +//// Arithmetic with int64_t + +inline BFloat16 operator+(BFloat16 a, int64_t b) { + return a + static_cast(b); +} +inline BFloat16 operator-(BFloat16 a, int64_t b) { + return a - static_cast(b); +} +inline BFloat16 operator*(BFloat16 a, int64_t b) { + return a * static_cast(b); +} +inline BFloat16 operator/(BFloat16 a, int64_t b) { + return a / static_cast(b); +} + +inline BFloat16 operator+(int64_t a, BFloat16 b) { + return static_cast(a) + b; +} +inline BFloat16 operator-(int64_t a, BFloat16 b) { + return static_cast(a) - b; +} +inline BFloat16 operator*(int64_t a, BFloat16 b) { + return static_cast(a) * b; +} +inline BFloat16 operator/(int64_t a, BFloat16 b) { + return static_cast(a) / b; +} + +// Overloading < and > operators, because std::max and std::min use them. + +inline bool operator>(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) > float(rhs); +} + +inline bool operator<(BFloat16& lhs, BFloat16& rhs) { + return float(lhs) < float(rhs); +} + } // namespace executor } // namespace torch + +namespace std { + +template <> +class numeric_limits { + public: + static constexpr bool is_signed = true; + static constexpr bool is_specialized = true; + static constexpr bool is_integer = false; + static constexpr bool is_exact = false; + static constexpr bool has_infinity = true; + static constexpr bool has_quiet_NaN = true; + static constexpr bool has_signaling_NaN = true; + static constexpr auto has_denorm = numeric_limits::has_denorm; + static constexpr auto has_denorm_loss = + numeric_limits::has_denorm_loss; + static constexpr auto round_style = numeric_limits::round_style; + static constexpr bool is_iec559 = false; + static constexpr bool is_bounded = true; + static constexpr bool is_modulo = false; + static constexpr int digits = 8; + static constexpr int digits10 = 2; + static constexpr int max_digits10 = 4; + static constexpr int radix = 2; + static constexpr int min_exponent = -125; + static constexpr int min_exponent10 = -37; + static constexpr int max_exponent = 128; + static constexpr int max_exponent10 = 38; + static constexpr auto traps = numeric_limits::traps; + static constexpr auto tinyness_before = + numeric_limits::tinyness_before; + + static constexpr torch::executor::BFloat16 min() { + return torch::executor::BFloat16( + 0x0080, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 lowest() { + return torch::executor::BFloat16( + 0xFF7F, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 max() { + return torch::executor::BFloat16( + 0x7F7F, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 epsilon() { + return torch::executor::BFloat16( + 0x3C00, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 round_error() { + return torch::executor::BFloat16( + 0x3F00, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 infinity() { + return torch::executor::BFloat16( + 0x7F80, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 quiet_NaN() { + return torch::executor::BFloat16( + 0x7FC0, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 signaling_NaN() { + return torch::executor::BFloat16( + 0x7F80, torch::executor::BFloat16::from_bits()); + } + static constexpr torch::executor::BFloat16 denorm_min() { + return torch::executor::BFloat16( + 0x0001, torch::executor::BFloat16::from_bits()); + } +}; + +} // namespace std diff --git a/runtime/core/portable_type/test/CMakeLists.txt b/runtime/core/portable_type/test/CMakeLists.txt index 21eb4feae0f..58a69f656eb 100644 --- a/runtime/core/portable_type/test/CMakeLists.txt +++ b/runtime/core/portable_type/test/CMakeLists.txt @@ -24,7 +24,7 @@ set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../..) include(${EXECUTORCH_ROOT}/build/Test.cmake) set(_test_srcs optional_test.cpp tensor_test.cpp half_test.cpp scalar_test.cpp - tensor_impl_test.cpp + tensor_impl_test.cpp bfloat16_test.cpp ) et_cxx_test(runtime_core_portable_type_test SOURCES ${_test_srcs} EXTRA_LIBS) diff --git a/runtime/core/portable_type/test/bfloat16_test.cpp b/runtime/core/portable_type/test/bfloat16_test.cpp new file mode 100644 index 00000000000..9ea53e6cba2 --- /dev/null +++ b/runtime/core/portable_type/test/bfloat16_test.cpp @@ -0,0 +1,191 @@ +#include + +#include + +using torch::executor::BFloat16; + +namespace { +float float_from_bytes(uint32_t sign, uint32_t exponent, uint32_t fraction) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + uint32_t bytes; + bytes = 0; + bytes |= sign; + bytes <<= 8; + bytes |= exponent; + bytes <<= 23; + bytes |= fraction; + + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} + +TEST(BFloat16Conversion, FloatToBFloat16AndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; + + for (int i = 0; i < 100; ++i) { + bfloats[i].x = torch::executor::internal::bits_from_f32(in[i]); + out[i] = torch::executor::internal::f32_from_bits(bfloats[i].x); + + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } +} + +TEST(BFloat16Conversion, FloatToBFloat16RNEAndBack) { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float in[100]; + for (int i = 0; i < 100; ++i) { + // NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions,cppcoreguidelines-avoid-magic-numbers) + in[i] = i + 1.25; + } + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + BFloat16 bfloats[100]; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,modernize-avoid-c-arrays) + float out[100]; + + for (int i = 0; i < 100; ++i) { + bfloats[i].x = torch::executor::internal::round_to_nearest_even(in[i]); + out[i] = torch::executor::internal::f32_from_bits(bfloats[i].x); + + // The relative error should be less than 1/(2^7) since BFloat16 + // has 7 bits mantissa. + EXPECT_LE(std::fabs(out[i] - in[i]) / in[i], 1.0 / 128); + } +} + +TEST(BFloat16Conversion, NaN) { + float inNaN = float_from_bytes(0, 0xFF, 0x7FFFFF); + EXPECT_TRUE(std::isnan(inNaN)); + + BFloat16 a = BFloat16(inNaN); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_TRUE(std::isnan(out)); +} + +TEST(BFloat16Conversion, Inf) { + float inInf = float_from_bytes(0, 0xFF, 0); + EXPECT_TRUE(std::isinf(inInf)); + + BFloat16 a = BFloat16(inInf); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_TRUE(std::isinf(out)); +} + +TEST(BFloat16Conversion, SmallestDenormal) { + float in = std::numeric_limits::denorm_min(); // The smallest non-zero + // subnormal number + BFloat16 a = BFloat16(in); + float out = torch::executor::internal::f32_from_bits(a.x); + + EXPECT_FLOAT_EQ(in, out); +} + +TEST(BFloat16Math, Addition) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after addition, we should have no loss in precision. + + // input bits + // S | Exponent | Mantissa + // 0 | 10000000 | 10010000000000000000000 = 3.125 + float input = float_from_bytes(0, 0, 0x40480000); + + // expected bits + // S | Exponent | Mantissa + // 0 | 10000001 | 10010000000000000000000 = 6.25 + float expected = float_from_bytes(0, 0, 0x40c80000); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + BFloat16 b; + b.x = torch::executor::internal::bits_from_f32(input); + b = b + b; + + float res = torch::executor::internal::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} + +TEST(BFloat16Math, Subtraction) { + // This test verifies that if only first 7 bits of float's mantissa are + // changed after subtraction, we should have no loss in precision. + + // input bits + // S | Exponent | Mantissa + // 0 | 10000001 | 11101000000000000000000 = 7.625 + float input = float_from_bytes(0, 0, 0x40f40000); + + // expected bits + // S | Exponent | Mantissa + // 0 | 10000000 | 01010000000000000000000 = 2.625 + float expected = float_from_bytes(0, 0, 0x40280000); + + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) + BFloat16 b; + b.x = torch::executor::internal::bits_from_f32(input); + b = b - 5; + + float res = torch::executor::internal::f32_from_bits(b.x); + EXPECT_EQ(res, expected); +} + +// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) +TEST(BFloat16Math, NextAfterZero) { + const BFloat16 zero{0}; + + auto check_nextafter = [](BFloat16 from, BFloat16 to, BFloat16 expected) { + BFloat16 actual = std::nextafter(from, to); + // Check for bitwise equality! + ASSERT_EQ(actual.x ^ expected.x, uint16_t{0}); + }; + check_nextafter(zero, zero, /*expected=*/zero); + check_nextafter(zero, -zero, /*expected=*/-zero); + check_nextafter(-zero, zero, /*expected=*/zero); + check_nextafter(-zero, -zero, /*expected=*/-zero); +} + +float BinaryToFloat(uint32_t bytes) { + // NOLINTNEXTLINE(cppcoreguidelines-init-variables) + float res; + std::memcpy(&res, &bytes, sizeof(res)); + return res; +} + +struct BFloat16TestParam { + uint32_t input; + uint16_t rne; +}; + +class BFloat16Test : public ::testing::Test, + public ::testing::WithParamInterface {}; + +TEST_P(BFloat16Test, BFloat16RNETest) { + float value = BinaryToFloat(GetParam().input); + uint16_t rounded = torch::executor::internal::round_to_nearest_even(value); + EXPECT_EQ(GetParam().rne, rounded); +} + +INSTANTIATE_TEST_SUITE_P( + BFloat16TestInstantiation, + BFloat16Test, + ::testing::Values( + BFloat16TestParam{0x3F848000, 0x3F84}, + BFloat16TestParam{0x3F848010, 0x3F85}, + BFloat16TestParam{0x3F850000, 0x3F85}, + BFloat16TestParam{0x3F858000, 0x3F86}, + BFloat16TestParam{0x3FFF8000, 0x4000})); + +} // namespace diff --git a/runtime/core/portable_type/test/targets.bzl b/runtime/core/portable_type/test/targets.bzl index af55f95e45e..c0b4ef00c78 100644 --- a/runtime/core/portable_type/test/targets.bzl +++ b/runtime/core/portable_type/test/targets.bzl @@ -6,6 +6,14 @@ def define_common_targets(): The directory containing this targets.bzl file should also contain both TARGETS and BUCK files that call this function. """ + runtime.cxx_test( + name = "bfloat16_test", + srcs = ["bfloat16_test.cpp"], + deps = [ + "//executorch/runtime/core/portable_type:portable_type", + ], + ) + runtime.cxx_test( name = "optional_test", srcs = ["optional_test.cpp"], From 4e8b86b5de7e7ef7422ca39b35c86fc9a34cf41a Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:08 -0700 Subject: [PATCH 02/11] [ExecuTorch] support BF16 in op_to_copy Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981356](https://our.internmc.facebook.com/intern/diff/D61981356/) [ghstack-poisoned] --- kernels/portable/cpu/op_to_copy.cpp | 9 +++++---- kernels/test/op_to_copy_test.cpp | 27 +++++++++++++++++++-------- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/kernels/portable/cpu/op_to_copy.cpp b/kernels/portable/cpu/op_to_copy.cpp index 7ecd4f3b5e1..c0c04e65e93 100644 --- a/kernels/portable/cpu/op_to_copy.cpp +++ b/kernels/portable/cpu/op_to_copy.cpp @@ -46,10 +46,11 @@ Tensor& to_copy_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] { - ET_SWITCH_REALHB_TYPES(out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] { - _to_impl(self, out); - }); + ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] { + ET_SWITCH_REALHBBF16_TYPES( + out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] { + _to_impl(self, out); + }); }); return out; diff --git a/kernels/test/op_to_copy_test.cpp b/kernels/test/op_to_copy_test.cpp index 1cc892dedbe..0a6529e736d 100644 --- a/kernels/test/op_to_copy_test.cpp +++ b/kernels/test/op_to_copy_test.cpp @@ -36,7 +36,9 @@ typedef std::map< std::type_index, std::variant< std::vector, - std::vector>> + std::vector, + std::vector, + std::vector>> FloatingTypeToDataMap; typedef std::map< @@ -309,9 +311,9 @@ TEST_F(OpToTest, AllDtypesSupported) { ScalarType::OUTPUT_DTYPE>(test_cases); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_REALHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL @@ -323,14 +325,14 @@ TEST_F(OpToTest, BoolTests) { #define TEST_TO_BOOL(INPUT_CTYPE, INPUT_DTYPE) \ test_runner_to_bool( \ test_case_to_bool, result_to_bool); - ET_FORALL_REAL_TYPES(TEST_TO_BOOL); + ET_FORALL_REALHBF16_TYPES(TEST_TO_BOOL); std::vector test_case_from_bool = {true, true, false}; std::vector result_from_bool = {1.0, 1.0, 0}; #define TEST_FROM_BOOL(OUTPUT_CTYPE, OUTPUT_DTYPE) \ test_runner_from_bool( \ test_case_from_bool, result_from_bool); - ET_FORALL_REAL_TYPES(TEST_FROM_BOOL); + ET_FORALL_REALHBF16_TYPES(TEST_FROM_BOOL); } TEST_F(OpToTest, NanInfSupported) { @@ -349,9 +351,9 @@ TEST_F(OpToTest, NanInfSupported) { ScalarType::OUTPUT_DTYPE>(test_cases); #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ - ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); + ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY #undef TEST_KERNEL @@ -381,6 +383,13 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { -0.30919688936285893988}; // clang-format on + std::vector half_data; + std::vector bf16_data; + for (auto d : double_data) { + half_data.emplace_back(d); + bf16_data.emplace_back(d); + } + std::vector int64_data = { -1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0}; std::vector int32_data = { @@ -394,6 +403,8 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { FloatingTypeToDataMap floating_point_data; floating_point_data[typeid(float)] = float_data; floating_point_data[typeid(double)] = double_data; + floating_point_data[typeid(exec_aten::Half)] = half_data; + floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data; // Gathering all int data together for better traversial IntTypeToDataMap int_data; @@ -412,7 +423,7 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) { #define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \ ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL); - ET_FORALL_FLOAT_TYPES(TEST_ENTRY); + ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY); } TEST_F(OpToTest, MismatchedSizesDie) { From ff889745571d9641f75b8a954fa21ee459066d18 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:11 -0700 Subject: [PATCH 03/11] [ExecuTorch] support BF16 in op_mul Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981355](https://our.internmc.facebook.com/intern/diff/D61981355/) [ghstack-poisoned] --- kernels/optimized/cpu/binary_ops.h | 3 +- kernels/optimized/cpu/op_mul.cpp | 16 +- kernels/portable/cpu/op_mul.cpp | 14 +- kernels/test/op_mul_test.cpp | 158 +++++++++++------- .../exec_aten/testing_util/tensor_util.cpp | 2 +- runtime/core/exec_aten/util/tensor_util.h | 9 + 6 files changed, 125 insertions(+), 77 deletions(-) diff --git a/kernels/optimized/cpu/binary_ops.h b/kernels/optimized/cpu/binary_ops.h index 01f3eed401e..6d941509f72 100644 --- a/kernels/optimized/cpu/binary_ops.h +++ b/kernels/optimized/cpu/binary_ops.h @@ -75,7 +75,8 @@ ElementwiseOptimizedPath inline select_optimized_path( ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); - if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half) { + if (a_type != b_type || a_type != out_type || a_type == ScalarType::Half || + a_type == ScalarType::BFloat16) { return ElementwiseOptimizedPath::kNone; } if (a.sizes().equals(b.sizes()) || diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 3b93870a610..4f7af01ed9b 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -80,7 +80,7 @@ Tensor& opt_mul_out( ScalarType out_type = out.scalar_type(); if (b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { auto error = resize_tensor(out, a.sizes()); ET_KERNEL_CHECK_MSG( ctx, @@ -170,12 +170,12 @@ Tensor& opt_mul_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { apply_binary_elementwise_fn( [](const CTYPE_A val_a, const CTYPE_B val_b) { CTYPE_IN a_casted = static_cast(val_a); @@ -210,7 +210,7 @@ Tensor& opt_mul_scalar_out( ET_CHECK(common_type == out_type); - if (common_type == ScalarType::Half) { + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { common_type = ScalarType::Float; } @@ -219,7 +219,7 @@ Tensor& opt_mul_scalar_out( ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); if (a_type == common_type && a_type == out_type && - a_type != ScalarType::Half) { + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { CTYPE_B b_val; @@ -235,11 +235,11 @@ Tensor& opt_mul_scalar_out( }); }); } else { - ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { ET_SWITCH_REALB_TYPES( common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES( out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { CTYPE_B b_val; ET_EXTRACT_SCALAR(b, b_val); diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index c933d10d274..1a6a57eb4a3 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -70,7 +70,7 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, executorch::runtime::tensor_is_realhbbf16_type(out), InvalidArgument, out); ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); @@ -79,12 +79,12 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() { MulInner< can_cast::value, CTYPE_A, @@ -123,15 +123,15 @@ Tensor& mul_scalar_out( ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out); - if (common_type == ScalarType::Half) { + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { common_type = ScalarType::Float; } - ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "mul.Scalar_out", CTYPE_B, [&]() { ET_SWITCH_REALB_TYPES( common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES( out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() { CTYPE_B b_val; utils::extract_scalar(b, &b_val); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 32b69352ef1..0a6abb03516 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -72,7 +72,7 @@ class OpMulOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_mul_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -89,29 +89,99 @@ class OpMulOutTest : public OperatorTest { // Multiply two tensors op_mul_out( - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes), out); - EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8})); + tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}), tf.ones(sizes), out); + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875})); op_mul_out( tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.zeros(sizes), out); EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{0.0, 0.0, 0.0, 0.0})); op_mul_out( - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), + tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}), + tf.make(sizes, /*data=*/{1.25, 2.5, 4.75, 8.875}), out); EXPECT_TENSOR_CLOSE( - out, tf.make(sizes, /*data=*/{1.21, 4.84, 19.36, 77.44})); + out, tf.make(sizes, /*data=*/{1.5625, 6.25, 22.5625, 78.765625})); } void test_mul_enumerate_a_types() { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_mul_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } + + template + void test_optimized_path_ignores_leading_1_dimensions() { + TensorFactory tf; + + const std::vector sizes1 = {1, 1, 2, 2}; + const std::vector sizes2 = {1, 2, 2}; + + // Destination for the mul. + Tensor out = tf.zeros(sizes1); + + // Multiply two tensors + op_mul_out( + tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out); + EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8})); + } + + template + void test_broadcast_a2b() { + TensorFactory tf_a; + + std::vector> b_sizeses = { + {2}, + {1, 2}, + }; + for (const auto& b_sizes : b_sizeses) { + // a and b of different shapes + Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4}); + Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE( + op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8})); + } + } + + template + void test_broadcast_b2a() { + TensorFactory tf_a; + // a and b of different shapes + Tensor a = tf_a.make({2}, /*data=*/{2, 2}); + Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4}); + + // Destination for output of mul. + Tensor out = tf_a.zeros({2, 2}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE( + op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8})); + } + + template + void test_scalar_input_broadcast() { + TensorFactory tf_a; + + // a is a 1d tensor and b is a scalar + Tensor a = tf_a.make({2}, /*data=*/{2, 2}); + Tensor b = tf_a.make({}, /*data=*/{2}); + + // Destination for output of mul. + Tensor out = tf_a.make({2}, /*data=*/{2, 2}); + Tensor expected = tf_a.make({2}, /*data=*/{4, 4}); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); + EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + } }; class OpMulScalarOutTest : public OperatorTest { @@ -141,6 +211,14 @@ TEST_F(OpMulOutTest, DoubleTensors) { test_floating_point_mul_out(); } +TEST_F(OpMulOutTest, HalfTensors) { + test_floating_point_mul_out(); +} + +TEST_F(OpMulOutTest, BFloat16Tensors) { + test_floating_point_mul_out(); +} + TEST_F(OpMulOutTest, BoolTensors) { TensorFactory tf; @@ -166,18 +244,12 @@ TEST_F(OpMulOutTest, BoolTensors) { } TEST_F(OpMulOutTest, OptimizedPathIgnoresLeading1Dimensions) { - TensorFactory tf; - - const std::vector sizes1 = {1, 1, 2, 2}; - const std::vector sizes2 = {1, 2, 2}; +#define ENUMERATE_TEST_ENTRY(ctype, dtype) \ + test_optimized_path_ignores_leading_1_dimensions(); - // Destination for the mul. - Tensor out = tf.zeros(sizes1); + ET_FORALL_FLOATHBF16_TYPES(ENUMERATE_TEST_ENTRY); - // Multiply two tensors - op_mul_out( - tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8}), tf.ones(sizes2), out); - EXPECT_TENSOR_CLOSE(out, tf.make(sizes1, /*data=*/{1.1, 2.2, 4.4, 8.8})); +#undef ENUMERATE_TEST_ENTRY } // Mismatched shape tests. @@ -202,40 +274,16 @@ TEST_F(OpMulOutTest, MismatchedNonBroadcastableInputShapesDies) { // Broadcast tensor b's size to tensor a's size TEST_F(OpMulOutTest, BroadcastA2BTest) { - TensorFactory tf_a; - - std::vector> b_sizeses = { - {2}, - {1, 2}, - }; - for (const auto& b_sizes : b_sizeses) { - // a and b of different shapes - Tensor a = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4}); - Tensor b = tf_a.make(b_sizes, /*data=*/{2, 2}); - - // Destination for output of mul. - Tensor out = tf_a.zeros({2, 2}); - - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE( - op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8})); - } + test_broadcast_a2b(); + test_broadcast_a2b(); + test_broadcast_a2b(); } // Broadcast tensor a's size to tensor b's size TEST_F(OpMulOutTest, BroadcastB2ATest) { - TensorFactory tf_a; - - // a and b of different shapes - Tensor a = tf_a.make({2}, /*data=*/{2, 2}); - Tensor b = tf_a.make({2, 2}, /*data=*/{1, 2, 3, 4}); - - // Destination for output of mul. - Tensor out = tf_a.zeros({2, 2}); - - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE( - op_mul_out(a, b, out), tf_a.make({2, 2}, /*data=*/{2, 4, 6, 8})); + test_broadcast_b2a(); + test_broadcast_b2a(); + test_broadcast_b2a(); } // Broadcast tensor a and b's size to a new size c. @@ -256,19 +304,9 @@ TEST_F(OpMulOutTest, BroadcastAB2CTest) { } TEST_F(OpMulOutTest, ScalarInputBroadcastTest) { - TensorFactory tf_a; - - // a is a 1d tensor and b is a scalar - Tensor a = tf_a.make({2}, /*data=*/{2, 2}); - Tensor b = tf_a.make({}, /*data=*/{2}); - - // Destination for output of mul. - Tensor out = tf_a.make({2}, /*data=*/{2, 2}); - Tensor expected = tf_a.make({2}, /*data=*/{4, 4}); - - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE(op_mul_out(a, b, out), expected); - EXPECT_TENSOR_CLOSE(op_mul_out(b, a, out), expected); + test_scalar_input_broadcast(); + test_scalar_input_broadcast(); + test_scalar_input_broadcast(); } TEST_F(OpMulOutTest, MismatchedOutputShapesDies) { diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index 03dffd208f0..0712b7177bf 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -269,7 +269,7 @@ std::ostream& operator<<(std::ostream& os, const Tensor& t) { break; switch (t.scalar_type()) { - ET_FORALL_REAL_TYPES_AND2(Half, Bool, PRINT_CASE) + ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, PRINT_CASE) default: ET_CHECK_MSG( false, diff --git a/runtime/core/exec_aten/util/tensor_util.h b/runtime/core/exec_aten/util/tensor_util.h index b18cd349a62..b920a2aebaf 100644 --- a/runtime/core/exec_aten/util/tensor_util.h +++ b/runtime/core/exec_aten/util/tensor_util.h @@ -510,6 +510,15 @@ inline bool tensor_is_realhb_type(exec_aten::Tensor t) { return true; } +inline bool tensor_is_realhbbf16_type(exec_aten::Tensor t) { + ET_LOG_MSG_AND_RETURN_IF_FALSE( + executorch::runtime::isRealHBBF16Type(t.scalar_type()), + "Expected to find a real type, but tensor has type %s", + torch::executor::toString(t.scalar_type())); + + return true; +} + inline bool tensor_is_complex_type(exec_aten::Tensor t) { ET_LOG_MSG_AND_RETURN_IF_FALSE( torch::executor::isComplexType(t.scalar_type()), From 46579e447481c8c20a2f1dddca54894969e9cb1f Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:14 -0700 Subject: [PATCH 04/11] [ExecuTorch] support BF16 in op_mm Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981353](https://our.internmc.facebook.com/intern/diff/D61981353/) [ghstack-poisoned] --- kernels/portable/cpu/op_mm.cpp | 27 ++++++++++++++------------- kernels/test/op_mm_test.cpp | 2 +- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/kernels/portable/cpu/op_mm.cpp b/kernels/portable/cpu/op_mm.cpp index 6903bf3cad5..2c1a426ea44 100644 --- a/kernels/portable/cpu/op_mm.cpp +++ b/kernels/portable/cpu/op_mm.cpp @@ -29,19 +29,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) { InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { - size_t m = in.size(0); - size_t n = in.size(1); - size_t p = mat2.size(1); - - vec_matmul( - out.mutable_data_ptr(), - in.const_data_ptr(), - mat2.const_data_ptr(), - m, - n, - p); - }); + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t m = in.size(0); + size_t n = in.size(1); + size_t p = mat2.size(1); + + vec_matmul( + out.mutable_data_ptr(), + in.const_data_ptr(), + mat2.const_data_ptr(), + m, + n, + p); + }); return out; } diff --git a/kernels/test/op_mm_test.cpp b/kernels/test/op_mm_test.cpp index 70d4b5ff0f5..c05792523f2 100644 --- a/kernels/test/op_mm_test.cpp +++ b/kernels/test/op_mm_test.cpp @@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) { /// zeros(). TEST_F(OpMmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() From 8388e56b203ae123ec2e218a02a54b0140087c33 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:17 -0700 Subject: [PATCH 05/11] [ExecuTorch] support BF16 in op_copy Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981357](https://our.internmc.facebook.com/intern/diff/D61981357/) [ghstack-poisoned] --- kernels/portable/cpu/op_copy.cpp | 8 ++++---- kernels/test/op_copy_test.cpp | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 900b6e39d34..58b3ab5521d 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -42,8 +42,8 @@ Tensor& copy_out( ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); @@ -69,8 +69,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) { ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); diff --git a/kernels/test/op_copy_test.cpp b/kernels/test/op_copy_test.cpp index 82332f85eb2..007b10a7636 100644 --- a/kernels/test/op_copy_test.cpp +++ b/kernels/test/op_copy_test.cpp @@ -125,13 +125,13 @@ class OpCopyInplaceTest : public OperatorTest { // regular test for copy.out TEST_F(OpCopyTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpCopyTest, EmptyInputSupported) { #define TEST_ENTRY(ctype, dtype) test_empty_input(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } From 8e0b9d339457d3ecb39bab2b74a3988d69550d3d Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:20 -0700 Subject: [PATCH 06/11] [ExecuTorch] support BF16 in op_slice_scatter Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981364](https://our.internmc.facebook.com/intern/diff/D61981364/) [ghstack-poisoned] --- kernels/portable/cpu/op_slice_scatter.cpp | 4 ++-- kernels/test/op_slice_scatter_test.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/kernels/portable/cpu/op_slice_scatter.cpp b/kernels/portable/cpu/op_slice_scatter.cpp index 367b626696f..e284af5fe97 100644 --- a/kernels/portable/cpu/op_slice_scatter.cpp +++ b/kernels/portable/cpu/op_slice_scatter.cpp @@ -74,8 +74,8 @@ Tensor& slice_scatter_out( ScalarType in_type = input.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( src_type, ctx, "slice_scatter.out", CTYPE_SRC, [&]() { CTYPE* out_data = out.mutable_data_ptr(); const CTYPE_SRC* src_data = src.const_data_ptr(); diff --git a/kernels/test/op_slice_scatter_test.cpp b/kernels/test/op_slice_scatter_test.cpp index 4901f832a33..1d5c8a43b10 100644 --- a/kernels/test/op_slice_scatter_test.cpp +++ b/kernels/test/op_slice_scatter_test.cpp @@ -49,7 +49,7 @@ class OpSliceScatterTensorOutTest : public OperatorTest { 5, 6, 7, 8, // [1, :] 9, 10, 11, 12, // [2, :] }); - + // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out), // src shape should equal to input[0:2:1, :] Tensor src = tf.make( @@ -670,7 +670,7 @@ TEST_F(OpSliceScatterTensorOutTest, LegalStepsSupported) { /// zeros(). TEST_F(OpSliceScatterTensorOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() From 741f777e298ae06373d673f29698956ca5984018 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:23 -0700 Subject: [PATCH 07/11] [ExecuTorch] support BF16 in op_scalar_tensor Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981360](https://our.internmc.facebook.com/intern/diff/D61981360/) [ghstack-poisoned] --- kernels/portable/cpu/op_scalar_tensor.cpp | 15 ++++++++------- kernels/test/op_scalar_tensor_test.cpp | 4 ++-- runtime/core/portable_type/scalar.h | 4 ++++ 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/kernels/portable/cpu/op_scalar_tensor.cpp b/kernels/portable/cpu/op_scalar_tensor.cpp index b69267c9917..b79d447f6af 100644 --- a/kernels/portable/cpu/op_scalar_tensor.cpp +++ b/kernels/portable/cpu/op_scalar_tensor.cpp @@ -24,13 +24,14 @@ Tensor& scalar_tensor_out(RuntimeContext& ctx, const Scalar& s, Tensor& out) { constexpr auto name = "scalar_tensor.out"; - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { - CTYPE_S val_s; - utils::extract_scalar(s, &val_s); - out.mutable_data_ptr()[0] = convert(val_s); - }); - }); + ET_SWITCH_REAL_TYPES_AND3( + Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() { + ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { + CTYPE_S val_s; + utils::extract_scalar(s, &val_s); + out.mutable_data_ptr()[0] = convert(val_s); + }); + }); return out; } diff --git a/kernels/test/op_scalar_tensor_test.cpp b/kernels/test/op_scalar_tensor_test.cpp index 7a2f5ca9dab..482f6073a69 100644 --- a/kernels/test/op_scalar_tensor_test.cpp +++ b/kernels/test/op_scalar_tensor_test.cpp @@ -80,7 +80,7 @@ class OpScalarTensorOutTest : public OperatorTest { test_scalar_tensor_out_0d(9); \ } -ET_FORALL_REAL_TYPES(GENERATE_TEST_0D) +ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST_0D) #define GENERATE_TEST(ctype, dtype) \ TEST_F(OpScalarTensorOutTest, dtype##Tensors) { \ @@ -98,7 +98,7 @@ ET_FORALL_REAL_TYPES(GENERATE_TEST_0D) test_scalar_tensor_out_3d(7); \ } -ET_FORALL_REAL_TYPES(GENERATE_TEST) +ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST) TEST_F(OpScalarTensorOutTest, InvalidOutShapeFails) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { diff --git a/runtime/core/portable_type/scalar.h b/runtime/core/portable_type/scalar.h index 2619f9e2614..1147fee7cc9 100644 --- a/runtime/core/portable_type/scalar.h +++ b/runtime/core/portable_type/scalar.h @@ -8,6 +8,8 @@ #pragma once +#include +#include #include #include @@ -39,6 +41,8 @@ class Scalar { /*implicit*/ Scalar(double val) : tag(Tag::Double) { v.as_double = val; } + /*implicit*/ Scalar(BFloat16 val) : Scalar((double)(float)val) {} + /*implicit*/ Scalar(Half val) : Scalar((double)(float)val) {} /// Returns the concrete scalar value stored within. template From ebdde778b810daf2036dac12b6b3cb7913aad25e Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:26 -0700 Subject: [PATCH 08/11] [ExecuTorch] support BF16 in op_where Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981359](https://our.internmc.facebook.com/intern/diff/D61981359/) [ghstack-poisoned] --- kernels/portable/cpu/op_where.cpp | 4 ++-- kernels/test/op_where_test.cpp | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index bf42447582e..6ff4cb85fb3 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -41,8 +41,8 @@ Tensor& where_out( cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, "Unhandled dtype %s for where.self_out", torch::executor::toString(cond_type)); - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_OUT = typename torch::executor::promote_types::type; apply_ternary_elementwise_fn( diff --git a/kernels/test/op_where_test.cpp b/kernels/test/op_where_test.cpp index 3388e62e2f5..7ddbbef2d74 100644 --- a/kernels/test/op_where_test.cpp +++ b/kernels/test/op_where_test.cpp @@ -80,7 +80,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_FLOAT_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -90,7 +90,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -148,7 +148,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where_enumerate_b_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -157,7 +157,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } From 1addb7d6f68e5e07ce68a54d92ef2c3e0c7b8434 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 29 Aug 2024 13:11:29 -0700 Subject: [PATCH 09/11] [ExecuTorch] support BF16 in op_add Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981362](https://our.internmc.facebook.com/intern/diff/D61981362/) [ghstack-poisoned] --- kernels/optimized/cpu/op_add.cpp | 16 ++++++++-------- kernels/portable/cpu/op_add.cpp | 20 ++++++++++++++------ kernels/test/op_add_test.cpp | 23 +++++++++++++++++------ 3 files changed, 39 insertions(+), 20 deletions(-) diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index a2a05891e54..6c2ebca0111 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -83,7 +83,7 @@ Tensor& opt_add_out( ScalarType out_type = out.scalar_type(); if (b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { auto error = resize_tensor(out, a.sizes()); ET_KERNEL_CHECK_MSG( ctx, @@ -186,12 +186,12 @@ Tensor& opt_add_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { CTYPE_IN alpha_val; ET_KERNEL_CHECK( ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); @@ -226,7 +226,7 @@ Tensor& opt_add_scalar_out( ET_CHECK(common_type == out_type); - if (common_type == ScalarType::Half) { + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { common_type = ScalarType::Float; } @@ -235,7 +235,7 @@ Tensor& opt_add_scalar_out( ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); if (a_type == common_type && a_type == out_type && - a_type != ScalarType::Half) { + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { CTYPE_B b_val; @@ -255,11 +255,11 @@ Tensor& opt_add_scalar_out( }); }); } else { - ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { ET_SWITCH_REALB_TYPES( common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES( out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { CTYPE_B b_val; ET_EXTRACT_SCALAR(b, b_val); diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index 33662ecc55a..22460e8821d 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -78,7 +78,11 @@ Tensor& add_out( InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); @@ -92,15 +96,15 @@ Tensor& add_out( constexpr auto name = "add.out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); CTYPE_IN alpha_val; utils::extract_scalar(alpha, &alpha_val); - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { AddInner< can_cast::value, CTYPE_A, @@ -130,7 +134,11 @@ Tensor& add_scalar_out( out, "Failed to resize output tensor."); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); @@ -149,7 +157,7 @@ Tensor& add_scalar_out( constexpr auto name = "add.Scalar_out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_IN = typename utils::promote_type_with_scalar_type< CTYPE_A, diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index 79a58a0c7ce..51ace05b752 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -58,6 +58,7 @@ class OpAddOutKernelTest : public OperatorTest { template void test_add_enumerate_out_types() { + test_add(); test_add(); test_add(); test_add(); @@ -73,7 +74,7 @@ class OpAddOutKernelTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -82,7 +83,7 @@ class OpAddOutKernelTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -99,13 +100,15 @@ class OpAddOutKernelTest : public OperatorTest { // Add two tensors. op_add_out( - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), + tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}), tf.ones(sizes), - /*alpha=*/1.1, + /*alpha=*/1.25, out); - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9})); + // Check that it matches the expected output. Values selected to + // be exactly representable to avoid throwing off half/bfloat16 + // tests. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125})); } }; @@ -136,6 +139,14 @@ TEST_F(OpAddOutKernelTest, DoubleTensors) { test_floating_point_add_out(); } +TEST_F(OpAddOutKernelTest, HalfTensors) { + test_floating_point_add_out(); +} + +TEST_F(OpAddOutKernelTest, BFloat16Tensors) { + test_floating_point_add_out(); +} + TEST_F(OpAddOutKernelTest, BoolAndIntInputTensor) { TensorFactory tf; TensorFactory tfi; From feeadb9ab0d4a7bf430fcf615307bf385e5bf19b Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Fri, 30 Aug 2024 10:58:30 -0700 Subject: [PATCH 10/11] Update base for Update on "[ExecuTorch] support BF16 in op_add" Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981362](https://our.internmc.facebook.com/intern/diff/D61981362/) [ghstack-poisoned] --- kernels/optimized/cpu/op_mul.cpp | 3 ++- kernels/portable/cpu/op_mul.cpp | 6 +++++- kernels/test/op_mul_test.cpp | 4 ++-- .../exec_aten/testing_util/tensor_util.cpp | 20 ++++++++++++++++--- runtime/core/portable_type/bfloat16.h | 1 + 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/kernels/optimized/cpu/op_mul.cpp b/kernels/optimized/cpu/op_mul.cpp index 4f7af01ed9b..31b0f7754fb 100644 --- a/kernels/optimized/cpu/op_mul.cpp +++ b/kernels/optimized/cpu/op_mul.cpp @@ -80,7 +80,8 @@ Tensor& opt_mul_out( ScalarType out_type = out.scalar_type(); if (b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && + a_type != ScalarType::BFloat16) { auto error = resize_tensor(out, a.sizes()); ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 1a6a57eb4a3..470eaf7b85d 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -70,7 +70,11 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { InvalidArgument, out); - ET_KERNEL_CHECK(ctx, executorch::runtime::tensor_is_realhbbf16_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 0a6abb03516..41a8656f967 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -134,8 +134,8 @@ class OpMulOutTest : public OperatorTest { TensorFactory tf_a; std::vector> b_sizeses = { - {2}, - {1, 2}, + {2}, + {1, 2}, }; for (const auto& b_sizes : b_sizeses) { // a and b of different shapes diff --git a/runtime/core/exec_aten/testing_util/tensor_util.cpp b/runtime/core/exec_aten/testing_util/tensor_util.cpp index 0712b7177bf..0301cc9a519 100644 --- a/runtime/core/exec_aten/testing_util/tensor_util.cpp +++ b/runtime/core/exec_aten/testing_util/tensor_util.cpp @@ -16,6 +16,8 @@ #include #include +using exec_aten::BFloat16; +using exec_aten::Half; using exec_aten::ScalarType; using exec_aten::Tensor; @@ -32,9 +34,7 @@ namespace { * T must be a floating point type. Non-floating point data should be compared * directly. */ -template < - typename T, - typename = std::enable_if_t::value>> +template bool data_is_close( const T* a, const T* b, @@ -119,6 +119,20 @@ bool tensors_are_close( a.numel(), rtol, atol); + } else if (a.scalar_type() == ScalarType::Half) { + return data_is_close( + a.const_data_ptr(), + b.const_data_ptr(), + a.numel(), + rtol, + atol); + } else if (a.scalar_type() == ScalarType::BFloat16) { + return data_is_close( + a.const_data_ptr(), + b.const_data_ptr(), + a.numel(), + rtol, + atol); } else { // Non-floating-point types can be compared bitwise. return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0; diff --git a/runtime/core/portable_type/bfloat16.h b/runtime/core/portable_type/bfloat16.h index 67f7478adff..e665e6152e3 100644 --- a/runtime/core/portable_type/bfloat16.h +++ b/runtime/core/portable_type/bfloat16.h @@ -11,6 +11,7 @@ #include #include #include +#include #include namespace torch { From 8075ea87e380646292935a40dcf38978bb5c67f4 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Tue, 3 Sep 2024 14:13:39 -0700 Subject: [PATCH 11/11] Update base for Update on "[ExecuTorch] support BF16 in op_add" Adding bfloat16 support to important ops for LLMs to start. Differential Revision: [D61981362](https://our.internmc.facebook.com/intern/diff/D61981362/) [ghstack-poisoned]