Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import cast

import torch
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm.tosa_quant_utils import dq_op
from executorch.backends.arm.tosa_utils import is_consumer_node_depthwise_conv2d
from executorch.exir.pass_base import ExportPass, PassResult
Expand Down Expand Up @@ -52,12 +53,7 @@ def call(self, graph_module: torch.fx.GraphModule):
NHWC_Order = (0, 2, 3, 1)
HWCM_Order = (2, 3, 0, 1)
for node in graph_module.graph.nodes:
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
node_data = node.meta["val"][0].data
else:
node_data = node.meta["val"].data
node_data = get_first_fake_tensor(node).data

if len(node_data.shape) == 4:
dim_order = NHWC_Order
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.arm._passes.insert_squeeze_after_sum_pass import (
InsertSqueezeAfterSumPass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
)
Expand All @@ -30,6 +31,9 @@
ScalarsToAttributePass,
)
from executorch.backends.arm._passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.pass_manager import PassManager
Expand All @@ -45,10 +49,12 @@ def transform_to_backend_pipeline(
):
"""Apply passes before transforming program to backend"""
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(InsertSqueezeAfterSumPass())
self.add_pass(ConvertSplitToSlicePass())
Expand All @@ -61,6 +67,6 @@ def transform_to_backend_pipeline(
return self._transform(exported_program.graph_module)

def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
self.add_pass(DecomposeDivPass())
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeDivPass())
return self._transform(graph_module)
20 changes: 20 additions & 0 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import Optional

import torch
import torch.fx

from executorch.exir.dialects._ops import ops as exir_ops
from torch._ops import OpOverload
from torch._subclasses.fake_tensor import FakeTensor


