diff --git a/kernels/optimized/cpu/op_le.cpp b/kernels/optimized/cpu/op_le.cpp index 51fca9b0063..60696f1d2f1 100644 --- a/kernels/optimized/cpu/op_le.cpp +++ b/kernels/optimized/cpu/op_le.cpp @@ -27,24 +27,25 @@ Tensor& opt_le_tensor_out( const Tensor& a, const Tensor& b, Tensor& out) { - (void)ctx; - ScalarType a_type = a.scalar_type(); ScalarType out_type = out.scalar_type(); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, + resize_to_broadcast_target_size(a, b, out) == Error::Ok, + InvalidArgument, + out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "le.Tensor_out"; + // Check for optimized broadcast paths auto selected_optimized_path = select_optimized_path(a, b, out); if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) { - // Resize for dynamic shape - auto error = resize_to_broadcast_target_size(a, b, out); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - - ET_SWITCH_REALB_TYPES(a_type, ctx, "le.Tensor_out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { using Vec = at::vec::Vectorized; at::vec::map2( [](Vec x, Vec y) { return x.le(y); }, @@ -55,16 +56,13 @@ Tensor& opt_le_tensor_out( }); } else if (selected_optimized_path != ElementwiseOptimizedPath::kNone) { // Handle optimized broadcast cases - ET_SWITCH_REALB_TYPES(out_type, ctx, "le.Tensor_out", CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() { auto le_lambda = [](auto x, auto y) { return x.le(y); }; torch::executor::handle_broadcast_elementwise( ctx, le_lambda, a, b, out, selected_optimized_path); }); } else { - // @lint-ignore CLANGTIDY facebook-hte-CArray - static constexpr const char op_name[] = "le.Tensor_out"; - return internal::comparison_tensor_out( - ctx, a, b, out); + internal::comparison_tensor_out(ctx, a, b, out); } return out; @@ -75,66 +73,37 @@ Tensor& opt_le_scalar_out( const Tensor& a, const Scalar& b, Tensor& out) { - (void)ctx; - - // Resize for dynamic shape - auto error = resize_tensor(out, a.sizes()); - ET_KERNEL_CHECK_MSG( - ctx, - error == Error::Ok, - InvalidArgument, - out, - "Failed to resize output tensor."); - ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); - ScalarType common_type = promoteTypes(a_type, b_type); + ScalarType common_type = utils::promote_type_with_scalar(a_type, b); ScalarType out_type = out.scalar_type(); - if (a_type == common_type && a_type == out_type) { - ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "le.Scalar_out", CTYPE, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() { - CTYPE_B b_val = 0; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE b_casted = static_cast(b_val); - using Vec = at::vec::Vectorized; - at::vec::map( - [b_casted](Vec x) { return x.le(Vec(b_casted)); }, - out.mutable_data_ptr(), - a.const_data_ptr(), - a.numel()); - }); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + + ET_KERNEL_CHECK( + ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out); + + // @lint-ignore CLANGTIDY facebook-hte-CArray + static constexpr const char op_name[] = "le.Scalar_out"; + + if (a_type == common_type && a_type == out_type && + a_type != ScalarType::Half && a_type != ScalarType::BFloat16) { + ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() { + ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() { + CTYPE_B b_val = 0; + ET_EXTRACT_SCALAR(b, b_val); + CTYPE b_casted = static_cast(b_val); + using Vec = at::vec::Vectorized; + at::vec::map( + [b_casted](Vec x) { return x.le(Vec(b_casted)); }, + out.mutable_data_ptr(), + a.const_data_ptr(), + a.numel()); + }); }); } else { - ET_SWITCH_REAL_TYPES_AND( - Bool, a_type, ctx, "le.Scalar_out", CTYPE_A, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, b_type, ctx, "le.Scalar_out", CTYPE_B, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, common_type, ctx, "le.Scalar_out", CTYPE_IN, [&]() { - ET_SWITCH_REAL_TYPES_AND( - Bool, - out_type, - ctx, - "le.Scalar_out", - CTYPE_OUT, - [&]() { - CTYPE_B b_val = 0; - ET_EXTRACT_SCALAR(b, b_val); - CTYPE_IN b_casted = static_cast(b_val); - const size_t n = a.numel(); - const CTYPE_A* a_data = a.const_data_ptr(); - CTYPE_OUT* out_data = - out.mutable_data_ptr(); - for (auto i = 0; i < n; ++i) { - out_data[i] = static_cast( - static_cast(a_data[i]) <= b_casted); - } - }); - }); - }); - }); + internal::comparison_scalar_out(ctx, a, b, out); } return out; diff --git a/kernels/test/op_le_test.cpp b/kernels/test/op_le_test.cpp index 4a9b97dfe8a..1baf098f9dd 100644 --- a/kernels/test/op_le_test.cpp +++ b/kernels/test/op_le_test.cpp @@ -67,11 +67,11 @@ TEST_F(OpLeScalarOutTest, AllRealInputBoolOutputSupport) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_le_scalar_out(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_le_scalar_out(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES) + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES) #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY @@ -124,11 +124,11 @@ TEST_F(OpLeTensorOutTest, AllDtypesSupported) { #define TEST_ENTRY(ctype_in, dtype_in, ctype_out, dtype_out) \ test_dtype(); -#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ - ET_FORALL_REAL_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ +#define TEST_FORALL_OUT_TYPES(ctype_in, dtype_in) \ + ET_FORALL_REALHBF16_TYPES_WITH2(ctype_in, dtype_in, TEST_ENTRY) \ test_dtype(); - ET_FORALL_REAL_TYPES(TEST_FORALL_OUT_TYPES); + ET_FORALL_REALHBF16_TYPES(TEST_FORALL_OUT_TYPES); #undef TEST_FORALL_OUT_TYPES #undef TEST_ENTRY