diff --git a/kernels/portable/cpu/op_eq.cpp b/kernels/portable/cpu/op_eq.cpp index 552053556cc..bddb6181ee0 100644 --- a/kernels/portable/cpu/op_eq.cpp +++ b/kernels/portable/cpu/op_eq.cpp @@ -6,62 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& eq_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted == b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "eq.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& eq_scalar_out( @@ -69,48 +27,9 @@ Tensor& eq_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "eq.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "eq.Scalar_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "eq.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted == b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "eq.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_ge.cpp b/kernels/portable/cpu/op_ge.cpp index 7195f80f65d..8457f91b548 100644 --- a/kernels/portable/cpu/op_ge.cpp +++ b/kernels/portable/cpu/op_ge.cpp @@ -6,63 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& ge_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ge.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "ge.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "ge.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted >= b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "ge.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& ge_scalar_out( @@ -70,48 +27,9 @@ Tensor& ge_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ge.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "ge.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "ge.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "ge.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted >= b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "ge.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_gt.cpp b/kernels/portable/cpu/op_gt.cpp index e3cceedabc7..bb1f6a274cd 100644 --- a/kernels/portable/cpu/op_gt.cpp +++ b/kernels/portable/cpu/op_gt.cpp @@ -6,63 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& gt_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "gt.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "gt.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "gt.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted > b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "gt.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& gt_scalar_out( @@ -70,48 +27,9 @@ Tensor& gt_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "gt.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "gt.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "gt.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "gt.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted > b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "gt.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_le.cpp b/kernels/portable/cpu/op_le.cpp index 5b241dbb8a3..e893678fc5e 100644 --- a/kernels/portable/cpu/op_le.cpp +++ b/kernels/portable/cpu/op_le.cpp @@ -6,63 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& le_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "le.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "le.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted <= b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "le.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& le_scalar_out( @@ -70,48 +27,9 @@ Tensor& le_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "le.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted <= b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "le.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_lt.cpp b/kernels/portable/cpu/op_lt.cpp index 6f2ba4d6936..6f1ffb21153 100644 --- a/kernels/portable/cpu/op_lt.cpp +++ b/kernels/portable/cpu/op_lt.cpp @@ -6,63 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& lt_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - // Determine output size and resize for dynamic shapes - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "lt.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "lt.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "lt.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted < b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "lt.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& lt_scalar_out( @@ -70,48 +27,9 @@ Tensor& lt_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = utils::promote_type_with_scalar(a_type, b); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "lt.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "lt.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "lt.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "lt.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted < b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "lt.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/op_ne.cpp b/kernels/portable/cpu/op_ne.cpp index 194b9de3720..5e5a2d38a33 100644 --- a/kernels/portable/cpu/op_ne.cpp +++ b/kernels/portable/cpu/op_ne.cpp @@ -6,62 +6,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include -#include -#include -#include +#include namespace torch { namespace executor { namespace native { -using Tensor = exec_aten::Tensor; -using ScalarType = exec_aten::ScalarType; - Tensor& ne_tensor_out( KernelRuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) { - ET_KERNEL_CHECK( - ctx, - resize_to_broadcast_target_size(a, b, out) == Error::Ok, - InvalidArgument, - out); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = b.scalar_type(); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ne.Tensor_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "ne.Tensor_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == - promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "ne.Tensor_out", CTYPE_OUT, [&]() { - apply_binary_elementwise_fn( - [](const CTYPE_A val_a, const CTYPE_B val_b) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted != b_casted; - return static_cast(value); - }, - a, - b, - out); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "ne.Tensor_out"; + return internal::comparison_tensor_out(ctx, a, b, out); } Tensor& ne_scalar_out( @@ -69,47 +27,9 @@ Tensor& ne_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - // Resize for dynamic shape - ET_KERNEL_CHECK_MSG( - ctx, - resize_tensor(out, a.sizes()) == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_KERNEL_CHECK( - ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - - ScalarType a_type = a.scalar_type(); - ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType out_type = out.scalar_type(); - - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "ne.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "ne.Scalar_out", CTYPE_B, [&]() { - using CTYPE_IN = - typename torch::executor::promote_types::type; - ET_DCHECK( - CppTypeToScalarType::value == promoteTypes(a_type, b_type)); - ET_SWITCH_REAL_TYPES_AND( - Bool, out_type, ctx, "ne.Scalar_out", CTYPE_OUT, [&]() { - CTYPE_B val_b = 0; - utils::extract_scalar(b, &val_b); - apply_unary_map_fn( - [val_b](const CTYPE_A val_a) { - CTYPE_IN a_casted = static_cast(val_a); - CTYPE_IN b_casted = static_cast(val_b); - bool value = a_casted != b_casted; - return static_cast(value); - }, - a.const_data_ptr(), - out.mutable_data_ptr(), - out.numel()); - }); - }); - }); - - return out; + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "ne.Scalar_out"; + return internal::comparison_scalar_out(ctx, a, b, out); } } // namespace native diff --git a/kernels/portable/cpu/pattern/comparison_op.h b/kernels/portable/cpu/pattern/comparison_op.h new file mode 100644 index 00000000000..e0d9bf4dcab --- /dev/null +++ b/kernels/portable/cpu/pattern/comparison_op.h @@ -0,0 +1,147 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +namespace torch { +namespace executor { +namespace native { +namespace internal { + +#define DEFINE_BINARY_OPERATOR_TEMPLATE(name, op) \ + template \ + T name(const T val_a, const T val_b) { \ + return val_a op val_b; \ + } + +DEFINE_BINARY_OPERATOR_TEMPLATE(eq, ==) +DEFINE_BINARY_OPERATOR_TEMPLATE(ne, !=) +DEFINE_BINARY_OPERATOR_TEMPLATE(ge, >=) +DEFINE_BINARY_OPERATOR_TEMPLATE(le, <=) +DEFINE_BINARY_OPERATOR_TEMPLATE(gt, >) +DEFINE_BINARY_OPERATOR_TEMPLATE(lt, <) + +template +using comparison_fn = T (*)(const T, const T); + +template +constexpr comparison_fn get_comparison_fn() { + std::string_view op = op_name; + if (op == "eq.Tensor_out" || op == "eq.Scalar_out") { + return eq; + } + if (op == "ne.Tensor_out" || op == "ne.Scalar_out") { + return ne; + } + if (op == "ge.Tensor_out" || op == "ge.Scalar_out") { + return ge; + } + if (op == "le.Tensor_out" || op == "le.Scalar_out") { + return le; + } + if (op == "gt.Tensor_out" || op == "gt.Scalar_out") { + return gt; + } + if (op == "lt.Tensor_out" || op == "lt.Scalar_out") { + return lt; + } + return nullptr; +}; + +template +struct ComparisonFnForOp { + static constexpr auto value = get_comparison_fn(); + static_assert(value != nullptr, "unknown op_name!"); +}; + +template +Tensor& comparison_tensor_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Tensor& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = promoteTypes(a.scalar_type(), b.scalar_type()); + if (executorch::runtime::isFloatingType(common_type) && + a.scalar_type() != b.scalar_type()) { + common_type = ScalarType::Float; + } + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + utils::apply_bitensor_elementwise_fn( + ComparisonFnForOp::value, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + b, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + + return out; +} + +template +Tensor& comparison_scalar_out( + KernelRuntimeContext& ctx, + const Tensor& a, + const Scalar& b, + Tensor& out) { + // Common Dtype + ScalarType common_type = utils::promote_type_with_scalar(a.scalar_type(), b); + + // Check Dim Order + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + // Resize + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // Compute Dtype + ScalarType compute_type = utils::get_compute_type(common_type); + + ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() { + const CTYPE_COMPUTE val_b = utils::scalar_to(b); + utils::apply_unitensor_elementwise_fn( + [val_b](const CTYPE_COMPUTE val_a) { + return ComparisonFnForOp::value(val_a, val_b); + }, + ctx, + a, + utils::SupportedTensorDtypes::REALHBBF16, + out, + utils::SupportedTensorDtypes::REALHBBF16); + }); + + return out; +} + +} // namespace internal +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/portable/cpu/pattern/targets.bzl b/kernels/portable/cpu/pattern/targets.bzl index 250a8fb6a72..81b4345c75e 100644 --- a/kernels/portable/cpu/pattern/targets.bzl +++ b/kernels/portable/cpu/pattern/targets.bzl @@ -14,6 +14,7 @@ def define_common_targets(): deps = [ "//executorch/kernels/portable/cpu/pattern:pattern", "//executorch/kernels/portable/cpu/pattern:bitwise_op", + "//executorch/kernels/portable/cpu/pattern:comparison_op", ], visibility = ["//executorch/...", "@EXECUTORCH_CLIENTS"], ) @@ -30,6 +31,15 @@ def define_common_targets(): visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], ) + runtime.cxx_library( + name = "comparison_op", + exported_headers = [ + "comparison_op.h", + ], + compiler_flags = [], + visibility = ["//executorch/kernels/portable/cpu/...", "//executorch/kernels/optimized/cpu/..."], + ) + runtime.cxx_library( name = "pattern", srcs = [ diff --git a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl index 26e16b8fecc..8f66cb4acc6 100644 --- a/shim/xplat/executorch/kernels/portable/op_registration_util.bzl +++ b/shim/xplat/executorch/kernels/portable/op_registration_util.bzl @@ -478,8 +478,9 @@ ATEN_OPS = ( name = "op_eq", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -576,8 +577,9 @@ ATEN_OPS = ( name = "op_ge", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -600,8 +602,9 @@ ATEN_OPS = ( name = "op_gt", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -648,8 +651,9 @@ ATEN_OPS = ( name = "op_le", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -733,8 +737,9 @@ ATEN_OPS = ( name = "op_lt", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target( @@ -841,8 +846,9 @@ ATEN_OPS = ( name = "op_ne", deps = [ ":scalar_utils", + "//executorch/kernels/portable/cpu/pattern:comparison_op", "//executorch/kernels/portable/cpu/util:broadcast_util", - "//executorch/kernels/portable/cpu/util:functional_util", + "//executorch/kernels/portable/cpu/util:elementwise_util", ], ), op_target(