Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from executorch.backends.arm._passes.decompose_var_pass import DecomposeVarPass
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
FoldAndAnnotateQParamsPass,
QuantizeFullArgument,
)
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
Expand Down Expand Up @@ -84,12 +85,16 @@ def transform_to_backend_pipeline(
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.full.default,
]
)
)
Expand Down
54 changes: 52 additions & 2 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node

q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default


def get_input_qparams(node: Node) -> dict[int, QuantArgs]:
"""
Expand Down Expand Up @@ -77,8 +80,6 @@ def __init__(self, targeted_ops: Iterable[Callable]):
self.targeted_ops = targeted_ops

def call(self, graph_module: GraphModule) -> PassResult:
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default

# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
Expand All @@ -98,6 +99,22 @@ def call(self, graph_module: GraphModule) -> PassResult:
for i, arg in enumerate(n.args):
if not isinstance(arg, Node):
continue

# Make sure arg has requires_grad set to False
# For parameters that are not quantized, sometimes (i.e. convolution)
# the Parameter(FakeTensor(...)) has requires_grad set to True, which
# causes the retracing of the graph to fail with:
#
# E RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/autograd/functions/utils.h":74, please report a bug to PyTorch.
# E
# E While executing %aten_convolution_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%quantized_decomposed_quantize_per_tensor_default, %b__frozen_param0, %p__param_constant1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {})
# E Original traceback:
# E File "/Users/perast01/src/executorch/backends/arm/test/ops/test_conv2d.py", line 110, in forward
# E x = conv(x)
#
if arg.op == "placeholder":
arg.meta["val"].requires_grad = False

if arg.target != dq_op:
continue

Expand Down Expand Up @@ -129,3 +146,36 @@ def call(self, graph_module: GraphModule) -> PassResult:

graph_module.recompile()
return PassResult(graph_module, True)


class QuantizeFullArgument(ExportPass):
"""
Make sure the fill_value for full.default is quantized. This pass needs to be run before
the folding pass above to make sure that the retraced output of the full.default op is
the right dtype.
"""

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target != exir_ops.edge.aten.full.default:
continue

# Make sure we have a quantized operator
user = list(n.users)[0]
if user.target != q_op:
continue

qargs = QuantArgs.from_operator(user.target, user.args)
if "dtype" not in n.kwargs.keys() or n.kwargs["dtype"] != qargs.dtype:
# replace the node arg with a quantized dito and also set dtype
# to get the right output according to the Edge IR specification:
# exir/dialects/edge/edge.yaml:3596
quantized_full_value = qargs.quantize_value(n.args[1]).item()
n.update_arg(1, quantized_full_value)
n.update_kwarg("dtype", qargs.dtype)
modified = True

return PassResult(graph_module, modified)
100 changes: 87 additions & 13 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,44 +8,118 @@

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common
from executorch.backends.arm.tosa_specification import TosaSpecification


@register_node_visitor
class AvgPool2dVisitor(NodeVisitor):
class AvgPool2dVisitor_0_80_BI(NodeVisitor):
target = "aten.avg_pool2d.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
def _build_generic_avgpool2d(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
input_zp: int,
output_zp: int,
accumulator_type,
) -> None:
input_tensor = inputs[0]

kernel_size_list = inputs[1].special
stride_size_list = inputs[2].special
try:
pad_size_list = inputs[3].special
except IndexError:
pad_size_list = [0, 0, 0, 0]

build_avg_pool_2d_common(
node,
tosa_graph,
input_tensor,
kernel_size_list,
stride_size_list,
pad_size_list,
is_quant_node,
output,
attr = ts.TosaSerializerAttribute()
attr.PoolAttribute(
kernel=kernel_size_list,
stride=stride_size_list,
pad=pad_size_list,
input_zp=input_zp,
output_zp=output_zp,
accum_dtype=accumulator_type,
)

tosa_graph.addOperator(
ts.TosaOp.Op().AVG_POOL2D,
[input_tensor.name],
[output.name],
attr,
)

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
input_tensor = inputs[0]
assert input_tensor.dtype == ts.DType.INT8

accumulator_type = ts.DType.INT32

input_qargs = get_input_qparams(node)
input_zp = input_qargs[0].zp

output_qargs = get_output_qparams(node)
output_zp = output_qargs[0].zp

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)


@register_node_visitor
class AvgPool2dVisitor_0_80_MI(AvgPool2dVisitor_0_80_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
assert (
inputs[0].dtype == ts.DType.INT8 or inputs[0].dtype == ts.DType.FP32
), "Only FP32 and INT8 supported"

if inputs[0].dtype == ts.DType.INT8:
super().define_node(node, tosa_graph, inputs, output, is_quant_node)

if inputs[0].dtype == ts.DType.FP32:
accumulator_type = ts.DType.FP32
# Initilize zero point to zero.
input_zp = 0
output_zp = 0

self._build_generic_avgpool2d(
node, tosa_graph, inputs, output, input_zp, output_zp, accumulator_type
)
5 changes: 5 additions & 0 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import promote_shape, tosa_shape
from serializer.tosa_serializer import TosaOp

Expand All @@ -21,6 +22,10 @@
class BatchNormVisitor(NodeVisitor):
target = "aten._native_batch_norm_legit_no_training.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand Down
48 changes: 26 additions & 22 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
get_input_qparams,
get_output_qparams,
)
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_quant_utils import (
build_rescale_conv_output,
get_quant_arg_downstream,
get_quant_arg_upstream,
)
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape

from serializer.tosa_serializer import TosaOp
Expand Down Expand Up @@ -57,9 +57,6 @@ def define_node(
) -> None:
input, weight, bias, stride, pad, dilation, _, _, group = inputs

# Currently only int8 is supported in quantized types.
actual_out_type = ts.DType.INT8 if is_quant_node else output.dtype

# Get the attributes of convolution.
attr = ts.TosaSerializerAttribute()
pad_attr = [val for val in pad.special for _ in (0, 1)]
Expand All @@ -82,9 +79,11 @@ def define_node(
dilation_attr[1],
)

input_zp = (
get_quant_arg_upstream(node.all_input_nodes[0]).zp if is_quant_node else 0
)
input_zp = 0
if inputs[0].dtype == ts.DType.INT8:
# int8 input requires quantization information
input_qparams = get_input_qparams(node)
input_zp = input_qparams[0].zp

attr.ConvAttribute(
pad=pad_attr,
Expand All @@ -100,16 +99,22 @@ def define_node(
# Create a zero bias tensor if not presented
out_channels = weight.shape[0]
bias_name = "bias" + node.name.split("default", 1)[1]
bias_type = output.dtype
if output.dtype == ts.DType.INT8:
# Conv is quantized to int8, but the TOSA operator has
# output type int32, and the bias must be the same type
# as the TOSA output type
bias_type = ts.DType.INT32
bias = tosa_graph.addConst(
[out_channels],
ts.DType.INT32 if is_quant_node else output.dtype,
bias_type,
[0] * out_channels,
name=bias_name,
)

# The output type is int32 when input type is int8.
conv2d_output_name = output.name
if is_quant_node:
if output.dtype == ts.DType.INT8:
conv2d_res = tosa_graph.addIntermediate(
tosa_shape(output.shape, output.dim_order), ts.DType.INT32
)
Expand All @@ -132,7 +137,7 @@ def define_node(

weight_reshaped = tosa_graph.addIntermediate(
weight_post_shape,
ts.DType.INT8 if is_quant_node else weight.dtype,
weight.dtype,
)
build_reshape(
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
Expand All @@ -157,20 +162,19 @@ def define_node(

# For quantized convolution, rescale the output value back to the same
# integer value domain of the next op. Otherwise return float32 output.
if is_quant_node:
if inputs[0].dtype == ts.DType.INT8:
# Get scale_factor from input, weight, and output.
input_scale = get_quant_arg_upstream(node.all_input_nodes[0]).scale
weight_scale = get_quant_arg_upstream(node.all_input_nodes[1]).scale
output_qargs = get_quant_arg_downstream(list(node.users)[0])

input_scale = input_qparams[0].scale
weight_scale = input_qparams[1].scale
output_qargs = get_output_qparams(node)
build_rescale_conv_output(
tosa_graph,
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
conv2d_res,
output.name,
actual_out_type,
output.dtype,
input_scale,
weight_scale,
output_qargs.scale,
output_qargs.zp,
output_qargs[0].scale,
output_qargs[0].zp,
)
6 changes: 6 additions & 0 deletions backends/arm/operators/op_div.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import tosa_shape
from serializer.tosa_serializer import TosaOp

Expand All @@ -21,6 +22,11 @@
class DivVisitor(NodeVisitor):
target = "aten.div.Tensor"

# Only supported for MI
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand Down
Loading
Loading