diff --git a/kernels/optimized/cpu/op_add.cpp b/kernels/optimized/cpu/op_add.cpp index a2a05891e54..d46dd85fb3f 100644 --- a/kernels/optimized/cpu/op_add.cpp +++ b/kernels/optimized/cpu/op_add.cpp @@ -83,7 +83,8 @@ Tensor& opt_add_out( ScalarType out_type = out.scalar_type(); if (b.numel() == 1) { - if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) { + if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half && + a_type != ScalarType::BFloat16) { auto error = resize_tensor(out, a.sizes()); ET_KERNEL_CHECK_MSG( ctx, @@ -186,12 +187,12 @@ Tensor& opt_add_out( InvalidArgument, out); - ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); - ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() { CTYPE_IN alpha_val; ET_KERNEL_CHECK( ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, ); @@ -226,7 +227,7 @@ Tensor& opt_add_scalar_out( ET_CHECK(common_type == out_type); - if (common_type == ScalarType::Half) { + if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) { common_type = ScalarType::Float; } @@ -235,7 +236,7 @@ Tensor& opt_add_scalar_out( ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor."); if (a_type == common_type && a_type == out_type && - a_type != ScalarType::Half) { + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { CTYPE_B b_val; @@ -255,11 +256,11 @@ Tensor& opt_add_scalar_out( }); }); } else { - ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() { ET_SWITCH_REALB_TYPES( common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES( out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() { CTYPE_B b_val; ET_EXTRACT_SCALAR(b, b_val); diff --git a/kernels/portable/cpu/op_add.cpp b/kernels/portable/cpu/op_add.cpp index a435e4ee658..2cc01a97fa6 100644 --- a/kernels/portable/cpu/op_add.cpp +++ b/kernels/portable/cpu/op_add.cpp @@ -78,7 +78,11 @@ Tensor& add_out( InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); @@ -94,15 +98,15 @@ Tensor& add_out( constexpr auto name = "add.out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_IN = typename torch::executor:: promote_types::type; ET_DCHECK(CppTypeToScalarType::value == common_type); CTYPE_IN alpha_val; utils::extract_scalar(alpha, &alpha_val); - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() { AddInner< can_cast::value, CTYPE_A, @@ -132,7 +136,11 @@ Tensor& add_scalar_out( out, "Failed to resize output tensor."); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); @@ -153,7 +161,7 @@ Tensor& add_scalar_out( constexpr auto name = "add.Scalar_out"; - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_IN = typename utils::promote_type_with_scalar_type< CTYPE_A, diff --git a/kernels/portable/cpu/op_copy.cpp b/kernels/portable/cpu/op_copy.cpp index 764a50a5d20..86f2d5c62be 100644 --- a/kernels/portable/cpu/op_copy.cpp +++ b/kernels/portable/cpu/op_copy.cpp @@ -45,8 +45,8 @@ Tensor& copy_out( ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); @@ -75,8 +75,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) { ScalarType in_type = in.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() { apply_binary_elementwise_fn( [](const CTYPE val_in, const CTYPE_SRC val_src) { return convert(val_src); diff --git a/kernels/portable/cpu/op_mm.cpp b/kernels/portable/cpu/op_mm.cpp index 4a6a8f3cfdc..1241182e4a9 100644 --- a/kernels/portable/cpu/op_mm.cpp +++ b/kernels/portable/cpu/op_mm.cpp @@ -34,19 +34,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) { ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); - ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { - size_t m = in.size(0); - size_t n = in.size(1); - size_t p = mat2.size(1); - - vec_matmul( - out.mutable_data_ptr(), - in.const_data_ptr(), - mat2.const_data_ptr(), - m, - n, - p); - }); + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t m = in.size(0); + size_t n = in.size(1); + size_t p = mat2.size(1); + + vec_matmul( + out.mutable_data_ptr(), + in.const_data_ptr(), + mat2.const_data_ptr(), + m, + n, + p); + }); return out; } diff --git a/kernels/portable/cpu/op_scalar_tensor.cpp b/kernels/portable/cpu/op_scalar_tensor.cpp index b69267c9917..b79d447f6af 100644 --- a/kernels/portable/cpu/op_scalar_tensor.cpp +++ b/kernels/portable/cpu/op_scalar_tensor.cpp @@ -24,13 +24,14 @@ Tensor& scalar_tensor_out(RuntimeContext& ctx, const Scalar& s, Tensor& out) { constexpr auto name = "scalar_tensor.out"; - ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { - CTYPE_S val_s; - utils::extract_scalar(s, &val_s); - out.mutable_data_ptr()[0] = convert(val_s); - }); - }); + ET_SWITCH_REAL_TYPES_AND3( + Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() { + ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() { + CTYPE_S val_s; + utils::extract_scalar(s, &val_s); + out.mutable_data_ptr()[0] = convert(val_s); + }); + }); return out; } diff --git a/kernels/portable/cpu/op_slice_scatter.cpp b/kernels/portable/cpu/op_slice_scatter.cpp index a1f9ce4d921..47374716b4e 100644 --- a/kernels/portable/cpu/op_slice_scatter.cpp +++ b/kernels/portable/cpu/op_slice_scatter.cpp @@ -74,8 +74,8 @@ Tensor& slice_scatter_out( ScalarType in_type = input.scalar_type(); ScalarType src_type = src.scalar_type(); - ET_SWITCH_REALHB_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() { - ET_SWITCH_REALHB_TYPES( + ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() { + ET_SWITCH_REALHBBF16_TYPES( src_type, ctx, "slice_scatter.out", CTYPE_SRC, [&]() { CTYPE* out_data = out.mutable_data_ptr(); const CTYPE_SRC* src_data = src.const_data_ptr(); diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index bf42447582e..6ff4cb85fb3 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -41,8 +41,8 @@ Tensor& where_out( cond_type == ScalarType::Bool || cond_type == ScalarType::Byte, "Unhandled dtype %s for where.self_out", torch::executor::toString(cond_type)); - ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() { - ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() { using CTYPE_OUT = typename torch::executor::promote_types::type; apply_ternary_elementwise_fn( diff --git a/kernels/test/op_add_test.cpp b/kernels/test/op_add_test.cpp index 79a58a0c7ce..51ace05b752 100644 --- a/kernels/test/op_add_test.cpp +++ b/kernels/test/op_add_test.cpp @@ -58,6 +58,7 @@ class OpAddOutKernelTest : public OperatorTest { template void test_add_enumerate_out_types() { + test_add(); test_add(); test_add(); test_add(); @@ -73,7 +74,7 @@ class OpAddOutKernelTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_out_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -82,7 +83,7 @@ class OpAddOutKernelTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_add_enumerate_b_types(); - ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -99,13 +100,15 @@ class OpAddOutKernelTest : public OperatorTest { // Add two tensors. op_add_out( - tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}), + tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}), tf.ones(sizes), - /*alpha=*/1.1, + /*alpha=*/1.25, out); - // Check that it matches the expected output. - EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9})); + // Check that it matches the expected output. Values selected to + // be exactly representable to avoid throwing off half/bfloat16 + // tests. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125})); } }; @@ -136,6 +139,14 @@ TEST_F(OpAddOutKernelTest, DoubleTensors) { test_floating_point_add_out(); } +TEST_F(OpAddOutKernelTest, HalfTensors) { + test_floating_point_add_out(); +} + +TEST_F(OpAddOutKernelTest, BFloat16Tensors) { + test_floating_point_add_out(); +} + TEST_F(OpAddOutKernelTest, BoolAndIntInputTensor) { TensorFactory tf; TensorFactory tfi; diff --git a/kernels/test/op_copy_test.cpp b/kernels/test/op_copy_test.cpp index 82332f85eb2..007b10a7636 100644 --- a/kernels/test/op_copy_test.cpp +++ b/kernels/test/op_copy_test.cpp @@ -125,13 +125,13 @@ class OpCopyInplaceTest : public OperatorTest { // regular test for copy.out TEST_F(OpCopyTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } TEST_F(OpCopyTest, EmptyInputSupported) { #define TEST_ENTRY(ctype, dtype) test_empty_input(); - ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY } diff --git a/kernels/test/op_mm_test.cpp b/kernels/test/op_mm_test.cpp index 70d4b5ff0f5..c05792523f2 100644 --- a/kernels/test/op_mm_test.cpp +++ b/kernels/test/op_mm_test.cpp @@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) { /// zeros(). TEST_F(OpMmOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() diff --git a/kernels/test/op_scalar_tensor_test.cpp b/kernels/test/op_scalar_tensor_test.cpp index 7a2f5ca9dab..482f6073a69 100644 --- a/kernels/test/op_scalar_tensor_test.cpp +++ b/kernels/test/op_scalar_tensor_test.cpp @@ -80,7 +80,7 @@ class OpScalarTensorOutTest : public OperatorTest { test_scalar_tensor_out_0d(9); \ } -ET_FORALL_REAL_TYPES(GENERATE_TEST_0D) +ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST_0D) #define GENERATE_TEST(ctype, dtype) \ TEST_F(OpScalarTensorOutTest, dtype##Tensors) { \ @@ -98,7 +98,7 @@ ET_FORALL_REAL_TYPES(GENERATE_TEST_0D) test_scalar_tensor_out_3d(7); \ } -ET_FORALL_REAL_TYPES(GENERATE_TEST) +ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, GENERATE_TEST) TEST_F(OpScalarTensorOutTest, InvalidOutShapeFails) { if (torch::executor::testing::SupportedFeatures::get()->is_aten) { diff --git a/kernels/test/op_slice_scatter_test.cpp b/kernels/test/op_slice_scatter_test.cpp index 4901f832a33..1d5c8a43b10 100644 --- a/kernels/test/op_slice_scatter_test.cpp +++ b/kernels/test/op_slice_scatter_test.cpp @@ -49,7 +49,7 @@ class OpSliceScatterTensorOutTest : public OperatorTest { 5, 6, 7, 8, // [1, :] 9, 10, 11, 12, // [2, :] }); - + // op_slice_scatter_out(input, src, /*dim=*/0, /*start=*/0, /*end=*/2, /*step=*/1, out), // src shape should equal to input[0:2:1, :] Tensor src = tf.make( @@ -670,7 +670,7 @@ TEST_F(OpSliceScatterTensorOutTest, LegalStepsSupported) { /// zeros(). TEST_F(OpSliceScatterTensorOutTest, AllRealDtypesSupported) { #define TEST_ENTRY(ctype, dtype) test_dtype(); - ET_FORALL_REAL_TYPES(TEST_ENTRY); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); #undef TEST_ENTRY // TODO: Also add tests for half, complex, quantized, and other types. Easiest // way to do that would be to make TensorFactory support zeros() and ones() diff --git a/kernels/test/op_where_test.cpp b/kernels/test/op_where_test.cpp index 3388e62e2f5..7ddbbef2d74 100644 --- a/kernels/test/op_where_test.cpp +++ b/kernels/test/op_where_test.cpp @@ -80,7 +80,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_FLOAT_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -90,7 +90,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -148,7 +148,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where_enumerate_b_types(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } @@ -157,7 +157,7 @@ class OpWhereOutTest : public OperatorTest { #define ENUMERATE_TEST_ENTRY(ctype, dtype) \ test_where(); - ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY) + ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY) #undef ENUMERATE_TEST_ENTRY } diff --git a/runtime/core/portable_type/scalar.h b/runtime/core/portable_type/scalar.h index 2619f9e2614..1147fee7cc9 100644 --- a/runtime/core/portable_type/scalar.h +++ b/runtime/core/portable_type/scalar.h @@ -8,6 +8,8 @@ #pragma once +#include +#include #include #include @@ -39,6 +41,8 @@ class Scalar { /*implicit*/ Scalar(double val) : tag(Tag::Double) { v.as_double = val; } + /*implicit*/ Scalar(BFloat16 val) : Scalar((double)(float)val) {} + /*implicit*/ Scalar(Half val) : Scalar((double)(float)val) {} /// Returns the concrete scalar value stored within. template