Skip to content

Commit acb2b69

Browse files
hsharma35facebook-github-bot
authored andcommitted
Fix relu test + use xt macros.
Summary: TensorFactory needs to outlive tensors generated by it. This diff fixes bug in relu tests. Reviewed By: zonglinpeng Differential Revision: D87752226
1 parent b71b3b1 commit acb2b69

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

backends/cadence/hifi/operators/op_quantized_relu_out.cpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
#include <executorch/backends/cadence/common/xt_macros.h>
910
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
1011
#include <executorch/runtime/kernel/kernel_includes.h>
1112

@@ -34,7 +35,10 @@ void quantized_relu_per_tensor_out(
3435
const uint8_t* p_in = input.const_data_ptr<uint8_t>();
3536
uint8_t* p_out = output.mutable_data_ptr<uint8_t>();
3637

37-
WORD32 ret_val = xa_nn_vec_relu_asym8u_asym8u(
38+
XT_KERNEL_CHECK(
39+
ctx,
40+
,
41+
xa_nn_vec_relu_asym8u_asym8u,
3842
p_out,
3943
p_in,
4044
_in_zero_point,
@@ -45,15 +49,16 @@ void quantized_relu_per_tensor_out(
4549
255,
4650
input.numel());
4751

48-
ET_CHECK_MSG(ret_val == 0, "An internal error occured");
49-
5052
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
51-
const int8_t _in_zero_point = static_cast<int8_t>(in_zero_point);
52-
const int8_t _out_zero_point = static_cast<int8_t>(out_zero_point);
53+
const int _in_zero_point = static_cast<int>(in_zero_point);
54+
const int _out_zero_point = static_cast<int>(out_zero_point);
5355
const int8_t* p_in = input.const_data_ptr<int8_t>();
5456
int8_t* p_out = output.mutable_data_ptr<int8_t>();
5557

56-
WORD32 ret_val = xa_nn_vec_relu_asym8s_asym8s(
58+
XT_KERNEL_CHECK(
59+
ctx,
60+
,
61+
xa_nn_vec_relu_asym8s_asym8s,
5762
p_out,
5863
p_in,
5964
_in_zero_point,
@@ -64,8 +69,6 @@ void quantized_relu_per_tensor_out(
6469
127,
6570
input.numel());
6671

67-
ET_CHECK_MSG(ret_val == 0, "An internal error occured");
68-
6972
} else {
7073
ET_CHECK_MSG(
7174
false,

backends/cadence/hifi/operators/tests/test_op_quantized_relu_out.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,14 @@ class HiFiQuantizedReluTest : public OperatorTest {
5757

5858
TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
5959
TensorFactory<ScalarType::Char> tf_chars;
60+
TensorFactory<ScalarType::Int> tf_ints;
6061
const std::vector<int32_t> sizes{2, 3, 5, 6};
6162
Tensor quantized_input = tf_chars.full(sizes, -128);
6263
Tensor quantized_output = tf_chars.full(sizes, 100);
6364
Tensor in_zero_point = tf_chars.full({1}, 127);
6465
int64_t out_zero_point = -128;
65-
Tensor out_multiplier =
66-
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
67-
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
66+
Tensor out_multiplier = tf_ints.full({1}, 1077952640);
67+
Tensor out_shift = tf_ints.full({1}, 5);
6868

6969
quantized_relu_out(
7070
quantized_input,
@@ -80,14 +80,14 @@ TEST_F(HiFiQuantizedReluTest, MultiDimensionalTest) {
8080

8181
TEST_F(HiFiQuantizedReluTest, OneDimensionalTest) {
8282
TensorFactory<ScalarType::Char> tf_chars;
83+
TensorFactory<ScalarType::Int> tf_ints;
8384
const std::vector<int32_t> sizes{56};
8485
Tensor quantized_input = tf_chars.full(sizes, -128);
8586
Tensor quantized_output = tf_chars.full(sizes, 100);
8687
Tensor in_zero_point = tf_chars.full({1}, 127);
8788
int64_t out_zero_point = -128;
88-
Tensor out_multiplier =
89-
TensorFactory<ScalarType::Int>().full({1}, 1077952640);
90-
Tensor out_shift = TensorFactory<ScalarType::Int>().full({1}, 5);
89+
Tensor out_multiplier = tf_ints.full({1}, 1077952640);
90+
Tensor out_shift = tf_ints.full({1}, 5);
9191

9292
quantized_relu_out(
9393
quantized_input,

0 commit comments

Comments
 (0)