diff --git a/examples/export/export_and_delegate.py b/examples/export/export_and_delegate.py index a54f9dafedb..3ddb9eb215d 100644 --- a/examples/export/export_and_delegate.py +++ b/examples/export/export_and_delegate.py @@ -8,7 +8,6 @@ import argparse -import executorch.exir as exir import torch from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -157,11 +156,17 @@ def export_and_lower_the_whole_graph(): # Lower AddMulModule to the demo backend print("Lowering to the demo backend...") - _ = to_backend( - BackendWithCompilerDemo.__name__, edge.exported_program, m.get_compile_spec() + lowered_module = to_backend( + BackendWithCompilerDemo.__name__, edge, m.get_compile_spec() ) - # TODO(chenlai): emit the lowered graph + buffer = lowered_module.buffer() + + model_name = "whole" + filename = f"{model_name}.pte" + print(f"Saving exported program to {filename}") + with open(filename, "wb") as file: + file.write(buffer) OPTIONS_TO_LOWER = { diff --git a/exir/TARGETS b/exir/TARGETS index 20958142bc6..5961775cb13 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -120,10 +120,14 @@ python_library( deps = [ ":delegate", ":graph_module", - ":lib", + ":schema", ":tracer", "//caffe2:torch", "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/emit:lib", + "//executorch/exir/passes:memory_planning_pass", + "//executorch/exir/passes:spec_prop_pass", + "//executorch/exir/serialize:lib", ], ) diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index fe0718921c9..223fc22f41f 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -145,6 +145,29 @@ python_unittest( ], ) +python_unittest( + name = "test_lowered_backend_module", + srcs = [ + "test_lowered_backend_module.py", + ], + supports_static_listing = True, + deps = [ + "fbsource//third-party/pypi/hypothesis:hypothesis", + ":backend_with_compiler_demo", + ":qnn_backend_demo", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/tests:models", + "//executorch/extension/pybindings:portable", # @manual + "//executorch/kernels/portable:custom_ops_generated_lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "//executorch/runtime/executor/test:test_backend_compiler_lib", + ], +) + python_unittest( name = "test_graph_partition", srcs = [ diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index c79ccc728bb..94b7d9121bf 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -115,31 +115,6 @@ def check_backend_delegate( program.backend_delegate_data[processed.index].data, expected_processed ) - def test_simple(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.sin(x) - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - expected_res = sin_module(*model_inputs) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - - lowered_sin_module = to_backend( - "BackendWithCompilerDemo", edgeir_m.exported_program, [] - ) - new_res = lowered_sin_module(*model_inputs) - - self.assertTrue(torch.allclose(new_res, expected_res)) - - # TODO(tkaruturi): emitting single LoweredBackendModule - # program = exir.capture(graph_module).to_edge().to_exectorch().program - @vary_segments def test_backend_with_compiler(self, extract_segments: bool): class SinModule(torch.nn.Module): diff --git a/exir/backend/test/test_lowered_backend_module.py b/exir/backend/test/test_lowered_backend_module.py new file mode 100644 index 00000000000..91078f0ed7f --- /dev/null +++ b/exir/backend/test/test_lowered_backend_module.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import executorch.exir.tests.models as models + +import torch +from executorch import exir +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.test.backend_with_compiler_demo import ( + BackendWithCompilerDemo, +) +from executorch.exir.backend.test.qnn_backend_demo import QnnBackend +from executorch.exir.schema import DelegateCall, Program + +from executorch.extension.pybindings.portable import ( # @manual + _load_for_executorch_from_buffer, +) +from hypothesis import given, settings, strategies as st + + +class TestBackendAPI(unittest.TestCase): + def validate_lowered_module_program(self, program: Program) -> None: + """ + For any program emitted from lowered_backend_module, we expect only one delegate call + """ + # there should only be one instruction + self.assertEqual( + len(program.execution_plan[0].chains[0].instructions), + 1, + ) + + # the only instruction should be a delegate call + self.assertTrue( + isinstance( + program.execution_plan[0].chains[0].instructions[0].instr_args, + DelegateCall, + ) + ) + + def get_program_from_wrapped_module( + self, lowered_module, example_inputs, capture_config, edge_compile_config + ): + class WrappedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.one_module = lowered_module + + def forward(self, *args): + return self.one_module(*args) + + return ( + exir.capture(WrappedModule(), example_inputs, capture_config) + .to_edge(edge_compile_config) + .to_executorch() + .program + ) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_lowered_backend_module_end_to_end(self, unlift): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + expected_res = sin_module(*model_inputs) + edgeir_m = exir.capture( + sin_module, + model_inputs, + exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift), + ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True)) + max_value = model_inputs[0].shape[0] + compile_specs = [CompileSpec("max_value", bytes([max_value]))] + lowered_sin_module = to_backend( + BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs + ) + + new_res = lowered_sin_module(*model_inputs) + + self.assertTrue(torch.allclose(new_res[0], expected_res)) + program = lowered_sin_module.program() + self.validate_lowered_module_program(program) + buff = lowered_sin_module.buffer() + + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + model_outputs = executorch_module.forward([model_inputs]) + self.assertEqual( + model_inputs, + torch.ones(1), + ) + expected_res = 0.8333 * torch.ones(1) + + self.assertTrue( + torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03) + ) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_lowered_backend_module(self, unlift): + module_list = [ + models.Emformer(), + models.Repeat(), + models.ElementwiseAdd(), + models.MLP(), + models.ModelWithUnusedArg(), + ] + + capture_config = ( + exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() + ) + + edge_compile_config = exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ) + + for model in module_list: + model_inputs = model.get_random_inputs() + + edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( + edge_compile_config + ) + lowered_model = to_backend( + QnnBackend.__name__, edgeir_m.exported_program, [] + ) + program = lowered_model.program() + reference_program = self.get_program_from_wrapped_module( + lowered_model, model_inputs, capture_config, edge_compile_config + ) + + # Check program is fairly equal to the reference program + self.assertEqual( + len(program.execution_plan[0].chains[0].instructions), + len(reference_program.execution_plan[0].chains[0].instructions), + ) + + self.assertEqual( + len(program.execution_plan[0].values), + len(reference_program.execution_plan[0].values), + ) + + self.assertEqual( + len(program.execution_plan[0].inputs), + len(reference_program.execution_plan[0].inputs), + ) + + self.assertEqual( + len(program.execution_plan[0].outputs), + len(reference_program.execution_plan[0].outputs), + ) + + # Ensure we can get the buffer + _ = lowered_model.buffer() + self.validate_lowered_module_program(program) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_nested_lowered_backend_module(self, unlift): + module_list = [ + models.Emformer(), + models.Repeat(), + models.ElementwiseAdd(), + models.MLP(), + models.ModelWithUnusedArg(), + ] + + capture_config = ( + exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() + ) + + edge_compile_config = exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ) + + for model in module_list: + model_inputs = model.get_random_inputs() + + edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( + edge_compile_config + ) + lowered_module = to_backend( + QnnBackend.__name__, edgeir_m.exported_program, [] + ) + + # This module will include one operator and two delegate call + class WrappedModule(torch.nn.Module): + def __init__(self, lowered_module): + super().__init__() + self.one_module = lowered_module + + def forward(self, *args): + return self.one_module(*args) + + wrapped_module = WrappedModule(lowered_module) + wrapped_module_edge = exir.capture( + wrapped_module, model_inputs, capture_config + ).to_edge(edge_compile_config) + + nested_lowered_model = to_backend( + QnnBackend.__name__, wrapped_module_edge.exported_program, [] + ) + + program = nested_lowered_model.program() + self.validate_lowered_module_program(program) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 96190ecac0f..91a4dbb8ade 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -7,18 +7,29 @@ # pyre-strict import copy -from typing import Dict, List, Tuple, Union +import operator +from typing import Dict, List, Optional, Tuple, Union import torch import torch.utils._pytree as pytree -from executorch.exir import CallSpec, ExportGraphSignature from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name +from executorch.exir.emit import emit_program from executorch.exir.graph_module import _get_submodule +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass +from executorch.exir.schema import Program +from executorch.exir.serialize import serialize_to_flatbuffer + from executorch.exir.tracer import Value -from torch._export.exported_program import ExportedProgram + +from torch._export.exported_program import ( + CallSpec, + ExportedProgram, + ExportGraphSignature, +) from torch._subclasses import FakeTensor from torch.fx.passes.utils.fuser_utils import ( erase_nodes, @@ -77,6 +88,168 @@ def compile_specs(self) -> List[CompileSpec]: def original_module(self) -> ExportedProgram: return self._original_module + # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api + def buffer( + self, + extract_segments: bool = False, + segment_alignment: int = 4096, + constant_tensor_alignment: Optional[int] = None, + delegate_alignment: Optional[int] = None, + ) -> bytes: + out = serialize_to_flatbuffer( + program=self.program(), + extract_segments=extract_segments, + segment_alignment=segment_alignment, + constant_tensor_alignment=constant_tensor_alignment, + delegate_alignment=delegate_alignment, + ) + return out + + # TODO(chenlai): re-consider recapture instead of manually constructing the program because + # the meta data construction is done manually. + def program(self, emit_stacktrace: bool = False) -> Program: + """ + The idea in this function is to create a module based on the original module. The original module will + look something like following: + + opcode name target args kwargs + ------------- ------------------- ---------------- ------------------------------------------ -------- + placeholder arg0_1 arg0_1 () {} + placeholder arg1_1 arg1_1 () {} + call_function aten_repeat_default * (arg1_1, [4, 1]) {} + call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} + call_function aten_add_tensor * (arg1_1, arg1_1) {} + output output output ([aten_mul_tensor, aten_add_tensor],) {} + + if the whole module is lowered, the resulting lowered module look like + + opcode name target args kwargs + ------------- ------------------------ --------------------------- ---------------------------------- -------- + placeholder arg0_1 arg0_1 () {} + placeholder arg1_1 arg1_1 () {} + get_attr lowered_module_0 lowered_module_0 () {} + call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} + call_function getitem (executorch_call_delegate, 0) {} + call_function getitem_1 (executorch_call_delegate, 1) {} + output output_1 output ([getitem, getitem_1],) {} + + We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node + and return the list of getitems as the output + """ + lowered_exported_program = copy.deepcopy(self.original_module) + + # The real input nodes are the ones not buffer or parameter + all_input_nodes = [ + node + for node in lowered_exported_program.graph.nodes + if ( + node.op == "placeholder" + and node.name + not in lowered_exported_program.graph_signature.inputs_to_buffers + and node.name + not in lowered_exported_program.graph_signature.inputs_to_parameters + ) + ] + + output_node = [ + node for node in lowered_exported_program.graph.nodes if node.op == "output" + ] + assert len(output_node) == 1, "There should be only one output node" + + # Step 1. Cleaning up the graph before inserting the call_delegate node + # Remove the original output node + lowered_exported_program.graph.erase_node(output_node[0]) + + # Remove all the everything else except the input + for node in reversed(lowered_exported_program.graph.nodes): + if node.op != "placeholder": + lowered_exported_program.graph.erase_node(node) + + # Find placeholders that are parameters or buffers, remove them from the main graph + for node in lowered_exported_program.graph.nodes: + if node.op == "placeholder" and ( + node.name in lowered_exported_program.graph_signature.inputs_to_buffers + or node.name + in lowered_exported_program.graph_signature.inputs_to_parameters + ): + lowered_exported_program.graph.erase_node(node) + + # Step 2. Start constructing the graph + lowered_name = get_lowered_module_name( + lowered_exported_program.graph_module, self + ) + # Insert the lowered module to the graph module as an attibute + lowered_node = lowered_exported_program.graph.get_attr(lowered_name) + + # Insert a call_delegate node to the graph module, with arguments from the arg list + delegate_node = lowered_exported_program.graph.call_function( + executorch_call_delegate, (lowered_node, *all_input_nodes) + ) + # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) + # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly + original_output_nodes = [ + node for node in self.original_module.graph.nodes if node.op == "output" + ][0].args[0] + + delegate_node.meta["spec"] = tuple( + [make_spec(node.meta["val"]) for node in original_output_nodes] + ) + + # The getitem nodes that are going to be inserted to the lowered graph module + getitem_nodes = [] + for i in range(len(original_output_nodes)): + getitem_node = lowered_exported_program.graph.call_function( + operator.getitem, + args=(delegate_node, i), + ) + getitem_nodes.append(getitem_node) + lowered_exported_program.graph.output(getitem_nodes) + + lowered_exported_program.graph_module.recompile() + lowered_exported_program.graph.lint() + + # Users output will be the get items nodes instead + lowered_exported_program.graph_signature.user_outputs = [ + getitem_node.name for getitem_node in getitem_nodes + ] + # All data are consumed by the delegates so they should be removed from the state dict. + inputs_to_parameters = ( + lowered_exported_program.graph_signature.inputs_to_parameters + ) + inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers + lowered_exported_program.graph_signature.user_inputs = [ + user_input + for user_input in lowered_exported_program.graph_signature.user_inputs + if user_input in inputs_to_parameters or user_input in inputs_to_buffers + ] + lowered_exported_program.graph_signature.buffers = {} + lowered_exported_program.graph_signature.parameters = {} + lowered_exported_program.graph_signature.inputs_to_parameters = {} + lowered_exported_program.graph_signature.inputs_to_buffers = {} + + # Double check the ExportedProgram data(especially everything except graph) is good + exported_program = ExportedProgram( + root=lowered_exported_program.graph_module, + graph=lowered_exported_program.graph, + graph_signature=lowered_exported_program.graph_signature, + # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) + # somewhere as we should pass it a list of tensors to the lowered module and output a + # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the + # inputs/outputs to the toplevel program will be in the format of the eager module. + call_spec=lowered_exported_program.call_spec, + state_dict={}, # None because all data are consumed by delegate + range_constraints=lowered_exported_program.range_constraints, + equality_constraints=lowered_exported_program.equality_constraints, + module_call_graph=lowered_exported_program.module_call_graph, + ) + exported_program = exported_program.transform( + SpecPropPass(), MemoryPlanningPass("greedy") + ) + emitted_program = emit_program( + exported_program, emit_stacktrace=emit_stacktrace + ).program + return emitted_program + # Used to patch each delegated function with a call_delegate call # @staticmethod def forward( diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index c366ae9c9cc..6cb28077150 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -86,10 +86,12 @@ def call_delegate(self, lowered_module, args, kwargs, meta): args_data, kwargs_data = pytree.tree_map_only( ProxyValue, lambda x: x.data, (args, kwargs) ) - meta["spec"] = pytree.tree_map( - make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) + # If spec is missing, re-genenrate it with args data + if "spec" not in meta: + meta["spec"] = pytree.tree_map( + make_spec, + executorch_call_delegate(lowered_module, *args_data), + ) return super().call_delegate(lowered_module, args, kwargs, meta) # pyre-ignore