From 7c64f707d3dac3c13a25d404f17217a2da71e28d Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 28 Oct 2025 15:19:28 +0000 Subject: [PATCH 1/2] Arm backend: Update for missing operators for int16x8 * Update test infra to handle int16x8 Signed-off-by: Saoirse Stewart --- backends/arm/_passes/rewrite_upsample.py | 10 +++- backends/arm/operators/op_avg_pool2d.py | 5 +- backends/arm/operators/op_clamp.py | 19 ++++-- backends/arm/operators/op_constant_pad_nd.py | 6 ++ backends/arm/operators/op_eq.py | 2 +- backends/arm/operators/op_ge.py | 2 +- backends/arm/operators/op_gt.py | 2 +- backends/arm/operators/op_le.py | 2 +- backends/arm/operators/op_lt.py | 2 +- backends/arm/operators/op_max_pool2d.py | 2 +- backends/arm/operators/op_tosa_resize.py | 18 +++++- backends/arm/quantizer/__init__.py | 1 + backends/arm/quantizer/arm_quantizer.py | 1 + .../arm/test/ops/test_adaptive_avg_pool2d.py | 50 ++++++++++++++++ backends/arm/test/ops/test_avg_pool2d.py | 49 +++++++++++++++ backends/arm/test/ops/test_clamp.py | 54 +++++++++++++++++ backends/arm/test/ops/test_constant_pad_nd.py | 14 +++++ backends/arm/test/ops/test_eq.py | 60 +++++++++++++++++++ backends/arm/test/ops/test_ge.py | 60 +++++++++++++++++++ backends/arm/test/ops/test_gt.py | 60 +++++++++++++++++++ backends/arm/test/ops/test_le.py | 60 +++++++++++++++++++ backends/arm/test/ops/test_lt.py | 60 +++++++++++++++++++ backends/arm/test/ops/test_max_pool.py | 48 +++++++++++++++ .../arm/test/ops/test_upsample_bilinear2d.py | 40 ++++++++++++- .../arm/test/ops/test_upsample_nearest2d.py | 16 +++++ backends/arm/test/tester/test_pipeline.py | 39 +++++++++--- backends/arm/tosa/dialect/ops/resize.py | 6 ++ 27 files changed, 662 insertions(+), 26 deletions(-) diff --git a/backends/arm/_passes/rewrite_upsample.py b/backends/arm/_passes/rewrite_upsample.py index e0ef1dbcf4a..143ed1a724a 100644 --- a/backends/arm/_passes/rewrite_upsample.py +++ b/backends/arm/_passes/rewrite_upsample.py @@ -11,6 +11,7 @@ create_node, get_first_fake_tensor, ) +from executorch.backends.arm.tosa.mapping import TosaSpecialDtype from executorch.backends.arm.tosa.utils import get_resize_parameters from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult @@ -52,7 +53,9 @@ def call(self, graph_module): node.replace_all_uses_with(tosa_resize_node) graph_module.graph.erase_node(node) input_dtype = get_first_fake_tensor(x).dtype - if input_dtype == torch.int8 and resize_mode == "bilinear": + if ( + input_dtype == torch.int8 or input_dtype == torch.int16 + ) and resize_mode == "bilinear": input_size = get_first_fake_tensor(x).shape input_size_xy = input_size[2:] output_size = get_first_fake_tensor(node).shape @@ -71,6 +74,11 @@ def call(self, graph_module): exir_ops.backend.tosa.RESCALE.default, ) tosa_resize_node.replace_all_uses_with(rescale_node) + if input_dtype == torch.int16: + tosa_resize_node.meta[TosaSpecialDtype.meta_key()] = ( + TosaSpecialDtype.INT48 + ) + rescale_node.args = ( tosa_resize_node, output_dtype, diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 83f5f5d45f3..83ead5d2ff8 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -115,7 +115,10 @@ def define_node( validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( - self.target, [inputs[0], output], ts.DType.INT8, output.tosa_spec + self.target, + [inputs[0], output], + [ts.DType.INT8, ts.DType.INT16], + output.tosa_spec, ) accumulator_type = ts.DType.INT32 diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 74722abe281..84299486a62 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -70,20 +70,27 @@ def define_node( validate_num_inputs(self.target, inputs, [2, 3]) validate_same_dtype(self.target, [inputs[0], output], ts) validate_valid_dtype( - self.target, [inputs[0], output], [ts.DType.INT8], output.tosa_spec + self.target, + [inputs[0], output], + [ts.DType.INT8, ts.DType.INT16], + output.tosa_spec, ) + is_int16 = output.dtype == ts.DType.INT16 + torch_dtype = torch.int16 if is_int16 else torch.int8 + # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments - min_int8, max_int8 = self._get_min_max_arguments( + min_quant, max_quant = self._get_min_max_arguments( node, - torch.iinfo(torch.int8).min, - torch.iinfo(torch.int8).max, + torch.iinfo(torch_dtype).min, + torch.iinfo(torch_dtype).max, ) + np_dtype = np.int16 if is_int16 else np.int8 attr = ts.TosaSerializerAttribute() attr.ClampAttribute( - np.frombuffer(np.int8(min_int8).tobytes(), dtype=np.uint8).tolist(), - np.frombuffer(np.int8(max_int8).tobytes(), dtype=np.uint8).tolist(), + np.frombuffer(np_dtype(min_quant).tobytes(), dtype=np.uint8).tolist(), + np.frombuffer(np_dtype(max_quant).tobytes(), dtype=np.uint8).tolist(), ts.NanPropagationMode.PROPAGATE, ) diff --git a/backends/arm/operators/op_constant_pad_nd.py b/backends/arm/operators/op_constant_pad_nd.py index 3bda87af5ed..47d11fb5627 100644 --- a/backends/arm/operators/op_constant_pad_nd.py +++ b/backends/arm/operators/op_constant_pad_nd.py @@ -50,6 +50,7 @@ def define_node( [inputs[0], output], [ ts.DType.INT8, + ts.DType.INT16, ts.DType.INT32, ts.DType.FP32, ts.DType.BOOL, @@ -62,6 +63,11 @@ def define_node( qargs = input_qparams[0] pad_const_val = qargs.quantize_value(inputs[2].number).item() pad_const_dtype = ts.DType.INT8 + elif inputs[0].dtype == ts.DType.INT16: + input_qparams = get_input_qparams(node) + qargs = input_qparams[0] + pad_const_val = qargs.quantize_value(inputs[2].number).item() + pad_const_dtype = ts.DType.INT16 else: pad_const_val = inputs[2].number pad_const_dtype = inputs[0].dtype diff --git a/backends/arm/operators/op_eq.py b/backends/arm/operators/op_eq.py index 8fb789a9d01..7cfd497b1fe 100644 --- a/backends/arm/operators/op_eq.py +++ b/backends/arm/operators/op_eq.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_ge.py b/backends/arm/operators/op_ge.py index 5994cbc9c0f..5d6eeb75275 100644 --- a/backends/arm/operators/op_ge.py +++ b/backends/arm/operators/op_ge.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_gt.py b/backends/arm/operators/op_gt.py index 859e5c236d7..92879d549b1 100644 --- a/backends/arm/operators/op_gt.py +++ b/backends/arm/operators/op_gt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_le.py b/backends/arm/operators/op_le.py index fb26b5b8606..2b1a023d624 100644 --- a/backends/arm/operators/op_le.py +++ b/backends/arm/operators/op_le.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_lt.py b/backends/arm/operators/op_lt.py index f5cf71420f4..4f3e1163c69 100644 --- a/backends/arm/operators/op_lt.py +++ b/backends/arm/operators/op_lt.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, inputs, - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32], output.tosa_spec, ) validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec) diff --git a/backends/arm/operators/op_max_pool2d.py b/backends/arm/operators/op_max_pool2d.py index 1cab28f9153..5690b82d97b 100644 --- a/backends/arm/operators/op_max_pool2d.py +++ b/backends/arm/operators/op_max_pool2d.py @@ -47,7 +47,7 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) diff --git a/backends/arm/operators/op_tosa_resize.py b/backends/arm/operators/op_tosa_resize.py index fb8e305839f..ac017847fa3 100644 --- a/backends/arm/operators/op_tosa_resize.py +++ b/backends/arm/operators/op_tosa_resize.py @@ -46,13 +46,27 @@ def define_node( resize_mode = ts.ResizeMode.NEAREST align_corners = False validate_same_dtype(self.target, [inputs[0], output], ts) + + valid_dtypes = [] + if self.tosa_spec.support_integer(): + valid_dtypes.extend( + [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.INT48] + ) + + if self.tosa_spec.support_float(): + valid_dtypes.extend( + [ + ts.DType.FP16, + ts.DType.FP32, + ] + ) + validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT32, ts.DType.FP16, ts.DType.FP32], + valid_dtypes, output.tosa_spec, ) - # tosa_shape output is NHWC, take HW input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[ 1:3 diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index e36c683416a..2018b845353 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -12,6 +12,7 @@ from .quantization_config import QuantizationConfig # noqa # usort: skip from .arm_quantizer import ( # noqa EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index e52b30895dc..31960ffe2b8 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -51,6 +51,7 @@ "TOSAQuantizer", "EthosUQuantizer", "VgfQuantizer", + "get_symmetric_a16w8_quantization_config", "get_symmetric_quantization_config", ] diff --git a/backends/arm/test/ops/test_adaptive_avg_pool2d.py b/backends/arm/test/ops/test_adaptive_avg_pool2d.py index 4411ce7f746..3e4fbcaa833 100644 --- a/backends/arm/test/ops/test_adaptive_avg_pool2d.py +++ b/backends/arm/test/ops/test_adaptive_avg_pool2d.py @@ -136,6 +136,20 @@ def test_adaptive_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_adaptive_avg_pool2d_tosa_INT_a16w8(test_module): + """Test adaptive_avg_pool2d with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op=[], + exir_op=exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_adaptive_avg_pool2d_u55_INT(test_module): @@ -150,6 +164,27 @@ def test_adaptive_avg_pool2d_u55_INT(test_module): pipeline.run() +# Remove high_channel_count & output_1x1_from_19 due to 2MB SRAM access on U55 +u55_test_modules = test_modules +for key in ["high_channel_count", "output_1x1_from_19"]: + u55_test_modules.pop(key) + + +@common.parametrize("test_module", u55_test_modules) +@common.XfailIfNoCorstone300 +def test_adaptive_avg_pool2d_16a8w_u55_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone320 def test_adaptive_avg_pool2d_u85_INT(test_module): @@ -164,6 +199,21 @@ def test_adaptive_avg_pool2d_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_adaptive_avg_pool2d_16a8w_u85_INT16(test_module): + """Test adaptive_avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_ops=[], + exir_ops=exir_op, + a16w8_quantization=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter def test_adaptive_avg_pool2d_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_avg_pool2d.py b/backends/arm/test/ops/test_avg_pool2d.py index 8310d1e40a4..9441a6fae61 100644 --- a/backends/arm/test/ops/test_avg_pool2d.py +++ b/backends/arm/test/ops/test_avg_pool2d.py @@ -141,6 +141,21 @@ def test_avg_pool2d_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +def test_avg_pool2d_tosa_INT_a16w8(test_module): + """Test avg_pool2d operation with int16 I/O quantization for TOSA INT.""" + model, input_tensor = test_module() + pipeline = TosaPipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + tosa_extensions=["int16"], + run_on_tosa_ref_model=conftest.is_option_enabled("tosa_ref_model"), + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone300 def test_avg_pool2d_u55_INT(test_module): @@ -155,6 +170,23 @@ def test_avg_pool2d_u55_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone300 +def test_avg_pool2d_16a8w_u55_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU55PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.XfailIfNoCorstone320 def test_avg_pool2d_u85_INT(test_module): @@ -169,6 +201,23 @@ def test_avg_pool2d_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_modules) +@common.XfailIfNoCorstone320 +def test_avg_pool2d_16a8w_u85_INT16(test_module): + """Test avg_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + model, input_tensor = test_module() + pipeline = EthosU85PipelineINT[input_t]( + model, + input_tensor, + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_modules) @common.SkipIfNoModelConverter def test_avg_pool2d_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_clamp.py b/backends/arm/test/ops/test_clamp.py index a5561802e44..88c12dd8d6c 100644 --- a/backends/arm/test/ops/test_clamp.py +++ b/backends/arm/test/ops/test_clamp.py @@ -84,6 +84,22 @@ def test_clamp_tosa_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_clamp_tosa_INT_a16w8(test_data): + """Test clamp operation with int16 I/O quantization for TOSA INT.""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = TosaPipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_clamp_u55_INT(test_data): @@ -102,6 +118,25 @@ def test_clamp_u55_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_clamp_16a8w_u55_INT16(test_data): + """Test clamp operation with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU55PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_clamp_u85_INT(test_data): @@ -120,6 +155,25 @@ def test_clamp_u85_INT(test_data): pipeline.run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_clamp_16a8w_u85_INT16(test_data): + """Test clamp operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + input_tensor, min_val, max_val = test_data() + model = Clamp(min_val, max_val) + pipeline = EthosU85PipelineINT[input_t]( + model, + (input_tensor,), + aten_op, + exir_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.change_args("run_method_and_compare_outputs", qtol=1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_clamp_vgf_FP(test_data): diff --git a/backends/arm/test/ops/test_constant_pad_nd.py b/backends/arm/test/ops/test_constant_pad_nd.py index d70249c31d1..437c4bee9ef 100644 --- a/backends/arm/test/ops/test_constant_pad_nd.py +++ b/backends/arm/test/ops/test_constant_pad_nd.py @@ -77,6 +77,20 @@ def test_constant_pad_nd_tosa_INT(test_data: Tuple): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_constant_pad_nd_tosa_INT_a16w8(test_data: Tuple): + """Test constant_pad_nd op with int16 I/O quantization for TOSA INT.""" + test_data, padding, value = test_data() + pipeline = TosaPipelineINT[input_t1]( + ConstantPadND(padding, value), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_constant_pad_nd_vgf_FP(test_data: Tuple): diff --git a/backends/arm/test/ops/test_eq.py b/backends/arm/test/ops/test_eq.py index 8f783240a2c..c9ea7b0278c 100644 --- a/backends/arm/test/ops/test_eq.py +++ b/backends/arm/test/ops/test_eq.py @@ -121,6 +121,30 @@ def test_eq_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_eq_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_eq_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_eq_scalar_u55_INT_tensor(test_module): @@ -188,6 +212,42 @@ def test_eq_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_eq_tensor_16a8w_u85_INT16(test_module): + """Test eq operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_eq_scalar_16a8w_u85_INT16(test_module): + """Test eq operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Equal.aten_op_Tensor, + Equal.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_eq_scalar_vgf_FP_tensor(test_module): diff --git a/backends/arm/test/ops/test_ge.py b/backends/arm/test/ops/test_ge.py index ede5be76eda..b3cc1df34c9 100644 --- a/backends/arm/test/ops/test_ge.py +++ b/backends/arm/test/ops/test_ge.py @@ -121,6 +121,30 @@ def test_ge_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_ge_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_ge_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_ge_tensor_u55_INT(test_module): @@ -180,6 +204,42 @@ def test_ge_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_ge_tensor_16a8w_u85_INT16(test_module): + """Test ge operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_ge_scalar_16a8w_u85_INT16(test_module): + """Test ge operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + GreaterEqual.aten_op_tensor, + GreaterEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_ge_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_gt.py b/backends/arm/test/ops/test_gt.py index 0e50b6b78be..aee617f9767 100644 --- a/backends/arm/test/ops/test_gt.py +++ b/backends/arm/test/ops/test_gt.py @@ -122,6 +122,30 @@ def test_gt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_gt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_gt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_gt_tensor_u55_INT(test_module): @@ -181,6 +205,42 @@ def test_gt_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_gt_tensor_16a8w_u85_INT16(test_module): + """Test gt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_gt_scalar_16a8w_u85_INT16(test_module): + """Test gt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + Greater.aten_op_tensor, + Greater.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_gt_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_le.py b/backends/arm/test/ops/test_le.py index fd0e63e9beb..cc8ddfc4da2 100644 --- a/backends/arm/test/ops/test_le.py +++ b/backends/arm/test/ops/test_le.py @@ -122,6 +122,30 @@ def test_le_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_le_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_le_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_le_tensor_u55_INT_not_delegated(test_module): @@ -184,6 +208,42 @@ def test_le_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_le_tensor_16a8w_u85_INT16(test_module): + """Test le operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_le_scalar_16a8w_u85_INT16(test_module): + """Test le operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessEqual.aten_op_tensor, + LessEqual.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_le_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_lt.py b/backends/arm/test/ops/test_lt.py index d0ed1a34185..22958208bcd 100644 --- a/backends/arm/test/ops/test_lt.py +++ b/backends/arm/test/ops/test_lt.py @@ -122,6 +122,30 @@ def test_lt_scalar_tosa_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +def test_lt_tensor_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +def test_lt_scalar_tosa_INT_a16w8(test_module): + pipeline = TosaPipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.XfailIfNoCorstone300 def test_lt_tensor_u55_INT_not_delegated(test_module): @@ -181,6 +205,42 @@ def test_lt_scalar_u85_INT(test_module): pipeline.run() +@common.parametrize("test_module", test_data_tensor) +@common.XfailIfNoCorstone320 +def test_lt_tensor_16a8w_u85_INT16(test_module): + """Test lt operation with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + +@common.parametrize("test_module", test_data_scalar) +@common.XfailIfNoCorstone320 +def test_lt_scalar_16a8w_u85_INT16(test_module): + """Test lt operation (scalar) with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + per_channel_quantization = False + + pipeline = EthosU85PipelineINT[input_t]( + test_module(), + test_module().get_inputs(), + LessThan.aten_op_tensor, + LessThan.exir_op, + per_channel_quantization=per_channel_quantization, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_module", test_data_tensor) @common.SkipIfNoModelConverter def test_lt_tensor_vgf_FP(test_module): diff --git a/backends/arm/test/ops/test_max_pool.py b/backends/arm/test/ops/test_max_pool.py index 559932848e4..21619afa7a3 100644 --- a/backends/arm/test/ops/test_max_pool.py +++ b/backends/arm/test/ops/test_max_pool.py @@ -133,6 +133,20 @@ def test_max_pool2d_tosa_INT(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_max_pool2d_tosa_INT_a16w8(test_data: torch.Tensor): + """Test max_pool2d operation with int16 I/O quantization for TOSA INT.""" + test_data, model_params = test_data() + pipeline = TosaPipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_op, + tosa_extensions=["int16"], + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone300 def test_max_pool2d_u55_INT(test_data: torch.Tensor): @@ -145,6 +159,23 @@ def test_max_pool2d_u55_INT(test_data: torch.Tensor): ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone300 +def test_max_pool2d_16a8w_u55_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U55 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU55PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.XfailIfNoCorstone320 def test_max_pool2d_u85_INT(test_data: torch.Tensor): @@ -157,6 +188,23 @@ def test_max_pool2d_u85_INT(test_data: torch.Tensor): ).run() +@common.parametrize("test_data", test_data_suite) +@common.XfailIfNoCorstone320 +def test_max_pool2d_16a8w_u85_INT16(test_data: torch.Tensor): + """Test max_pool2d with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + test_data, model_params = test_data() + pipeline = EthosU85PipelineINT[input_t1]( + MaxPool2d(*model_params), + (test_data,), + aten_op, + exir_ops=[], + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.run() + + reject_data_suite = { "reject_1": lambda: (MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)), "reject_2": lambda: (MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)), diff --git a/backends/arm/test/ops/test_upsample_bilinear2d.py b/backends/arm/test/ops/test_upsample_bilinear2d.py index 1edba708f1f..db440fcb3d4 100644 --- a/backends/arm/test/ops/test_upsample_bilinear2d.py +++ b/backends/arm/test/ops/test_upsample_bilinear2d.py @@ -7,7 +7,6 @@ import torch from executorch.backends.arm.test import common - from executorch.backends.arm.test.tester.test_pipeline import ( EthosU85PipelineINT, OpNotSupportedPipeline, @@ -196,6 +195,24 @@ def test_upsample_bilinear2d_vec_tosa_INT_Upsample( pipeline.run() +@common.parametrize("test_data", test_data_suite_tosa) +def test_upsample_bilinear2d_vec_tosa_INT_a16w8( + test_data: torch.Tensor, +): + """Test upsample_bilinear2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_u55) @common.XfailIfNoCorstone300 def test_upsample_bilinear2d_vec_U55_INT_Upsample_not_delegated( @@ -305,6 +322,27 @@ def test_upsample_bilinear2d_vec_U85_INT_UpsamplingBilinear2d( pipeline.run() +@common.parametrize("test_data", test_data_suite_Uxx) +@common.XfailIfNoCorstone320 +def test_upsample_bilinear2d_vec_U85_INT_a16w8( + test_data: input_t1, +): + """Test upsample_bilinear2d vec op with 16A8W quantization on U85 (16-bit activations, 8-bit weights)""" + data, size, scale_factor, compare_outputs = test_data + + pipeline = EthosU85PipelineINT[input_t1]( + UpsamplingBilinear2d(size, scale_factor), + (data,), + aten_op, + per_channel_quantization=False, + a16w8_quantization=True, + use_to_edge_transform_and_lower=True, + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite_tosa) @common.SkipIfNoModelConverter def test_upsample_bilinear2d_vgf_FP_UpsamplingBilinear2d(test_data: torch.Tensor): diff --git a/backends/arm/test/ops/test_upsample_nearest2d.py b/backends/arm/test/ops/test_upsample_nearest2d.py index a39adefc168..e7da0643d0e 100644 --- a/backends/arm/test/ops/test_upsample_nearest2d.py +++ b/backends/arm/test/ops/test_upsample_nearest2d.py @@ -195,6 +195,22 @@ def test_upsample_nearest2d_vec_tosa_INT_interpolate(test_data: torch.Tensor): pipeline.run() +@common.parametrize("test_data", test_data_suite) +def test_upsample_nearest2d_vec_tosa_INT_a16w8(test_data: torch.Tensor): + """Test upsample_nearest2d vector op with int16 I/O quantization for TOSA INT.""" + test_data, size, scale_factor, compare_outputs = test_data() + pipeline = TosaPipelineINT[input_t1]( + Upsample(size, scale_factor), + (test_data,), + aten_op, + exir_op=[], + tosa_extensions=["int16"], + ) + if not compare_outputs: + pipeline.pop_stage(-1) + pipeline.run() + + @common.parametrize("test_data", test_data_suite) @common.SkipIfNoModelConverter def test_upsample_nearest2d_vgf_FP(test_data: torch.Tensor): diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index 54a8f08ee50..42d71476ff4 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -25,6 +25,7 @@ from executorch.backends.arm.quantizer import ( EthosUQuantizer, + get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, TOSAQuantizer, VgfQuantizer, @@ -362,9 +363,15 @@ def __init__( ) quantizer = TOSAQuantizer(tosa_profiles[tosa_version]) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose 16A8W quantization config when int16 extension is requested + if "int16" in tosa_extensions: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -523,6 +530,7 @@ def __init__( run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, @@ -535,9 +543,15 @@ def __init__( tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) @@ -614,6 +628,7 @@ def __init__( run_on_fvp: bool = True, symmetric_io_quantization: bool = False, per_channel_quantization: bool = True, + a16w8_quantization: bool = False, use_to_edge_transform_and_lower: bool = True, custom_path: str = None, tosa_debug_mode: Optional[ArmCompileSpec.DebugMode] = None, @@ -626,9 +641,15 @@ def __init__( tosa_debug_mode=tosa_debug_mode, ) quantizer = EthosUQuantizer(compile_spec) - quantization_config = get_symmetric_quantization_config( - is_per_channel=per_channel_quantization - ) + # choose int8 or int16 activation quantization + if a16w8_quantization: + quantization_config = get_symmetric_a16w8_quantization_config( + is_per_channel=per_channel_quantization + ) + else: + quantization_config = get_symmetric_quantization_config( + is_per_channel=per_channel_quantization + ) if symmetric_io_quantization: quantizer.set_io(quantization_config) quant_stage = Quantize(quantizer, quantization_config) diff --git a/backends/arm/tosa/dialect/ops/resize.py b/backends/arm/tosa/dialect/ops/resize.py index 1f976d0f5e0..b40b1f74a75 100644 --- a/backends/arm/tosa/dialect/ops/resize.py +++ b/backends/arm/tosa/dialect/ops/resize.py @@ -43,6 +43,12 @@ def RESIZE( ) bilinear = resize_mode == "bilinear" output_dtype = torch.int32 if bilinear else torch.int8 + elif x.dtype == torch.int16: + if not tosa_spec.support_integer(): + raise TosaValueError( + f"Context TOSA spec {tosa_spec} doesn't support int16", op="RESIZE" + ) + output_dtype = x.dtype elif x.dtype in (torch.float16, torch.float32): if not tosa_spec.support_float(): raise TosaValueError( From 6c42da064b274405283a01bf394581a83eb6187c Mon Sep 17 00:00:00 2001 From: Saoirse Stewart Date: Tue, 28 Oct 2025 15:19:28 +0000 Subject: [PATCH 2/2] Update for merge conflicts Signed-off-by: Saoirse Stewart --- backends/arm/operators/op_avg_pool2d.py | 48 ++---------------- backends/arm/operators/op_clamp.py | 67 +++++-------------------- 2 files changed, 17 insertions(+), 98 deletions(-) diff --git a/backends/arm/operators/op_avg_pool2d.py b/backends/arm/operators/op_avg_pool2d.py index 83ead5d2ff8..b661aedeb24 100644 --- a/backends/arm/operators/op_avg_pool2d.py +++ b/backends/arm/operators/op_avg_pool2d.py @@ -117,52 +117,14 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP32], output.tosa_spec, ) - accumulator_type = ts.DType.INT32 - - input_qargs = get_input_qparams(node) - input_zp = input_qargs[0].get_zp_per_tensor() - - output_qargs = get_output_qparams(node) - output_zp = output_qargs[0].get_zp_per_tensor() - - self._build_generic_avgpool2d( - node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type - ) - - -@register_node_visitor -class AvgPool2dVisitor_FP(AvgPool2dVisitor): - target = "aten.avg_pool2d.default" - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) - - def define_node( - self, - node: torch.fx.Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7]) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ts.DType.INT8, ts.DType.FP32], - output.tosa_spec, - ) - - if inputs[0].dtype == ts.DType.INT8: - super().define_node(node, tosa_graph, inputs, output) + if inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.INT16: + accumulator_type = ts.DType.INT32 + input_qargs = get_input_qparams(node) + input_zp = input_qargs[0].get_zp_per_tensor() if inputs[0].dtype == ts.DType.FP32: accumulator_type = ts.DType.FP32 diff --git a/backends/arm/operators/op_clamp.py b/backends/arm/operators/op_clamp.py index 84299486a62..ff31b4633f4 100644 --- a/backends/arm/operators/op_clamp.py +++ b/backends/arm/operators/op_clamp.py @@ -60,60 +60,17 @@ def cast_type(value: Any) -> int | float: return min_arg, max_arg - def define_node( - self, - node: Node, - tosa_graph: Any, - inputs: List[TosaArg], - output: TosaArg, - ) -> None: - validate_num_inputs(self.target, inputs, [2, 3]) - validate_same_dtype(self.target, [inputs[0], output], ts) - validate_valid_dtype( - self.target, - [inputs[0], output], - [ts.DType.INT8, ts.DType.INT16], - output.tosa_spec, - ) - - is_int16 = output.dtype == ts.DType.INT16 - torch_dtype = torch.int16 if is_int16 else torch.int8 - - # NOTE: Quantization of the min/max arguments is handled by QuantizeOperatorArguments - min_quant, max_quant = self._get_min_max_arguments( - node, - torch.iinfo(torch_dtype).min, - torch.iinfo(torch_dtype).max, - ) - - np_dtype = np.int16 if is_int16 else np.int8 - attr = ts.TosaSerializerAttribute() - attr.ClampAttribute( - np.frombuffer(np_dtype(min_quant).tobytes(), dtype=np.uint8).tolist(), - np.frombuffer(np_dtype(max_quant).tobytes(), dtype=np.uint8).tolist(), - ts.NanPropagationMode.PROPAGATE, - ) - - self._serialize_operator( - node, - tosa_graph, - ts.Op.CLAMP, - [inputs[0].name], - [output.name], - attr, - ) - - -@register_node_visitor -class ClampVisitor_FP(ClampVisitor_INT): - # inheriting 'target' from INT class - - tosa_specs = [ - TosaSpecification.create_from_string("TOSA-1.0+FP"), - ] - - def __init__(self, *args): - super().__init__(*args) + def _to_bytes(self, value: int | float, dtype: torch.dtype) -> bytes: + if dtype == torch.float32: + return np.frombuffer(np.float32(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.float16: + return np.frombuffer(np.float16(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.int8: + return np.frombuffer(np.int8(value).tobytes(), dtype=np.uint8).tolist() + elif dtype == torch.int16: + return np.frombuffer(np.int16(value).tobytes(), dtype=np.uint8).tolist() + else: + raise ValueError(f"Unsupported dtype for to_bytes: {dtype}") def define_node( self, @@ -127,7 +84,7 @@ def define_node( validate_valid_dtype( self.target, [inputs[0], output], - [ts.DType.FP16, ts.DType.FP32], + [ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32], output.tosa_spec, )