From 653419edbd153ccb64ac3aa516d99000d60f4cb9 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Mon, 3 Feb 2025 09:01:39 -0800 Subject: [PATCH] Implement portable abs for complex input Absolute value of a complex number is straightforward enough. Had to fix a couple other things because this is (I think) the first use of complex types in ExecuTorch. Differential Revision: [D69058051](https://our.internmc.facebook.com/intern/diff/D69058051/) [ghstack-poisoned] --- kernels/portable/cpu/op_abs.cpp | 54 ++++++++++++++----- kernels/test/op_abs_test.cpp | 26 +++++++++ .../core/exec_aten/util/scalar_type_util.h | 30 +++++++++-- 3 files changed, 91 insertions(+), 19 deletions(-) diff --git a/kernels/portable/cpu/op_abs.cpp b/kernels/portable/cpu/op_abs.cpp index 61c2cd44ddf..f59edf7ac29 100644 --- a/kernels/portable/cpu/op_abs.cpp +++ b/kernels/portable/cpu/op_abs.cpp @@ -27,23 +27,49 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) { out, "Failed to resize output tensor."); - ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out); + const bool in_is_complex = + executorch::runtime::isComplexType(in.scalar_type()); + ET_KERNEL_CHECK( + ctx, + in_is_complex || tensors_have_same_dtype(in, out), + InvalidArgument, + out); ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); - ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { - apply_unary_map_fn( - [](const CTYPE val_in) { - if (val_in < 0) { - return static_cast(-val_in); - } else { - return static_cast(val_in); - } - }, - in.const_data_ptr(), - out.mutable_data_ptr(), - in.numel()); - }); + if (in_is_complex) { + // NOTE: Elected not to add COMPLEXH to dtype_util.h for now + // because I am not planning wide rollout of complex support; if + // we do add SupportedTensorDtypes::COMPLEXH support, then we + // should use it here. + ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE_IN, [&] { + ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "abs.out", CTYPE_OUT, [&] { + apply_unary_map_fn( + [](const CTYPE_IN val_in) -> CTYPE_OUT { + return sqrt( + val_in.real_ * val_in.real_ + val_in.imag_ * val_in.imag_); + }, + // XXX: switch to in/out new-style portable op impl + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); + }); + } else { + ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] { + apply_unary_map_fn( + [](const CTYPE val_in) { + if (val_in < 0) { + return static_cast(-val_in); + } else { + return static_cast(val_in); + } + }, + in.const_data_ptr(), + out.mutable_data_ptr(), + in.numel()); + }); + } return out; } diff --git a/kernels/test/op_abs_test.cpp b/kernels/test/op_abs_test.cpp index 80e1db6f01b..0d022d0a839 100644 --- a/kernels/test/op_abs_test.cpp +++ b/kernels/test/op_abs_test.cpp @@ -38,6 +38,24 @@ class OpAbsTest : public OperatorTest { EXPECT_TENSOR_EQ(out, ret); EXPECT_TENSOR_EQ(out, expected); } + + template + void run_complex_smoke_test() { + TensorFactory tf; + constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE); + TensorFactory tf_out; + using REAL_CTYPE = + typename executorch::runtime::ScalarTypeToCppType::type; + Tensor in = tf.make( + {1, 2}, + {CTYPE{REAL_CTYPE(3), REAL_CTYPE(4)}, + CTYPE{REAL_CTYPE(5), REAL_CTYPE(12)}}); + Tensor out = tf_out.zeros({1, 2}); + Tensor expected = tf_out.make({1, 2}, {5, 13}); + Tensor ret = op_abs_out(in, out); + EXPECT_TENSOR_EQ(out, ret); + EXPECT_TENSOR_CLOSE(out, expected); + } }; TEST_F(OpAbsTest, SmokeTest) { @@ -45,6 +63,14 @@ TEST_F(OpAbsTest, SmokeTest) { // TODO: cover all REALHBF16 types with generalized unary function test // harness. ET_FORALL_FLOATHBF16_TYPES(RUN_SMOKE_TEST); +#undef RUN_SMOKE_TEST +} + +TEST_F(OpAbsTest, ComplexSmokeTest) { +#define RUN_SMOKE_TEST(ctype, dtype) \ + run_complex_smoke_test(); + ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST); +#undef RUN_SMOKE_TEST } TEST_F(OpAbsTest, MemoryFormatCheck) { diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 735a46c68f7..7c8162b3cdc 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -348,9 +348,14 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType) // In this context, "COMPLEX" means complex types based on primitive C types, // which is why ComplexHalf is not included. -#define ET_FORALL_COMPLEX_TYPES(_) \ - _(::torch::executor::complex, ComplexFloat) \ - _(::torch::executor::complex, ComplexDouble) +#define ET_FORALL_COMPLEX_TYPES(_) \ + _(::executorch::aten::complex, ComplexFloat) \ + _(::executorch::aten::complex, ComplexDouble) + +#define ET_FORALL_COMPLEXH_TYPES(_) \ + _(::executorch::aten::complex<::executorch::aten::Half>, ComplexHalf) \ + _(::executorch::aten::complex, ComplexFloat) \ + _(::executorch::aten::complex, ComplexDouble) // // Utility functions to retrieve metadata for a given ScalarType @@ -593,7 +598,7 @@ inline bool isUnderlying( return type == ::executorch::runtime::toUnderlying(qtype); } -inline ::executorch::aten::ScalarType toRealValueType( +inline constexpr ::executorch::aten::ScalarType toRealValueType( ::executorch::aten::ScalarType t) { switch (t) { case ::executorch::aten::ScalarType::ComplexHalf: @@ -607,7 +612,7 @@ inline ::executorch::aten::ScalarType toRealValueType( } } -inline ::executorch::aten::ScalarType toComplexType( +inline constexpr ::executorch::aten::ScalarType toComplexType( ::executorch::aten::ScalarType t) { switch (t) { case ::executorch::aten::ScalarType::BFloat16: @@ -1060,6 +1065,14 @@ struct promote_types { ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) +#define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \ + ET_INTERNAL_SWITCH_CASE( \ + ::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__) + #define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH_CASE( \ ::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \ @@ -1278,6 +1291,13 @@ struct promote_types { NAME, \ ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__)) +#define ET_SWITCH_COMPLEXH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ + ET_INTERNAL_SWITCH( \ + TYPE, \ + CONTEXT, \ + NAME, \ + ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, __VA_ARGS__)) + #define ET_SWITCH_SCALAR_OBJ_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \ ET_INTERNAL_SWITCH( \ TYPE, \