Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 33 additions & 22 deletions backends/cadence/fusion_g3/operators/op_sub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,9 @@ Tensor& sub_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
#ifdef OP_ARG_CHECK
ScalarType alpha_type =
torch::executor::native::utils::get_scalar_dtype(alpha);

// Check alpha type
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);

Expand All @@ -67,10 +63,6 @@ Tensor& sub_out(
out);
#endif

// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "sub.out";

Expand Down Expand Up @@ -115,11 +107,15 @@ Tensor& sub_out(
}
}

if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
if (((broadcast == 1) && (max_dim > kTensorDimensionLimit)) ||
(!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == b.scalar_type()) &&
(a.scalar_type() == out.scalar_type())))) {
optimized = 0;
}

if ((compute_type == ScalarType::Int) && (optimized)) {
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
const int* const inp2_data = b.const_data_ptr<int>();
int* const out_data = out.mutable_data_ptr<int>();
Expand Down Expand Up @@ -161,7 +157,7 @@ Tensor& sub_out(
alpha_val,
out.numel());
}
} else if ((compute_type == ScalarType::Float) && (optimized)) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
const float* const inp2_data = b.const_data_ptr<float>();
float* const out_data = out.mutable_data_ptr<float>();
Expand Down Expand Up @@ -204,6 +200,13 @@ Tensor& sub_out(
out.numel());
}
} else {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
Expand Down Expand Up @@ -232,14 +235,9 @@ Tensor& sub_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);
#ifdef OP_ARG_CHECK
ScalarType alpha_type =
torch::executor::native::utils::get_scalar_dtype(alpha);

// Check alpha type
ET_KERNEL_CHECK(ctx, alpha_type != ScalarType::Bool, InvalidArgument, out);

Expand All @@ -265,14 +263,20 @@ Tensor& sub_scalar_out(
out);
#endif

// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "sub.Scalar_out";

if (compute_type == ScalarType::Int) {
bool optimized = 1;
ScalarType b_type = torch::executor::native::utils::get_scalar_dtype(b);

if (!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == b_type) &&
(a.scalar_type() == out.scalar_type()))) {
optimized = 0;
}

if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -291,7 +295,7 @@ Tensor& sub_scalar_out(
inp2_val,
alpha_val,
out.numel());
} else if (compute_type == ScalarType::Float) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -311,6 +315,13 @@ Tensor& sub_scalar_out(
alpha_val,
out.numel());
} else {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);
ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b =
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(b);
Expand Down