Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 60 additions & 45 deletions backends/cadence/fusion_g3/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,7 @@ Tensor& add_out(
const Tensor& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(canCast(common_type, out.scalar_type()) &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx,
Expand All @@ -65,10 +51,6 @@ Tensor& add_out(
out);
#endif

// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

static constexpr const char op_name[] = "add.out";

int kTensorDimensionLimit = 5;
Expand All @@ -77,12 +59,12 @@ Tensor& add_out(
int inp2_shape[kTensorDimensionLimit];
int out_shape[kTensorDimensionLimit];

bool broadcast = 0;
bool broadcast = false;

int max_dim = a.dim() > b.dim() ? a.dim() : b.dim();
max_dim = out.dim() > max_dim ? out.dim() : max_dim;

bool optimized = 1;
bool optimized = true;

/* Added change to work with input dimensions more than 5 */
for (int i = 0; i < max_dim; i++) {
Expand All @@ -109,15 +91,19 @@ Tensor& add_out(
for (int i = 0; i < out.dim(); i++) {
if (((inp1_shape[i]) != (out_shape[i])) ||
((inp2_shape[i]) != (out_shape[i]))) {
broadcast = 1;
broadcast = true;
}
}

if ((broadcast == 1) && (max_dim > kTensorDimensionLimit)) {
optimized = 0;
if (((broadcast) && (max_dim > kTensorDimensionLimit)) ||
(!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == b.scalar_type()) &&
(a.scalar_type() == out.scalar_type())))) {
optimized = false;
}

if ((compute_type == ScalarType::Int) && (optimized)) {
if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
const int* const inp2_data = b.const_data_ptr<int>();
int* const out_data = out.mutable_data_ptr<int>();
Expand Down Expand Up @@ -169,7 +155,7 @@ Tensor& add_out(
alpha_val,
out.numel());
}
} else if ((compute_type == ScalarType::Float) && (optimized)) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
const float* const inp2_data = b.const_data_ptr<float>();
float* const out_data = out.mutable_data_ptr<float>();
Expand Down Expand Up @@ -222,6 +208,23 @@ Tensor& add_out(
out.numel());
}
} else {
// Common Dtype
ScalarType common_type =
executorch::runtime::promoteTypes(a.scalar_type(), b.scalar_type());
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(canCast(common_type, out.scalar_type()) &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
Expand Down Expand Up @@ -249,22 +252,7 @@ Tensor& add_scalar_out(
const Scalar& b,
const Scalar& alpha,
Tensor& out) {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);

#ifdef OP_ARG_CHECK
// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(common_type == out.scalar_type() &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

// Check Dim Order
ET_KERNEL_CHECK(
ctx,
Expand All @@ -279,14 +267,23 @@ Tensor& add_scalar_out(
InvalidArgument,
out);
#endif
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "add.Scalar_out";

if (compute_type == ScalarType::Int) {
bool optimized = true;

if (!(((a.scalar_type() == ScalarType::Int) ||
(a.scalar_type() == ScalarType::Float)) &&
(a.scalar_type() == out.scalar_type()))) {
optimized = false;
}

if ((b.isFloatingPoint()) && (a.scalar_type() == ScalarType::Int)) {
optimized = false;
}

if ((a.scalar_type() == ScalarType::Int) && (optimized)) {
const int* const inp1_data = a.const_data_ptr<int>();
int inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -306,7 +303,7 @@ Tensor& add_scalar_out(
alpha_val,
out.numel());

} else if (compute_type == ScalarType::Float) {
} else if ((a.scalar_type() == ScalarType::Float) && (optimized)) {
const float* const inp1_data = a.const_data_ptr<float>();
float inp2_val;
torch::executor::native::utils::extract_scalar(b, &inp2_val);
Expand All @@ -327,6 +324,24 @@ Tensor& add_scalar_out(
out.numel());

} else {
// Common Dtype
ScalarType common_type =
torch::executor::native::utils::promote_type_with_scalar(
a.scalar_type(), b);
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// Check Common Dtype
ET_KERNEL_CHECK(
ctx,
(common_type == out.scalar_type() &&
torch::executor::check_alpha_type(
torch::executor::native::utils::get_scalar_dtype(alpha),
common_type)),
InvalidArgument,
out);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
Expand Down
22 changes: 16 additions & 6 deletions backends/cadence/fusion_g3/operators/op_cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,6 @@ Tensor& cat_out(
int kTensorDimensionLimit = executorch::runtime::kTensorDimensionLimit;

#ifdef OP_ARG_CHECK
ET_KERNEL_CHECK(
ctx,
torch::executor::check_cat_args(tensors, dim, out),
InvalidArgument,
out);

Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
Expand Down Expand Up @@ -106,7 +101,16 @@ Tensor& cat_out(
out_shapes[i] = out_size[i];
}

if ((out.scalar_type() == ScalarType::Int) ||
bool optimized = true;

for (int i = 0; i < tensors.size(); i++) {
if (out.scalar_type() != tensors[i].scalar_type()) {
optimized = false;
break;
}
}

if ((optimized) && (out.scalar_type() == ScalarType::Int) ||
(out.scalar_type() == ScalarType::Short) ||
(out.scalar_type() == ScalarType::Char) ||
(out.scalar_type() == ScalarType::UInt32) ||
Expand All @@ -125,6 +129,12 @@ Tensor& cat_out(
(int)dim,
get_element_size(out.scalar_type()));
} else {
ET_KERNEL_CHECK(
ctx,
torch::executor::check_cat_args(tensors, dim, out),
InvalidArgument,
out);

const size_t outer = executorch::runtime::getLeadingDims(out, dim);
const size_t dim_stride = executorch::runtime::getTrailingDims(out, dim);
const size_t ninputs = tensors.size();
Expand Down
32 changes: 19 additions & 13 deletions backends/cadence/fusion_g3/operators/op_dequantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,22 @@ Tensor& dequantize_impl(
}
}
} else {
if (*zero_point_data != 0) // tesor
if (*zero_point_data != 0) // tensor
{
is_asym_dequant |= 1;
}
}
}
float* out_data = out.mutable_data_ptr<float>();

bool optimized = true;

if (out.scalar_type() != ScalarType::Float) {
optimized = false;
}

if (is_asym_dequant) {
if (input.scalar_type() == ScalarType::Byte) {
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -139,7 +145,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::Char) {
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -152,7 +158,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::UInt16) {
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -165,7 +171,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == ScalarType::Short) {
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr<int16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -178,7 +184,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4u) {
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -191,7 +197,7 @@ Tensor& dequantize_impl(
axis,
zero_point_data,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4) {
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand Down Expand Up @@ -338,7 +344,7 @@ Tensor& dequantize_impl(
}
}
} else {
if (input.scalar_type() == ScalarType::Byte) {
if ((input.scalar_type() == ScalarType::Byte) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -350,7 +356,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::Char) {
} else if ((input.scalar_type() == ScalarType::Char) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -362,7 +368,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::UInt16) {
} else if ((input.scalar_type() == ScalarType::UInt16) && (optimized)) {
const uint16_t* input_data = input.const_data_ptr<uint16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -374,7 +380,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == ScalarType::Short) {
} else if ((input.scalar_type() == ScalarType::Short) && (optimized)) {
const int16_t* input_data = input.const_data_ptr<int16_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -386,7 +392,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4u) {
} else if ((input.scalar_type() == (ScalarType)Bits4u) && (optimized)) {
const uint8_t* input_data = input.const_data_ptr<uint8_t>();
XT_KERNEL_CHECK(
ctx,
Expand All @@ -398,7 +404,7 @@ Tensor& dequantize_impl(
input.dim(),
axis,
scale_data);
} else if (input.scalar_type() == (ScalarType)Bits4) {
} else if ((input.scalar_type() == (ScalarType)Bits4) && (optimized)) {
const int8_t* input_data = input.const_data_ptr<int8_t>();
XT_KERNEL_CHECK(
ctx,
Expand Down
Loading