def create_node(
Expand Down Expand Up @@ -64,3 +66,21 @@ def insert_q_dq_pair(
# node's first use
q.args = (anchor,) + q_params
return dq


def get_first_fake_tensor(node: torch.fx.Node) -> FakeTensor:
"""
Returns a FakeTensor from the meta field of 'node'.
If the node contains many fake tensors, return the first one.
"""
if isinstance(
node.meta["val"], (tuple, torch.fx.immutable_collections.immutable_list)
):
fake_tensor = node.meta["val"][0]
else:
fake_tensor = node.meta["val"]

assert isinstance(
fake_tensor, FakeTensor
), f'Found {fake_tensor} in meta["val"] of {node}, expected to find FakeTensor.'
return fake_tensor
9 changes: 6 additions & 3 deletions backends/arm/_passes/decompose_div_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass

edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)


def get_div_decomposition(op) -> tuple:
"""
Returns the the (reciprocal_op, mul_op), where the ops depends on if
the div op is in exir_ops torch.ops.aten.
"""
if op == exir_ops.edge.aten.div.Tensor:
if op in edge_div_ops:
return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
if op == torch.ops.aten.div.Tensor:
if op in aten_div_ops:
return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
raise RuntimeError(f"Can't get div decomposition for op {op}")

Expand All @@ -33,7 +36,7 @@ class DecomposeDivPass(ExportPass):
"""

def call_operator(self, op, args, kwargs, meta):
if op not in (exir_ops.edge.aten.div.Tensor, torch.ops.aten.div.Tensor):
if op not in (edge_div_ops + aten_div_ops):
return super().call_operator(op, args, kwargs, meta)

reciprocal_op, mul_op = get_div_decomposition(op)
Expand Down
126 changes: 126 additions & 0 deletions backends/arm/_passes/match_arg_ranks_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import cast

from executorch.backends.arm._passes.arm_pass_utils import (
create_node,
get_first_fake_tensor,
)

from executorch.exir.dialects._ops import ops as exir_ops

from executorch.exir.pass_base import ExportPass, PassResult
from torch.fx import GraphModule, Node


class MatchArgRanksPass(ExportPass):
"""
For ops in 'targeted_ops', make sure that the inputs share the same rank.
New dimensions are inserted at from the beginning of the
"""

def __init__(self, exported_program):
super().__init__()
self.exported_program = exported_program

targeted_ops = [
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.div.Tensor,
]

def _match_op_rank(self, graph_module, node, arg, max_rank):
"""
In graph_module, insert a view between arg and node to make the
rank of arg match the other args to node.
"""
shape = get_first_fake_tensor(arg).shape
rank = len(shape)
new_shape = list([1] * (max_rank - rank) + list(shape))
with graph_module.graph.inserting_before(node):
view = create_node(
graph_module.graph,
exir_ops.edge.aten.view_copy.default,
args=(arg, new_shape),
kwargs={},
)
node.replace_input_with(arg, view)

def _match_buffer_rank(self, arg, max_rank):
"""
Change arg's fake tensor meta to match max_rank if:
- arg is found in inputs_to_buffers or inputs_to_parameters.
"""
fake_tensor = get_first_fake_tensor(arg)
shape = fake_tensor.shape
rank = len(shape)
new_shape = list([1] * (max_rank - rank) + list(shape))

buffer_name = None
if arg.name in self.exported_program.graph_signature.inputs_to_buffers:
buffer_name = self.exported_program.graph_signature.inputs_to_buffers[
arg.name
]
elif arg.name in self.exported_program.graph_signature.inputs_to_parameters:
buffer_name = self.exported_program.graph_signature.inputs_to_parameters[
arg.name
]
if buffer_name:
new_tensor = self.exported_program.state_dict[buffer_name].reshape(
new_shape
)
self.exported_program.state_dict[buffer_name] = new_tensor
arg.meta["val"] = fake_tensor.fake_mode.from_tensor(
new_tensor, static_shapes=True
)

def call(self, graph_module: GraphModule) -> PassResult:
for node in graph_module.graph.nodes:
node = cast(Node, node)

if node.op != "call_function" or node.target not in self.targeted_ops:
continue

# Calculate max rank of all inputs to node
max_rank = 1
for arg in node.args:
if isinstance(arg, Node):
shape = get_first_fake_tensor(arg).shape
max_rank = max(max_rank, len(shape))

# Adjust output shape of args if needed.
for arg in node.args:
if not isinstance(arg, Node):
continue
shape = get_first_fake_tensor(arg).shape
rank = len(shape)
if rank == max_rank:
continue

# If the argument is call_function, match shape by inserting view node.
if arg.op == "call_function":
self._match_op_rank(graph_module, node, arg, max_rank)
else:
# If the argument is a buffer or parameter, adjust shape by changing the fake tensor meta.
self._match_buffer_rank(arg, max_rank)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)

def ensures(self, graph_module):
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target not in self.targeted_ops:
continue
arg0_rank = node.args[0].meta["val"].dim()
arg1_rank = node.args[1].meta["val"].dim()
if arg0_rank != arg1_rank:
raise ValueError(
"Arguments of arithmetic operators need to have the same rank!"
)
8 changes: 6 additions & 2 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import cast, Union

import torch
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor

from executorch.exir.pass_base import ExportPass, PassResult
from torch.ao.quantization.fx.utils import get_new_attr_name_with_prefix
Expand All @@ -22,10 +22,14 @@ class ScalarsToAttributePass(ExportPass):

targeted_ops = [
torch.ops.aten.add.Tensor,
torch.ops.aten.add_.Tensor,
torch.ops.aten.sub.Tensor,
torch.ops.aten.sub_.Tensor,
torch.ops.aten.rsub.Scalar,
torch.ops.aten.mul.Tensor,
torch.ops.aten.mul_.Tensor,
torch.ops.aten.div.Tensor,
torch.ops.aten.div_.Tensor,
]

def call(self, graph_module: GraphModule) -> PassResult:
Expand All @@ -37,7 +41,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
biggest_rank = 1
for arg in n.args:
if isinstance(arg, Node):
_, shape, _ = extract_tensor_meta(arg.meta)
shape = get_first_fake_tensor(arg).shape
biggest_rank = max(biggest_rank, len(shape))

new_args = []
Expand Down
53 changes: 53 additions & 0 deletions backends/arm/_passes/unsqueeze_scalar_placeholders_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.exir.pass_base import ExportPass, PassResult


class UnsqueezeScalarPlaceholdersPass(ExportPass):
"""
Placeholders that have node.meta["val"].shape = () cause issues later in the lowering.
This pass unsqueezes the placeholders to make sure shape is at least (1,).
"""

def __init__(self, exported_program):
self.exported_program = exported_program
super().__init__()

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op != "placeholder":
continue
rank = node.meta["val"].dim()
if rank == 0:
if not (
node.name in self.exported_program.graph_signature.inputs_to_buffers
or node.name
in self.exported_program.graph_signature.inputs_to_parameters
):
continue
tensor = self.exported_program.state_dict[node.name]
if tensor.dim() == 0:
self.exported_program.state_dict[node.name] = tensor.unsqueeze(0)
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor.unsqueeze(0), static_shapes=True
)
else:
node.meta["val"] = node.meta["val"].fake_mode.from_tensor(
tensor, static_shapes=True
)

graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)

def ensures(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if node.op == "placeholder":
rank = node.meta["val"].dim()
if rank == 0:
raise ValueError("Placeholders of rank 0 are not supported!")
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _annotate_mul(

annotated_partitions = []
for node in gm.graph.nodes:
if node.target not in (torch.ops.aten.mul.Tensor,):
if node.target not in (torch.ops.aten.mul.Tensor, torch.ops.aten.mul_.Tensor):
continue
mul_node = node
annotated_partitions.append([mul_node])
Expand Down
15 changes: 5 additions & 10 deletions backends/arm/quantizer/quantization_annotation/sub_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

# pyre-unsafe

import itertools
import operator
from typing import Callable, List, Optional

import torch
Expand All @@ -16,7 +14,6 @@
from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from torch.ao.quantization.quantizer import QuantizationAnnotation
from torch.fx import GraphModule, Node
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions


@register_annotator("sub")
Expand All @@ -25,14 +22,12 @@ def _annotate_sub(
quantization_config: QuantizationConfig,
filter_fn: Optional[Callable[[Node], bool]] = None,
) -> Optional[List[List[Node]]]:
sub_partitions = get_source_partitions(
gm.graph, [operator.sub, torch.sub, operator.isub], filter_fn
)
sub_partitions = list(itertools.chain.from_iterable(sub_partitions.values()))
annotated_partitions = []
for sub_partition in sub_partitions:
annotated_partitions.append(sub_partition.nodes)
sub_node = sub_partition.output_nodes[0]
for node in gm.graph.nodes:
if node.target not in (torch.ops.aten.sub.Tensor, torch.ops.aten.sub_.Tensor):
continue
annotated_partitions.append(node)
sub_node = node
if arm_quantizer_utils.is_annotated(sub_node):
continue

Expand Down
Loading
Loading