Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def _transform(self, graph_module: GraphModule):
with TosaLoweringContext(self.tosa_spec):
return self(graph_module).graph_module

def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
def _tosa_INT_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(ConvertFullLikeToFullPass())
self.add_pass(ConvertToClampPass())
self.add_pass(ConvertMinMaxPass())
Expand Down Expand Up @@ -218,9 +218,11 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)
return self._transform(graph_module)

def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
def _tosa_FP_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
) -> GraphModule:
self.add_pass(AnnotateOutputDimOrderPass())
self.add_pass(DecomposeExpm1Pass())
self.add_pass(DecomposeLogitPass())
Expand Down Expand Up @@ -255,9 +257,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeBatchNormNoStatsPass())
self.add_pass(DecomposeVarPass())
self.add_pass(
DecomposeMeanDimPass(exported_program.graph_module, self.tosa_spec)
)
self.add_pass(DecomposeMeanDimPass(graph_module, self.tosa_spec))
self.add_pass(DecomposeNotEqualPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeAddSubAlphaPass())
Expand Down Expand Up @@ -305,14 +305,16 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(InsertRescalePass())

self.validate_constraints_mandatory()
return self._transform(exported_program.graph_module)
return self._transform(graph_module)

def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, graph_module: GraphModule
):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_FP_pipeline(exported_program)
return self._tosa_FP_pipeline(exported_program, graph_module)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
return self._tosa_INT_pipeline(exported_program)
return self._tosa_INT_pipeline(exported_program, graph_module)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
Expand Down
10 changes: 4 additions & 6 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def process_inputs_to_buffers(
buffer_values = np.transpose(buffer_values, tosa_arg.dim_order)

tosa_graph.addConst(
buffer_values.shape, tosa_arg.dtype, buffer_values, name=node.name
buffer_values.shape, tosa_arg.dtype, buffer_values, name=tosa_arg.name
)


Expand Down Expand Up @@ -215,11 +215,9 @@ def process_placeholder(
raise RuntimeError(f"Placeholder '{node.name}' is of unknown type.")


def process_output(
node: torch.fx.Node,
tosa_graph: Any,
):
def process_output(node: torch.fx.Node, tosa_graph: Any, tosa_spec: TosaSpecification):
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
output_arg = TosaArg(output, tosa_spec)
tosa_graph.addOutputTensor(
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
tosa_graph.currRegion.currBasicBlock.tensors[output_arg.name]
)
148 changes: 94 additions & 54 deletions backends/arm/tosa/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
process_placeholder,
)
from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec
from executorch.backends.arm.tosa.mapping import TOSA_TENSOR_NAME_META
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.graph_module import get_control_flow_submodules
from torch.export.exported_program import ExportedProgram
from torch.fx import Graph, Node
from torch.fx import Graph, GraphModule, Node


# TOSA backend debug functionality
logger = logging.getLogger(__name__)
Expand All @@ -52,13 +55,39 @@ def bfs_mark(start_nodes: List[Node], idx: int, seen: Set[Node]):
# Walk backwards so we touch every producer
q.extend(n.all_input_nodes)

out = next(n for n in ep_graph.nodes if n.op == "output")
out = ep_graph.output_node()
# First argument of output node is tuple of outputs
output_list = cast(tuple, out.args[0])
seen: Set[Node] = set()
for idx, val in enumerate(out.args[0]):
for idx, val in enumerate(output_list):
bfs_mark([val], idx, seen)
return node2external_id


def _sort_outputs(graph_module: GraphModule, node_to_id_map: dict[str, int]):
def _external_id(n: Node, node_2_id, fallback: int) -> int:
return node_2_id.get(n.name, fallback)

out_node = graph_module.graph.output_node()
out_list = cast(tuple, out_node.args[0])
_counter = count()

# sort nodes by the key that is id
def _sort_key(t: Node) -> int:
return _external_id(t, node_to_id_map, next(_counter))

orig_ord = tuple(sorted(out_list, key=_sort_key))

current_order = tuple(out_list)
if orig_ord != current_order:
replacement = list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
out_node.args = (replacement,)
graph_module.graph.lint()
graph_module.recompile()

return graph_module


def arm_get_first_delegation_tag(graph_module) -> str:
"""Get the first delegation tag from the graph_module or return empty string."""
for node in graph_module.graph.nodes:
Expand Down Expand Up @@ -93,9 +122,9 @@ def _preprocess( # noqa: C901
artifact_path = compile_spec.get_intermediate_path()
tosa_spec = compile_spec.tosa_spec
dump_debug_info = compile_spec.tosa_debug_mode

# Assign to every node external id
node_2_id = _annotate_external_ids(edge_program.graph)
debug_hook = None
if dump_debug_info is not None:
debug_hook = DebugHook(dump_debug_info)

logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")

Expand All @@ -116,45 +145,66 @@ def _preprocess( # noqa: C901
f"doesn't match specification {tosa_spec}"
)

TOSABackend._preprocess_module(
edge_program.graph_module,
edge_program,
compile_spec,
tosa_graph,
debug_hook,
)
# Serialize and return the TOSA flatbuffer.
binary = tosa_graph.serialize()

if artifact_path:
tag = arm_get_first_delegation_tag(edge_program.graph_module)
debug_tosa_dump(
binary,
artifact_path,
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
)

if debug_hook is not None:
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
json_output = debug_hook.serialize()
with open(f"{artifact_path}/debug.json", "w") as f:
f.write(json_output)

return PreprocessResult(processed_bytes=binary)

@staticmethod
def _preprocess_module( # noqa: C901
graph_module: GraphModule,
edge_program: ExportedProgram,
compile_spec: TosaCompileSpec,
tosa_graph: ts.TosaSerializer,
debug_hook: DebugHook | None,
submodule_name: str | None = None,
):
"""Convert 'graph_module' to a tosa_graph"""
tosa_spec = compile_spec.tosa_spec
node_to_id_map = _annotate_external_ids(graph_module.graph)
artifact_path = compile_spec.get_intermediate_path()

# TODO: Fix the need to lazily import this.
from executorch.backends.arm._passes import ArmPassManager

graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
exported_program=edge_program
exported_program=edge_program, graph_module=graph_module
)

debug_hook = None
if dump_debug_info is not None:
debug_hook = DebugHook(dump_debug_info)

# TODO: Fix the need to lazily import this.
from executorch.backends.arm.operators.node_visitor import get_node_visitors

node_visitors = get_node_visitors(edge_program, tosa_spec, debug_hook)
graph_module = _sort_outputs(graph_module, node_to_id_map)

# Re-shuffle output nodes to preserve author's order
def _external_id(n: Node, node_2_id, fallback: int) -> int:
return node_2_id.get(n.name, fallback)

out_node = next(n for n in graph_module.graph.nodes if n.op == "output")
_counter = count()

# sort nodes by the key that is id
def _sort_key(t: Node) -> int:
return _external_id(t, node_2_id, next(_counter))
if submodule_name is not None:
tosa_graph.startRegion(submodule_name)
tosa_graph.currRegion.addBasicBlock(submodule_name)
suffix = f"_{submodule_name}"
for loop_node in graph_module.graph.nodes:
loop_node.meta[TOSA_TENSOR_NAME_META] = suffix

orig_ord = tuple(sorted(out_node.args[0], key=_sort_key))

current_order = tuple(out_node.args[0])
if orig_ord != current_order:
replacement = (
list(orig_ord) if isinstance(out_node.args[0], list) else orig_ord
)
out_node.args = (replacement,)
graph_module.graph.lint()
graph_module.recompile()

input_count = 0
for node in graph_module.graph.nodes:
node = cast(Node, node)
try:
Expand All @@ -164,37 +214,27 @@ def _sort_key(t: Node) -> int:
if len(node.users) == 0:
continue
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
if node.name in edge_program.graph_signature.user_inputs:
input_count += 1
elif node.op == "output":
process_output(node, tosa_graph)
process_output(node, tosa_graph, tosa_spec)
else:
# This will only happen if an unpartitioned graph is passed without
# any checking of compatibility.
raise RuntimeError(f"{node.name} is unsupported op {node.op}")
except Exception:
debug_fail(node, graph_module, tosa_graph.serialize(), artifact_path)
debug_fail(node, graph_module, tosa_graph, artifact_path)
raise

# Serialize and return the TOSA flatbuffer.
binary = tosa_graph.serialize()

if artifact_path:
tag = arm_get_first_delegation_tag(graph_module)
debug_tosa_dump(
binary,
artifact_path,
suffix="{}".format(f"_{tag}" if tag else "") + (f"_{tosa_spec}"),
# Recursively preprocess controlflow submodules.
for name, submodule, _ in get_control_flow_submodules(graph_module):
TOSABackend._preprocess_module(
submodule,
edge_program,
compile_spec,
tosa_graph,
debug_hook,
submodule_name=name,
)

if debug_hook is not None:
if debug_hook.mode == ArmCompileSpec.DebugMode.JSON:
json_output = debug_hook.serialize()
with open(f"{artifact_path}/debug.json", "w") as f:
f.write(json_output)

return PreprocessResult(processed_bytes=binary)

@staticmethod
def filter_tosa_compile_specs(
compile_spec: ArmCompileSpec,
Expand Down
4 changes: 3 additions & 1 deletion backends/arm/tosa/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import tosa_serializer as ts
from executorch.backends.arm.tosa.specification import TosaSpecification

TOSA_TENSOR_NAME_META = "tosa_tensor_name"

UNSUPPORTED_DTYPES = (
torch.float64,
torch.double,
Expand Down Expand Up @@ -144,7 +146,7 @@ def __process_node(self, argument: torch.fx.Node):
argument (torch.fx.Node): FX node to inspect.

"""
self.name: str = argument.name
self.name = argument.name + argument.meta.get(TOSA_TENSOR_NAME_META, "")
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
argument.meta, self.tosa_spec
)
Expand Down
Loading