Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,25 @@
from torch.export.graph_signature import InputKind


def is_submodule_node(node: torch.fx.Node):
if node.op not in ("get_attr", "placeholder"):
return False
try:
node.graph.owning_module.get_submodule(node.target)
except AttributeError:
return False
return True


def is_get_attr_node(node: torch.fx.Node) -> bool:
"""
Returns true if the given node is a get attr node for a tensor of the model
Returns true if the given node is a get attr node for a tensor of the model.
"""
return isinstance(node, torch.fx.Node) and node.op == "get_attr"
return (
isinstance(node, torch.fx.Node)
and node.op == "get_attr"
and not is_submodule_node(node)
)


def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _to_int32(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if len(node.users) == 0:
continue
if "val" not in node.meta:
continue
fake_tensor = node.meta["val"]
if not isinstance(fake_tensor, torch._subclasses.fake_tensor.FakeTensor):
continue
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,8 @@ def remove_dim_order_kwargs(

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:
if "val" not in node.meta:
continue
node_data = get_first_fake_tensor(node).data

self.remove_dim_order_kwargs(graph_module, node)
Expand Down
147 changes: 139 additions & 8 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,15 @@
import itertools
import operator
import typing
from typing import final, Optional, Sequence, Type
from typing import cast, final, Optional, Sequence, Type

import torch
import torch.fx as fx

from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
from executorch.backends.arm._passes.arm_pass_utils import (
get_first_fake_tensor,
is_submodule_node,
)
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
FuseQuantizedActivationPass,
Expand All @@ -31,6 +34,7 @@
TOSA_PRO_INT_SupportList,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.specification import Tosa_1_00
from executorch.exir import ExportedProgram
from executorch.exir.backend.utils import WhyNoPartitionReporter
from executorch.exir.dialects._ops import ops as exir_ops
Expand Down Expand Up @@ -110,7 +114,9 @@ def tosa_support_factory(
Additional checks can be supplied to avoid partitioning additional nodes.
"""
# Postive checks: Add nodes to partitioning
positive_checks: list[OperatorSupportBase] = []
positive_checks: list[OperatorSupportBase] = [
CondSupported(exported_program, tosa_spec, reporter)
]

if tosa_spec.support_integer():
positive_checks.append(TOSAProINTSupportList())
Expand Down Expand Up @@ -350,7 +356,8 @@ def inside_int32_bounds(self, node: torch.fx.Node) -> bool:
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

if is_submodule_node(node):
return True
vals = node.meta["val"]
tensor_list = vals if isinstance(vals, (list, tuple)) else [vals]

Expand Down Expand Up @@ -390,7 +397,11 @@ def is_node_supported(

# Ops with int64 inputs are only partitioned if input nodes are constant and will be partitioned.
# If it is not partitioned, the partition will get an int64 input and fail.
for input_node in node.all_input_nodes:
for input_node in (
input_node
for input_node in node.all_input_nodes
if input_node.op != "get_attr"
):
tensor_in = get_first_fake_tensor(input_node)
if tensor_in.dtype != torch.int64:
continue
Expand Down Expand Up @@ -426,8 +437,13 @@ def __init__(
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:

for input_node in node.all_input_nodes:
if is_submodule_node(node):
return True
for input_node in (
input_node
for input_node in node.all_input_nodes
if input_node.op != "get_attr"
):
tensor = get_first_fake_tensor(input_node)
if tensor.dtype == torch.float64:
self.reporter.report_reject(
Expand All @@ -449,7 +465,13 @@ def __init__(self, reporter: WhyNoPartitionReporter, max_rank: int):
def is_node_supported(
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
input_nodes = node.all_input_nodes
if is_submodule_node(node):
return True
input_nodes = (
input_node
for input_node in node.all_input_nodes
if input_node.op != "get_attr"
)
# check if any input node has an unsupported rank
for input_node in input_nodes:
input_node_shape = get_first_fake_tensor(input_node).shape
Expand Down Expand Up @@ -484,3 +506,112 @@ def is_node_supported(
)
return False
return True


class CondSupported(OperatorSupportBase):
"""Checks whether the cond operator, and it's submodule args, should be partitioned."""

def __init__(
self,
exported_program: ExportedProgram,
tosa_spec: TosaSpecification,
reporter: WhyNoPartitionReporter,
):
self.exported_program = exported_program
self.reporter = reporter
self.tosa_spec = tosa_spec
super().__init__()

def _fully_partitioned(self, submodule: fx.GraphModule) -> bool:
partition_tag = None
for submodule_node in submodule.graph.nodes:
if submodule_node.op == "call_function":
# Input Q ops and output DQ ops will be de-tagged even if the submodule is fully supported.
if (
submodule_node.target in Q_OPS
and list(submodule_node.all_input_nodes)[0].op == "placeholder"
):
continue
if (
submodule_node.target in DQ_OPS
and list(submodule_node.users)[0].op == "output"
):
continue
if "delegation_tag" not in submodule_node.meta:
return False
if partition_tag is None:
partition_tag = submodule_node.meta["delegation_tag"]
elif submodule_node.meta["delegation_tag"] != partition_tag:
return False
return True

def _cond_submodules_fully_partitioned(self, node: fx.Node) -> bool:
"""Returns whether the submodule arguments to a cond node were fully partitioned.
Updates "val" meta of the submodules if they are.
"""
cond_submodules = (
(
self.exported_program.graph_module.get_submodule(
str(cast(torch.fx.Node, submodule_node).target)
),
cast(torch.fx.Node, submodule_node),
)
for submodule_node in node.args[1:3]
)
for submodule, submodule_node in cond_submodules:
submodule = cast(torch.fx.GraphModule, submodule)

if self._fully_partitioned(submodule):
submodule_node.meta["val"] = submodule.graph.output_node().meta["val"]
else:
return False
return True

def is_node_supported( # noqa: C901
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
) -> bool:
if is_submodule_node(node):
if not isinstance(self.tosa_spec, Tosa_1_00):
self.reporter.report_reject(
node, "Control flow extension not supported for TOSA version <1.0"
)
return False
if not self.tosa_spec.support_extension("cf"):
self.reporter.report_reject(
node,
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
)
return False
for user in node.users:
if user.target != torch.ops.higher_order.cond:
self.reporter.report_reject(
node, f"Submodule had unsupported user {user}"
)
return False
if not self._cond_submodules_fully_partitioned(user):
self.reporter.report_reject(
node, "One submodule was not fully partitioned"
)
return False
return True
if node.target == torch.ops.higher_order.cond:
if not isinstance(self.tosa_spec, Tosa_1_00):
self.reporter.report_reject(
node, "Control flow extension not supported for TOSA version <1.0"
)
return False
if not self.tosa_spec.support_extension("cf"):
self.reporter.report_reject(
node,
f"TOSA spec {self.tosa_spec} does not support control flow extension.",
)
return False

if not self._cond_submodules_fully_partitioned(node):
self.reporter.report_reject(
node, "Submodule was not fully partitioned."
)
return False
return True

return False
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
op_cat,
op_ceil,
op_clamp,
op_cond_if,
op_constant_pad_nd,
op_cos,
op_eq,
Expand Down
61 changes: 61 additions & 0 deletions backends/arm/operators/op_cond_if.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import Any, cast, List

import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import ( # type: ignore
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg # type: ignore
from executorch.backends.arm.tosa.specification import Tosa_1_00
from torch.fx import Node


@register_node_visitor
class CondVisitor(NodeVisitor):
target = "cond"

tosa_specs = NodeVisitor.tosa_specs

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

validate_num_inputs(self.target, inputs, 4)
validate_valid_dtype(self.target, [inputs[0]], ts.DType.BOOL, self.tosa_spec)
if not isinstance(self.tosa_spec, Tosa_1_00):
raise ValueError("Trying to lower cond, but TOSA version is <1.0.")
if not self.tosa_spec.support_extension("cf"):
raise ValueError(
f"Trying to lower cond, but TOSA specification {self.tosa_spec} does not support the cf extension."
)

attr = ts.TosaSerializerAttribute()
if_graph, else_graph = (cast(Node, arg).target for arg in node.args[1:3])
attr.CondIfAttribute(if_graph, else_graph)

self._serialize_operator(
node,
tosa_graph,
ts.Op.COND_IF,
[
inputs[0].name,
*(subgraph_input.name for subgraph_input in inputs[-1].special),
],
[output.name],
attr,
)
10 changes: 5 additions & 5 deletions backends/arm/operators/op_index_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,14 @@ def define_node(
# channels and thus the stride-shift.
data = np.full(index_shape, int(values_strides[i] / C))
mul_const = tosa_graph.addConst(index_shape, index_dtype, data)
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_{i}_shift")
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_{i}_shift")
attr = ts.TosaSerializerAttribute()
attr.MulAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.MUL,
[index_name, mul_const.name, f"{node.name}_{i}_shift"],
[index_name, mul_const.name, f"{output.name}_{i}_shift"],
[stride_shifted_indices.name],
attr,
)
Expand All @@ -186,7 +186,7 @@ def define_node(
stride_shifted_indices.name,
gather_idx_shape,
reshaped_idxs.name,
shape_name_override=f"{node.name}_{i}_shape",
shape_name_override=f"{output.name}_{i}_shape",
)

# Guarantees that the accumulation tensor is properly
Expand Down Expand Up @@ -218,7 +218,7 @@ def define_node(
values.name,
gather_vals_shape,
reshaped_input.name,
shape_name_override=f"{node.name}_index_shape",
shape_name_override=f"{output.name}_index_shape",
)

gather_out_shape = (N, W, C)
Expand All @@ -244,5 +244,5 @@ def define_node(
gather_out.name,
list(output_shape),
output.name,
shape_name_override=f"{node.name}_output_shape",
shape_name_override=f"{output.name}_output_shape",
)
4 changes: 2 additions & 2 deletions backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ def define_node(
output.tosa_spec,
)

tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{node.name}_shift")
tosa_graph.addConst([1], ts.DType.INT8, 0, name=f"{output.name}_shift")
attr = ts.TosaSerializerAttribute()
attr.MulAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.MUL,
[inputs[0].name, inputs[1].name, f"{node.name}_shift"],
[inputs[0].name, inputs[1].name, f"{output.name}_shift"],
[output.name],
attr,
)
2 changes: 1 addition & 1 deletion backends/arm/operators/op_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def define_node(
(len(multiples),),
ts.DType.SHAPE,
list(tosa_shape(multiples, output.dim_order)),
name=node.name + "_multiples",
name=output.name + "_multiples",
)

attr = ts.TosaSerializerAttribute()
Expand Down
Loading
Loading