Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/annotate_decomposed_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
)
matmul_targets = {
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.mm.default,
}
for partition in matmul_partitions:
quantized_input = all(
Expand Down
104 changes: 22 additions & 82 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,79 +155,26 @@ def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module

def _tosa_INT_pipeline(
def _tosa_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchArgDtypePass())
if self.tosa_spec.is_U55_subset:
self.add_pass(CastToInt32Pass())

self.add_pass(CastBoolToInt8Pass())
self.add_pass(ReplaceScalarWithTensorByProfilePass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeBatchNormNoStatsPass())
self.add_pass(DecomposeVarPass())
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(ConvertELUParamsPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(FuseDuplicateUsersPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
if self.tosa_spec.is_U55_subset:
self.add_pass(BroadcastArgsPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
self.add_pass(DecomposeAvgPool2d())
self.add_pass(ComputeConstantOpsAOT(exported_program))

self.add_pass(DecomposeGroupedConv())

self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(DecomposeCumsumPass(exported_program))
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
self.add_pass(SizeAdjustInputPass())
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(InsertTableOpsPass(exported_program))
# If we have a conv2d with int16 activation split up into a convolution
# and an addition, to work-around the lack of support for int48 in torch
# needs to happen before RewriteConv2dPass, but after the table ops are inserted
# to be able to validate that conv2d has right dtype arguments.
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
self.add_pass(RewriteConv2dPass(exported_program))

self.add_pass(RewriteMatmulPass())
self.add_pass(RewriteUpsamplePass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))

self.add_pass(InsertRescaleInt32Pass())
self.add_pass(DecomposeSumPass())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(graph_module)

def _tosa_FP_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseDuplicateUsersPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
self.add_pass(DecomposeMaskedFill())
Expand All @@ -252,32 +199,20 @@ def _tosa_FP_pipeline(
self.add_pass(DecomposeRemainderPass())
self.add_pass(DecomposeDivTensorModePass())
self.add_pass(DecomposeEmbeddingPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(FuseBatchnorm2DPass(exported_program))
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeGluPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeLeakyReLUPass())
self.add_pass(DecomposeGroupNormPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeBatchNormNoStatsPass())
self.add_pass(DecomposeVarPass())
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeAddSubAlphaPass())
self.add_pass(DecomposeSoftmaxPass())
self.add_pass(DecomposeGeluPass())
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
self.add_pass(ConvertAnyDefaultDimDimsPass())
self.add_pass(MatchArgDtypePass())
self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeOperatorArguments())
self.add_pass(FoldAndAnnotateQParamsPass(exported_program)) # type: ignore[call-arg]
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeAdaptiveAvgPool2dPass())
Expand All @@ -290,22 +225,26 @@ def _tosa_FP_pipeline(
self.add_pass(DecomposeGroupedConv())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(DecomposeSumPass())
self.add_pass(DecomposeCumsumPass(exported_program))
self.add_pass(Conv1dUnsqueezePass())
self.add_pass(DecomposeMaxPool2DPass())
self.add_pass(SizeAdjustInputPass())
self.add_pass(DecomposeSelectPass())
self.add_pass(ConvertSqueezesToViewPass())
self.add_pass(CastToInt32Pass())
self.add_pass(BroadcastArgsPass())

self.add_pass(FuseViewCopyTransform())
self.add_pass(FuseConstantArgsPass(exported_program))
self.add_pass(RewriteConv2dPass(exported_program))
self.add_pass(DecomposeConv2dWithInt16ActivationPass())
self.add_pass(CastInt64BuffersToInt32Pass(exported_program))
self.add_pass(RewriteUpsamplePass())
self.add_pass(InsertTableOpsPass(exported_program))
self.add_pass(RewriteUpsamplePass())
self.add_pass(RewriteConv2dPass(exported_program))
self.add_pass(RewriteMatmulPass())
self.add_pass(FuseEqualPlaceholdersPass(exported_program))
self.add_pass(InsertRescaleInt32Pass())
self.add_pass(DecomposeSumPass())
self.add_pass(ToTosaMemoryFormatPass(exported_program))
self.add_pass(RemoveNoopPass())
self.add_pass(InsertRescalePass())
Expand All @@ -317,10 +256,11 @@ def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_FP_pipeline(exported_program, graph_module)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
return self._tosa_INT_pipeline(exported_program, graph_module)
if self.tosa_spec in (
TosaSpecification.create_from_string("TOSA-1.0+FP"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
):
return self._tosa_pipeline(exported_program, graph_module)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/_passes/broadcast_args_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
create_node,
get_first_fake_tensor,
)
from executorch.backends.arm.tosa.specification import get_context_spec

from executorch.exir.dialects._ops import ops as exir_ops

Expand All @@ -34,6 +35,9 @@ class BroadcastArgsPass(ArmPass):
}

def call(self, graph_module: GraphModule) -> PassResult:
tosa_spec = get_context_spec()
if not tosa_spec.is_U55_subset:
return PassResult(graph_module, False)
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in self.targeted_ops:
continue
Expand Down
10 changes: 9 additions & 1 deletion backends/arm/_passes/cast_to_int32_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
import torch

from executorch.backends.arm._passes.arm_pass import ArmPass

from executorch.backends.arm.tosa.specification import get_context_spec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, PassResult


class CastToInt32Pass(ArmPass):
Expand All @@ -22,6 +24,12 @@ class CastToInt32Pass(ArmPass):
exir_ops.edge.aten.bitwise_right_shift.Tensor,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
tosa_spec = get_context_spec()
if not tosa_spec.is_U55_subset:
return PassResult(graph_module, False)
return super().call(graph_module)

def call_operator(self, op, args, kwargs, meta):
if op not in self.targeted_ops:
return super().call_operator(op, args, kwargs, meta)
Expand Down
7 changes: 7 additions & 0 deletions backends/arm/_passes/convert_elu_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.backends.arm.constants import DQ_OPS
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult

Expand All @@ -30,6 +31,12 @@ def call(self, graph_module: torch.fx.GraphModule):
op="call_function", target=exir_ops.edge.aten.elu.default
)
for node in node_list:
input_node = node.all_input_nodes[0]
is_quantized = (
input_node.op == "call_function" and input_node.target in DQ_OPS
)
if not is_quantized:
continue
with graph.inserting_after(node):
replace_node = create_node(graph, exir_ops.edge.aten.elu.default)
old_args = list(node.args)
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/convert_int_pow_to_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.pow.Tensor_Scalar:
return super().call_operator(op, args, kwargs, meta)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

