Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
188 changes: 60 additions & 128 deletions kernels/optimized/cpu/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <ATen/cpu/vec/vec.h>
#include <executorch/kernels/optimized/cpu/binary_ops.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/runtime/core/exec_aten/util/tensor_util.h> // IWYU pragma: export
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
Expand All @@ -22,76 +22,35 @@ namespace native {
using Tensor = executorch::aten::Tensor;
using ScalarType = executorch::aten::ScalarType;

namespace {

template <
bool can_cast,
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner;

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<true, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT> {
static void run(const Tensor& a, const Tensor& b, Tensor& out) {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
// NOLINTNEXTLINE(facebook-hte-ConstantArgumentPassByValue)
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
}
};

struct ReportCanCastBug {
static void run(const Tensor&, const Tensor&, Tensor&) {
ET_DCHECK_MSG(false, "BUG: canCast should have been checked above");
}
};

template <
typename CTYPE_A,
typename CTYPE_B,
typename CTYPE_IN,
typename CTYPE_OUT>
struct MulInner<false, CTYPE_A, CTYPE_B, CTYPE_IN, CTYPE_OUT>
: public ReportCanCastBug {};

} // namespace

Tensor& opt_mul_out(
KernelRuntimeContext& ctx,
const Tensor& a,
const Tensor& b,
Tensor& out) {
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType b_type = b.scalar_type();
ScalarType out_type = out.scalar_type();
ScalarType common_type = promoteTypes(a_type, b_type);

ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.out";

if (b.numel() == 1) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
a_type != ScalarType::BFloat16) {
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(b_type, ctx, op_name, CTYPE_B, [&]() {
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
CTYPE b_casted = static_cast<CTYPE>(b_val);

Expand All @@ -111,17 +70,11 @@ Tensor& opt_mul_out(

auto selected_optimized_path = select_optimized_path(a, b, out);
if (selected_optimized_path == ElementwiseOptimizedPath::kTreatAs1d) {
ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

if (executorch::runtime::isComplexType(out_type)) {
ET_KERNEL_CHECK(
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);

ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
using Vec = at::vec::Vectorized<CTYPE>;
at::vec::map2<CTYPE>(
[](Vec x, Vec y) { return x * y; },
Expand All @@ -131,7 +84,7 @@ Tensor& opt_mul_out(
out.numel());
});
} else {
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
using Vec = at::vec::Vectorized<CTYPE>;
at::vec::map2<CTYPE>(
[](Vec x, Vec y) { return x * y; },
Expand All @@ -146,63 +99,47 @@ Tensor& opt_mul_out(
ET_KERNEL_CHECK(
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);

ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
auto mul_lambda = [](auto x, auto y) { return x * y; };
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, mul_lambda, a, b, out, selected_optimized_path);
});
} else {
ET_SWITCH_REALB_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
auto mul_lambda = [](auto x, auto y) { return x * y; };
torch::executor::handle_broadcast_elementwise<CTYPE>(
ctx, mul_lambda, a, b, out, selected_optimized_path);
});
}
} else {
ScalarType common_type =
promoteTypes(a_type, b_type, /*half_to_float*/ true);
ET_KERNEL_CHECK(ctx, canCast(common_type, out_type), InvalidArgument, out);

ET_KERNEL_CHECK(
ctx,
resize_to_broadcast_target_size(a, b, out) == Error::Ok,
InvalidArgument,
out);

if (executorch::runtime::isComplexType(a_type) ||
executorch::runtime::isComplexType(b_type) ||
executorch::runtime::isComplexType(out_type)) {
ET_KERNEL_CHECK(
ctx, a_type == b_type && a_type == out_type, InvalidArgument, out);

ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, "mul.out", CTYPE, [&]() {
ET_SWITCH_COMPLEXH_TYPES(out_type, ctx, op_name, CTYPE, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE, CTYPE>(
[](const CTYPE val_a, const CTYPE val_b) { return val_a * val_b; },
a,
b,
out);
});
} else {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted * b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
});
ScalarType compute_type = utils::internal::get_compute_type(common_type);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const auto val_a, const auto val_b) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out);
});
}
}
Expand All @@ -215,26 +152,24 @@ Tensor& opt_mul_scalar_out(
const Tensor& a,
const Scalar& b,
Tensor& out) {
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType common_type =
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
ScalarType common_type = utils::promote_type_with_scalar(a_type, b);
ScalarType out_type = out.scalar_type();

ET_CHECK(common_type == out_type);
ET_KERNEL_CHECK(ctx, common_type == out_type, InvalidArgument, out);

if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

// Resize for dynamic shape
auto error = resize_tensor(out, a.sizes());
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");
ET_KERNEL_CHECK(
ctx, resize_tensor(out, a.sizes()) == Error::Ok, InvalidArgument, out);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.Scalar_out";

if (a_type == common_type && a_type == out_type &&
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
ET_SWITCH_REALB_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE, [&]() {
ET_SWITCH_REALB_TYPES(a_type, ctx, op_name, CTYPE, [&]() {
CTYPE b_casted = utils::scalar_to<CTYPE>(b);

using Vec = at::vec::Vectorized<CTYPE>;
Expand All @@ -245,22 +180,19 @@ Tensor& opt_mul_scalar_out(
out.numel());
});
} else {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "mul.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_REALB_TYPES(
common_type, ctx, "mul.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "mul.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);

const size_t n = a.numel();
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
for (auto i = 0; i < n; ++i) {
out_data[i] = static_cast<CTYPE_OUT>(
static_cast<CTYPE_IN>(a_data[i]) * b_casted);
}
});
});
ScalarType compute_type = utils::internal::get_compute_type(common_type);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b](const auto val_a) { return val_a * val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ OPTIMIZED_ATEN_OPS = (
":binary_ops",
"//executorch/kernels/portable/cpu:scalar_utils",
"//executorch/kernels/portable/cpu/util:broadcast_util",
"//executorch/kernels/portable/cpu/util:dtype_util",
"//executorch/kernels/portable/cpu/util:elementwise_util",
"//executorch/runtime/core/exec_aten/util:tensor_util",
"//executorch/runtime/core/portable_type/c10/c10:aten_headers_for_executorch",
],
Expand Down
Loading