Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions backends/cadence/hifi/operators/op_quantized_relu_out.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/backends/cadence/common/xt_macros.h>
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/runtime/kernel/kernel_includes.h>

Expand Down Expand Up @@ -34,7 +35,10 @@ void quantized_relu_per_tensor_out(
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();

WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u(
XT_KERNEL_CHECK(
ctx,
,
xa_nn_vec_relu_asym8u_asym8u,
p_out,
p_in,
_in_zero_point,
Expand All @@ -45,15 +49,16 @@ void quantized_relu_per_tensor_out(
255,
input.numel());

ET_CHECK_MSG(ret_val == 0, "An internal error occured");

} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
const int8_t _in_zero_point = static_cast<int8_t>(in_zero_point);
const int8_t _out_zero_point = static_cast<int8_t>(out_zero_point);
const int _in_zero_point = static_cast<int>(in_zero_point);
const int _out_zero_point = static_cast<int>(out_zero_point);
const int8_t* p_in = input.const_data_ptr<int8_t>();
int8_t* p_out = output.mutable_data_ptr<int8_t>();

WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s(
XT_KERNEL_CHECK(
ctx,
,
xa_nn_vec_relu_asym8s_asym8s,
p_out,
p_in,
_in_zero_point,
Expand All @@ -64,8 +69,6 @@ void quantized_relu_per_tensor_out(
127,
input.numel());

ET_CHECK_MSG(ret_val == 0, "An internal error occured");

} else {
ET_CHECK_MSG(
false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ class HiFiQuantizedReluTest : public OperatorTest {

TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
TensorFactory<ScalarType::Char> tf_chars;
TensorFactory<ScalarType::Int> tf_ints;
const std::vector<int32_t> sizes{2, 3, 5, 6};
Tensor quantized_input = tf_chars.full(sizes, -128);
Tensor quantized_output = tf_chars.full(sizes, 100);
Tensor in_zero_point = tf_chars.full({1}, 127);
int64_t out_zero_point = -128;
Tensor out_multiplier =
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
Tensor out_multiplier = tf_ints.full({1}, 1077952640);
Tensor out_shift = tf_ints.full({1}, 5);

quantized_relu_out(
quantized_input,
Expand All @@ -80,14 +80,14 @@ TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {

TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) {
TensorFactory<ScalarType::Char> tf_chars;
TensorFactory<ScalarType::Int> tf_ints;
const std::vector<int32_t> sizes{56};
Tensor quantized_input = tf_chars.full(sizes, -128);
Tensor quantized_output = tf_chars.full(sizes, 100);
Tensor in_zero_point = tf_chars.full({1}, 127);
int64_t out_zero_point = -128;
Tensor out_multiplier =
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
Tensor out_multiplier = tf_ints.full({1}, 1077952640);
Tensor out_shift = tf_ints.full({1}, 5);

quantized_relu_out(
quantized_input,
Expand Down
Loading