x = args[0]
exp = args[1]

Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_acosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
if op is not edge_acosh_op:
return super().call_operator(op, args, kwargs, meta, updated)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta, updated)

log_op, sqrt_op, mul_op, sub_op, add_op, add_op_scalar = (
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.sqrt.default,
Expand Down
9 changes: 9 additions & 0 deletions backends/arm/_passes/decompose_asin_and_acos_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ def _combine_branches(
def call_operator(self, op, args, kwargs, meta):
if op not in (edge_asin_op + edge_acos_op):
return super().call_operator(op, args, kwargs, meta)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

logging.info(
f"Approximating {op}. This may introduce small numerical errors. For details, see {__file__}."
)
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_asinh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def call_operator(self, op, args, kwargs, meta):
if op not in edge_asinh_op:
return super().call_operator(op, args, kwargs, meta)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

log_op, sqrt_op, mul_op, add_op_scalar, add_op = (
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.sqrt.default,
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_atan_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ def call_operator(self, op, args, kwargs, meta):
if op is not edge_atan:
return super().call_operator(op, args, kwargs, meta, updated=False)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

logging.info(
f"Approximating atan. This may introduce small numerical errors. For details, see {__file__}."
)
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_atanh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def call_operator(self, op, args, kwargs, meta):
if op is not edge_atanh:
return super().call_operator(op, args, kwargs, meta, updated=False)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

ops = _get_atanh_ops(op)
(
op_mul_tensor,
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_cosh_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ def call_operator(self, op, args, kwargs, meta, updated=False):
if op is not edge_cosh:
return super().call_operator(op, args, kwargs, meta, updated)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

x = args

exp_op, mul_op, neg_op, add_op = (
Expand Down
21 changes: 19 additions & 2 deletions backends/arm/_passes/decompose_elu_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,14 @@ def call_operator(self, op, args, kwargs, meta):
if op not in edge_elu_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

(
expm1_op,
ge_op,
Expand All @@ -75,8 +83,17 @@ def call_operator(self, op, args, kwargs, meta):
alpha = args[1] if len(args) > 1 else 1.0

if alpha == 0:
relu_op = exir_ops.edge.aten.relu.default
return super().call_operator(relu_op, (input,), {}, meta, updated=True)
relu_op = exir_ops.edge.aten.clamp.default
return super().call_operator(
relu_op,
(
input,
0,
),
{},
meta,
updated=True,
)

expm1_node = super().call_operator(expm1_op, (input,), {}, meta, updated=True)
mul_node = super().call_operator(
Expand Down
8 changes: 8 additions & 0 deletions backends/arm/_passes/decompose_expm1_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def call_operator(self, op, args, kwargs, meta):
if op not in edge_expm1_ops:
return super().call_operator(op, args, kwargs, meta, updated=False)

is_quantized = (
len(meta.data.get("input_qparams", {})) > 0
and len(meta.data.get("output_qparams", {})) > 0
)
if is_quantized:
# If quantized, node should be replace by table op
return super().call_operator(op, args, kwargs, meta)

(
op_pow,
op_div,
Expand Down
Loading
Loading