Skip to content

Commit 4efd79c

Browse files
Arm backend: Move rescales from ADD & SUB visitors to pass (#15378)
Move the insertion of INT8/INT32 RESCALE ops from the SUB node visitor to the pass InsertRescaleInt32Pass. This is in practice a refactoring patch, but still the output TOSA files becomes different enough to cause Ethos-U55/U85 tests to fail in test_var.py and test_conv_relu_residual_add.py. However, the issue was fixed in https://gitlab.arm.com/artificial-intelligence/ethos-u/ethos-u-vela/-/commit/642f7517d3a6bd053032e1942822f6e38ccd546f so we temporarily set the failing tests to xfail until the version of Ethos-U Vela compiler depended on is bumped to one that includes the fix. ### Test plan test_insert_rescale_i32_pass.py has been modified to test the change. Signed-off-by: Martin Lindström <Martin.Lindstroem@arm.com> Co-authored-by: Oscar Andersson <Oscar.Andersson@arm.com>
1 parent 7ce78c0 commit 4efd79c

File tree

7 files changed

+82
-570
lines changed

7 files changed

+82
-570
lines changed

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
7676

7777

7878
class InsertRescaleInt32Pass(ArmPass):
79-
"""
80-
Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
79+
"""Numerous TOSA ops require inputs and outputs to be 32-bit integers in their
8180
quantized implementations. This pass treats such operator nodes by
82-
inserting rescale ops before and after them if needed. Note that extra logic
83-
that handles the scales and zero points must be in place because the affected
84-
TOSA have naive implementations that do not account for the quantization
85-
parameters.
81+
inserting rescale ops before and after them if needed. Note that extra
82+
logic that handles the scales and zero points are in place here because the
83+
affected TOSA ops have naive implementations that do not account for the
84+
quantization parameters.
8685
"""
8786

8887
# SUM must be decomposed after this pass to prevent insertion of RESCALE
@@ -93,6 +92,7 @@ class InsertRescaleInt32Pass(ArmPass):
9392

9493
included_targets = [
9594
exir_ops.edge.aten.abs.default,
95+
exir_ops.edge.aten.add.Tensor,
9696
exir_ops.edge.aten.eq.Tensor,
9797
exir_ops.edge.aten.ge.Tensor,
9898
exir_ops.edge.aten.gt.Tensor,
@@ -101,6 +101,7 @@ class InsertRescaleInt32Pass(ArmPass):
101101
exir_ops.edge.aten.maximum.default,
102102
exir_ops.edge.aten.minimum.default,
103103
exir_ops.edge.aten.mul.Tensor,
104+
exir_ops.edge.aten.sub.Tensor,
104105
exir_ops.edge.aten.sum.dim_IntList,
105106
]
106107

@@ -142,6 +143,34 @@ def _get_inputs_rescaled_qparams(
142143
qparams = {
143144
i: self._int32_qargs(min_scale) for i in range(len(input_qparams))
144145
}
146+
elif target in [
147+
exir_ops.edge.aten.add.Tensor,
148+
exir_ops.edge.aten.sub.Tensor,
149+
]:
150+
if input_qparams[0].dtype != input_qparams[1].dtype:
151+
raise ValueError(
152+
"Mismatch in dtype args: {input_qparams[0].dtype} != {input_qparams[1].dtype}"
153+
)
154+
155+
# We are handling two INT8 or two INT16 numbers. For INT8, if the
156+
# zero point is non-null, the result will be in the range [-255;
157+
# 255], therefore we need 9 bits for the result. We have a 32-bit
158+
# accumulator, so we can divide the scale by (1 << 20) which is
159+
# equivalent to shifting the INT8 operands 20 bits to the left
160+
# before rescaling them both to 2 * max(lhs, rhs).
161+
#
162+
# For INT16, similary logic can be applied, but we instead end up
163+
# with a left shift of 12.
164+
lhs_scale, rhs_scale = (
165+
qp.get_scale_per_tensor() for qp in input_qparams.values()
166+
)
167+
max_scale_2x = 2 * max(lhs_scale, rhs_scale)
168+
169+
# Select shift based on input dtype.
170+
shift_bits = 12 if input_qparams[0].dtype == torch.int16 else 20
171+
172+
scale = max_scale_2x / (1 << shift_bits)
173+
qparams = {i: self._int32_qargs(scale) for i in range(len(input_qparams))}
145174
elif target in [
146175
exir_ops.edge.aten.mul.Tensor,
147176
exir_ops.edge.aten.sum.dim_IntList,
@@ -168,6 +197,8 @@ def _get_output_qparams(
168197
exir_ops.edge.aten.maximum.default,
169198
exir_ops.edge.aten.minimum.default,
170199
exir_ops.edge.aten.sum.dim_IntList,
200+
exir_ops.edge.aten.add.Tensor,
201+
exir_ops.edge.aten.sub.Tensor,
171202
]:
172203
# The op has not altered the scale; the output scale is equal to
173204
# the operands' scales.

backends/arm/operators/op_add.py

Lines changed: 7 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from typing import Any, List
88

9-
import executorch.backends.arm.tosa.quant_utils as tqutils
10-
import executorch.backends.arm.tosa.utils as tutils
119
import tosa_serializer as ts
1210

1311
from executorch.backends.arm.operators.node_visitor import (
@@ -19,22 +17,20 @@
1917
validate_same_dtype,
2018
validate_valid_dtype,
2119
)
22-
from executorch.backends.arm.tosa import TosaSpecification
2320
from executorch.backends.arm.tosa.mapping import TosaArg
21+
from executorch.backends.arm.tosa.specification import TosaSpecification
2422
from torch.fx import Node
2523

2624

2725
@register_node_visitor
28-
class AddVisitor_INT(NodeVisitor):
26+
class AddVisitor(NodeVisitor):
2927
target = "aten.add.Tensor"
3028

3129
tosa_specs = [
3230
TosaSpecification.create_from_string("TOSA-1.0+INT"),
31+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3332
]
3433

35-
def __init__(self, *args):
36-
super().__init__(*args)
37-
3834
def define_node(
3935
self,
4036
node: Node,
@@ -44,113 +40,21 @@ def define_node(
4440
) -> None:
4541
validate_num_inputs(self.target, inputs, 2)
4642
validate_same_dtype(self.target, [*inputs, output], ts)
47-
valid_dtypes = []
48-
if self.tosa_spec.support_integer():
49-
valid_dtypes.extend([ts.DType.INT8, ts.DType.INT16, ts.DType.INT32])
50-
if self.tosa_spec.support_float():
51-
valid_dtypes.extend([ts.DType.INT32])
52-
5343
validate_valid_dtype(
5444
self.target,
5545
[*inputs, output],
56-
valid_dtypes,
46+
[ts.DType.INT32, ts.DType.FP32],
5747
output.tosa_spec,
5848
)
59-
scale_back = 1.0
60-
if inputs[0].dtype == ts.DType.INT8:
61-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
62-
tosa_graph, inputs, node, self.tosa_spec
63-
)
64-
elif inputs[0].dtype == ts.DType.INT16:
65-
rescaled_inputs, scale_back = (
66-
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
67-
tosa_graph, inputs, node, self.tosa_spec
68-
)
69-
)
70-
else:
71-
# input[0].dtype == ts.DType.INT16 or ts.DType.INT32
72-
# Non quantized input, natively support by TOSA.ADD
73-
rescaled_inputs = inputs
74-
75-
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
76-
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
77-
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
78-
else:
79-
# output.dtype == ts.DType.INT16 or ts.DType.INT32
80-
add_output = output
8149

82-
input1, input2 = rescaled_inputs
8350
attr = ts.TosaSerializerAttribute()
8451
attr.AddAttribute()
85-
# Do the INT32 Add
52+
8653
self._serialize_operator(
8754
node,
8855
tosa_graph,
8956
ts.Op.ADD,
90-
[input1.name, input2.name],
91-
[add_output.name],
57+
[inputs[0].name, inputs[1].name],
58+
[output.name],
9259
attr,
9360
)
94-
95-
if output.dtype == ts.DType.INT8:
96-
# Scale output back to 8 bit
97-
# pyre-ignore
98-
tqutils.insert_rescale_op_to_int8(
99-
tosa_graph,
100-
add_output,
101-
scale_back,
102-
node,
103-
compute_rescale=False,
104-
tosa_spec=self.tosa_spec,
105-
) # type: ignore[possibly-undefined]
106-
elif output.dtype == ts.DType.INT16:
107-
tqutils.insert_rescale_op_to_int16(
108-
tosa_graph,
109-
add_output,
110-
scale_back,
111-
node,
112-
compute_rescale=False,
113-
tosa_spec=self.tosa_spec,
114-
) # type: ignore[possibly-undefined]
115-
116-
117-
@register_node_visitor
118-
class AddVisitor_FP(AddVisitor_INT):
119-
# inheriting 'target' from INT class
120-
121-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
122-
123-
def __init__(self, *args):
124-
super().__init__(*args)
125-
126-
def define_node(
127-
self,
128-
node: Node,
129-
tosa_graph: Any,
130-
inputs: List[TosaArg],
131-
output: TosaArg,
132-
) -> None:
133-
validate_num_inputs(self.target, inputs, 2)
134-
validate_same_dtype(self.target, [*inputs, output], ts)
135-
136-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT16, ts.DType.INT32]:
137-
# Call the inherited define_node for handling integers
138-
super().define_node(node, tosa_graph, inputs, output)
139-
else:
140-
# FP32 Add lowering
141-
validate_valid_dtype(
142-
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
143-
)
144-
145-
input1, input2 = inputs
146-
attr = ts.TosaSerializerAttribute()
147-
attr.AddAttribute()
148-
# FP lowering
149-
self._serialize_operator(
150-
node,
151-
tosa_graph,
152-
ts.Op.ADD,
153-
[input1.name, input2.name],
154-
[output.name],
155-
attr,
156-
)

backends/arm/operators/op_sub.py

Lines changed: 8 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from typing import Any, List
88

9-
import executorch.backends.arm.tosa.quant_utils as tqutils
10-
import executorch.backends.arm.tosa.utils as tutils
119
import tosa_serializer as ts
1210

1311
from executorch.backends.arm.operators.node_visitor import (
@@ -19,22 +17,20 @@
1917
validate_same_dtype,
2018
validate_valid_dtype,
2119
)
22-
from executorch.backends.arm.tosa import TosaSpecification
2320
from executorch.backends.arm.tosa.mapping import TosaArg
21+
from executorch.backends.arm.tosa.specification import TosaSpecification
2422
from torch.fx import Node
2523

2624

2725
@register_node_visitor
28-
class SubVisitor_INT(NodeVisitor):
26+
class SubVisitor(NodeVisitor):
2927
target = "aten.sub.Tensor"
3028

3129
tosa_specs = [
3230
TosaSpecification.create_from_string("TOSA-1.0+INT"),
31+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3332
]
3433

35-
def __init__(self, *args):
36-
super().__init__(*args)
37-
3834
def define_node(
3935
self,
4036
node: Node,
@@ -47,106 +43,21 @@ def define_node(
4743
validate_valid_dtype(
4844
self.target,
4945
[*inputs, output],
50-
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32],
46+
[ts.DType.INT32, ts.DType.FP32],
5147
output.tosa_spec,
5248
)
5349

54-
scale_back = 1.0
55-
if inputs[0].dtype == ts.DType.INT8:
56-
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32_maxscale(
57-
tosa_graph, inputs, node, self.tosa_spec
58-
)
59-
elif inputs[0].dtype == ts.DType.INT16:
60-
rescaled_inputs, scale_back = (
61-
tqutils.insert_rescale_ops_int16_to_int32_maxscale(
62-
tosa_graph, inputs, node, self.tosa_spec
63-
)
64-
)
65-
else:
66-
# input[0].dtype == ts.DType.INT32
67-
# Non quantized input, natively support by TOSA.SUB
68-
rescaled_inputs = inputs
69-
70-
if output.dtype in [ts.DType.INT8, ts.DType.INT16]:
71-
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
72-
sub_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
73-
else:
74-
# output.dtype == ts.DType.INT32
75-
sub_output = output
76-
77-
# Do the INT32 Sub
7850
attr = ts.TosaSerializerAttribute()
7951
attr.SubAttribute()
52+
8053
self._serialize_operator(
8154
node,
8255
tosa_graph,
8356
ts.Op.SUB,
8457
[
85-
rescaled_inputs[0].name,
86-
rescaled_inputs[1].name,
58+
inputs[0].name,
59+
inputs[1].name,
8760
],
88-
[sub_output.name],
61+
[output.name],
8962
attr,
9063
)
91-
92-
if output.dtype == ts.DType.INT8:
93-
# Scale output back to 8 bit
94-
# pyre-ignore
95-
tqutils.insert_rescale_op_to_int8(
96-
tosa_graph,
97-
sub_output,
98-
scale_back,
99-
node,
100-
compute_rescale=False,
101-
tosa_spec=self.tosa_spec,
102-
) # type: ignore[possibly-undefined]
103-
elif output.dtype == ts.DType.INT16:
104-
tqutils.insert_rescale_op_to_int16(
105-
tosa_graph,
106-
sub_output,
107-
scale_back,
108-
node,
109-
compute_rescale=False,
110-
tosa_spec=self.tosa_spec,
111-
) # type: ignore[possibly-undefined]
112-
113-
114-
@register_node_visitor
115-
class SubVisitor_FP(SubVisitor_INT):
116-
# inheriting 'target' from INT class
117-
118-
tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]
119-
120-
def __init__(self, *args):
121-
super().__init__(*args)
122-
123-
def define_node(
124-
self,
125-
node: Node,
126-
tosa_graph: Any,
127-
inputs: List[TosaArg],
128-
output: TosaArg,
129-
) -> None:
130-
validate_num_inputs(self.target, inputs, 2)
131-
validate_same_dtype(self.target, [*inputs, output], ts)
132-
133-
if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
134-
# Call the inherited define_node for handling integers
135-
super().define_node(node, tosa_graph, inputs, output)
136-
else:
137-
# FP32 Sub lowering
138-
validate_valid_dtype(
139-
self.target, [*inputs, output], ts.DType.FP32, output.tosa_spec
140-
)
141-
142-
# MI lowering
143-
attr = ts.TosaSerializerAttribute()
144-
attr.SubAttribute()
145-
self._serialize_operator(
146-
node,
147-
tosa_graph,
148-
ts.Op.SUB,
149-
[inputs[0].name, inputs[1].name],
150-
[output.name],
151-
attr,
152-
)

backends/arm/test/misc/test_conv_relu_residual_add.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def test_tosa_INT(per_channel_quantization):
7676
pipeline.run()
7777

7878

79+
# TODO: Xfail until the Ethos-U Vela compiler ships commit
80+
# 642f7517d3a6bd053032e1942822f6e38ccd546f. That patch fixes the bug that
81+
# causes this test to fail.
82+
@pytest.mark.xfail(
83+
reason=("Blocked by Vela commit 642f7517d3a6bd053032e1942822f6e38ccd546f"),
84+
strict=True,
85+
)
7986
@pytest.mark.slow
8087
@common.XfailIfNoCorstone300
8188
@common.parametrize("per_channel_quantization", quant_test_data)

0 commit comments

Comments
 (0)