From 071ce021bf09381dafe5f5f853f5487df9b76017 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Thu, 26 Sep 2024 13:53:43 -0700 Subject: [PATCH] [ExecuTorch] Support bf16 for binary logical ops Differential Revision: [D63486223](https://our.internmc.facebook.com/intern/diff/D63486223/) [ghstack-poisoned] --- ...ary_ufunc_realb_realb_to_realb_logical.cpp | 6 +- kernels/test/BinaryLogicalOpTest.cpp | 27 +++++++ kernels/test/BinaryLogicalOpTest.h | 72 +++++++++++++++++++ kernels/test/op_logical_and_test.cpp | 23 +++--- kernels/test/op_logical_or_test.cpp | 23 +++--- kernels/test/op_logical_xor_test.cpp | 21 +++--- kernels/test/targets.bzl | 2 + 7 files changed, 142 insertions(+), 32 deletions(-) create mode 100644 kernels/test/BinaryLogicalOpTest.cpp create mode 100644 kernels/test/BinaryLogicalOpTest.h diff --git a/kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp b/kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp index 0c454cae792..ebc685afa51 100644 --- a/kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp +++ b/kernels/portable/cpu/pattern/binary_ufunc_realb_realb_to_realb_logical.cpp @@ -34,9 +34,9 @@ Tensor& binary_ufunc_realb_realb_to_realb_logical( ScalarType b_type = b.scalar_type(); ScalarType out_type = out.scalar_type(); - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, __func__, CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, __func__, CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, __func__, CTYPE_OUT, [&]() { + ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, __func__, CTYPE_A, [&]() { + ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, __func__, CTYPE_B, [&]() { + ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, __func__, CTYPE_OUT, [&]() { apply_binary_elementwise_fn( [fn](const CTYPE_A val_a, const CTYPE_B val_b) { bool a_casted = static_cast(val_a); diff --git a/kernels/test/BinaryLogicalOpTest.cpp b/kernels/test/BinaryLogicalOpTest.cpp new file mode 100644 index 00000000000..7557e7c9068 --- /dev/null +++ b/kernels/test/BinaryLogicalOpTest.cpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace torch::executor::testing { + +void BinaryLogicalOpTest::test_all_dtypes() { +#define TEST_ENTRY(ctype, dtype) \ + test_op_out(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#define TEST_ENTRY(ctype, dtype) \ + test_op_out(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +#define TEST_ENTRY(ctype, dtype) \ + test_op_out(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY +} +} // namespace torch::executor::testing diff --git a/kernels/test/BinaryLogicalOpTest.h b/kernels/test/BinaryLogicalOpTest.h new file mode 100644 index 00000000000..0cf412c3373 --- /dev/null +++ b/kernels/test/BinaryLogicalOpTest.h @@ -0,0 +1,72 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include + +namespace torch::executor::testing { +class BinaryLogicalOpTest : public OperatorTest { + protected: + // Implement this to call the torch::executor::aten::op_outf function for the + // op. + virtual exec_aten::Tensor& op_out( + const exec_aten::Tensor& lhs, + const exec_aten::Tensor& rhs, + exec_aten::Tensor& out) = 0; + + // Scalar reference implementation of the function in question for testing. + virtual double op_reference(double x, double y) const = 0; + + template < + exec_aten::ScalarType IN_DTYPE, + exec_aten::ScalarType IN_DTYPE2, + exec_aten::ScalarType OUT_DTYPE> + void test_op_out() { + TensorFactory tf_in; + TensorFactory tf_in2; + TensorFactory tf_out; + + exec_aten::Tensor out = tf_out.zeros({1, 4}); + + using CTYPE1 = typename decltype(tf_in)::ctype; + std::vector test_vector1 = {0, CTYPE1(-1), CTYPE1(0), CTYPE1(31)}; + + using CTYPE2 = typename decltype(tf_in2)::ctype; + std::vector test_vector2 = { + CTYPE2(0), + CTYPE2(0), + CTYPE2(15), + CTYPE2(12), + }; + + std::vector expected_vector; + for (int ii = 0; ii < test_vector1.size(); ++ii) { + expected_vector.push_back( + op_reference(test_vector1[ii], test_vector2[ii])); + } + + op_out( + tf_in.make({1, 4}, test_vector1), + tf_in2.make({1, 4}, test_vector2), + out); + + EXPECT_TENSOR_CLOSE(out, tf_out.make({1, 4}, expected_vector)); + } + + void test_all_dtypes(); +}; + +#define IMPLEMENT_BINARY_LOGICAL_OP_TEST(TestName) \ + TEST_F(TestName, SimpleTestAllTypes) { \ + test_all_dtypes(); \ + } +} // namespace torch::executor::testing diff --git a/kernels/test/op_logical_and_test.cpp b/kernels/test/op_logical_and_test.cpp index 68422ee7493..454b2f0d663 100644 --- a/kernels/test/op_logical_and_test.cpp +++ b/kernels/test/op_logical_and_test.cpp @@ -6,23 +6,26 @@ * LICENSE file in the root directory of this source tree. */ +#include #include // Declares the operator -#include -#include -#include -#include #include -using namespace ::testing; -using exec_aten::ScalarType; using exec_aten::Tensor; -using torch::executor::testing::TensorFactory; -class OpLogicalAndTest : public OperatorTest { +class OpLogicalAndTest : public torch::executor::testing::BinaryLogicalOpTest { protected: - Tensor& - op_logical_and_out(const Tensor& self, const Tensor& other, Tensor& out) { + Tensor& op_out(const Tensor& self, const Tensor& other, Tensor& out) + override { return torch::executor::aten::logical_and_outf(context_, self, other, out); } + + double op_reference(double x, double y) const override { + uint64_t lhs, rhs; + std::memcpy(&lhs, &x, sizeof(lhs)); + std::memcpy(&rhs, &y, sizeof(rhs)); + return lhs && rhs; + } }; + +IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalAndTest) diff --git a/kernels/test/op_logical_or_test.cpp b/kernels/test/op_logical_or_test.cpp index e8dfb5e589e..1a274f6f7d4 100644 --- a/kernels/test/op_logical_or_test.cpp +++ b/kernels/test/op_logical_or_test.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Meta Platforms, Inc. andaffiliates. * All rights reserved. * * This source code is licensed under the BSD-style license found in the @@ -7,22 +7,25 @@ */ #include // Declares the operator -#include -#include -#include -#include +#include #include -using namespace ::testing; -using exec_aten::ScalarType; using exec_aten::Tensor; -using torch::executor::testing::TensorFactory; -class OpLogicalOrTest : public OperatorTest { +class OpLogicalOrTest : public torch::executor::testing::BinaryLogicalOpTest { protected: Tensor& - op_logical_or_out(const Tensor& self, const Tensor& other, Tensor& out) { + op_out(const Tensor& self, const Tensor& other, Tensor& out) override { return torch::executor::aten::logical_or_outf(context_, self, other, out); } + + double op_reference(double x, double y) const override { + uint64_t lhs, rhs; + std::memcpy(&lhs, &x, sizeof(lhs)); + std::memcpy(&rhs, &y, sizeof(rhs)); + return lhs || rhs; + } }; + +IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalOrTest) diff --git a/kernels/test/op_logical_xor_test.cpp b/kernels/test/op_logical_xor_test.cpp index ab162a27967..969d82367ea 100644 --- a/kernels/test/op_logical_xor_test.cpp +++ b/kernels/test/op_logical_xor_test.cpp @@ -7,22 +7,25 @@ */ #include // Declares the operator -#include -#include -#include -#include +#include #include -using namespace ::testing; -using exec_aten::ScalarType; using exec_aten::Tensor; -using torch::executor::testing::TensorFactory; -class OpLogicalXorTest : public OperatorTest { +class OpLogicalXorTest : public torch::executor::testing::BinaryLogicalOpTest { protected: Tensor& - op_logical_xor_out(const Tensor& self, const Tensor& other, Tensor& out) { + op_out(const Tensor& self, const Tensor& other, Tensor& out) override { return torch::executor::aten::logical_xor_outf(context_, self, other, out); } + + double op_reference(double x, double y) const override { + uint64_t lhs, rhs; + std::memcpy(&lhs, &x, sizeof(lhs)); + std::memcpy(&rhs, &y, sizeof(rhs)); + return bool(lhs) != bool(rhs); + } }; + +IMPLEMENT_BINARY_LOGICAL_OP_TEST(OpLogicalXorTest) diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 0c002db02ad..7bc2e7555c6 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -44,9 +44,11 @@ def define_common_targets(): runtime.cxx_library( name = "test_util" + aten_suffix, srcs = [ + "BinaryLogicalOpTest.cpp", "UnaryUfuncRealHBBF16ToFloatHBF16Test.cpp", ], exported_headers = [ + "BinaryLogicalOpTest.h", "TestUtil.h", "UnaryUfuncRealHBBF16ToFloatHBF16Test.h", ],