From 0d2a39d40bb32335516cdd962c358401db2015aa Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 25 Aug 2023 10:37:40 -0700 Subject: [PATCH 1/2] Emit lowered module Summary: It has been a pending task for a while, as a follow up on https://fb.workplace.com/groups/536346827621174/permalink/665126474743208/ we want the lowered backend module to be **runnerable**, **emittable**, and **retracable**. This diff makes the lowered backend module emittable without the need to composite with other modules It will the easiest the flow for backend developer to try lower one op to a backend via delegate. Differential Revision: https://www.internalfb.com/diff/D47803806?entry_point=27 fbshipit-source-id: 65eb5e3b80ef1bfae2593e713b930887c68dcfd8 --- exir/TARGETS | 6 +- exir/backend/test/TARGETS | 23 ++ exir/backend/test/test_backends.py | 25 -- .../test/test_lowered_backend_module.py | 220 ++++++++++++++++++ exir/lowered_backend_module.py | 181 +++++++++++++- exir/passes/spec_prop_pass.py | 10 +- 6 files changed, 431 insertions(+), 34 deletions(-) create mode 100644 exir/backend/test/test_lowered_backend_module.py diff --git a/exir/TARGETS b/exir/TARGETS index 20958142bc6..5961775cb13 100644 --- a/exir/TARGETS +++ b/exir/TARGETS @@ -120,10 +120,14 @@ python_library( deps = [ ":delegate", ":graph_module", - ":lib", + ":schema", ":tracer", "//caffe2:torch", "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/emit:lib", + "//executorch/exir/passes:memory_planning_pass", + "//executorch/exir/passes:spec_prop_pass", + "//executorch/exir/serialize:lib", ], ) diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index fe0718921c9..223fc22f41f 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -145,6 +145,29 @@ python_unittest( ], ) +python_unittest( + name = "test_lowered_backend_module", + srcs = [ + "test_lowered_backend_module.py", + ], + supports_static_listing = True, + deps = [ + "fbsource//third-party/pypi/hypothesis:hypothesis", + ":backend_with_compiler_demo", + ":qnn_backend_demo", + "//caffe2:torch", + "//executorch/exir:lib", + "//executorch/exir:schema", + "//executorch/exir/backend:backend_api", + "//executorch/exir/backend:compile_spec_schema", + "//executorch/exir/tests:models", + "//executorch/extension/pybindings:portable", # @manual + "//executorch/kernels/portable:custom_ops_generated_lib", + "//executorch/kernels/quantized:custom_ops_generated_lib", + "//executorch/runtime/executor/test:test_backend_compiler_lib", + ], +) + python_unittest( name = "test_graph_partition", srcs = [ diff --git a/exir/backend/test/test_backends.py b/exir/backend/test/test_backends.py index c79ccc728bb..94b7d9121bf 100644 --- a/exir/backend/test/test_backends.py +++ b/exir/backend/test/test_backends.py @@ -115,31 +115,6 @@ def check_backend_delegate( program.backend_delegate_data[processed.index].data, expected_processed ) - def test_simple(self): - class SinModule(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, x): - return torch.sin(x) - - sin_module = SinModule() - model_inputs = (torch.ones(1),) - expected_res = sin_module(*model_inputs) - edgeir_m = exir.capture( - sin_module, model_inputs, exir.CaptureConfig() - ).to_edge() - - lowered_sin_module = to_backend( - "BackendWithCompilerDemo", edgeir_m.exported_program, [] - ) - new_res = lowered_sin_module(*model_inputs) - - self.assertTrue(torch.allclose(new_res, expected_res)) - - # TODO(tkaruturi): emitting single LoweredBackendModule - # program = exir.capture(graph_module).to_edge().to_exectorch().program - @vary_segments def test_backend_with_compiler(self, extract_segments: bool): class SinModule(torch.nn.Module): diff --git a/exir/backend/test/test_lowered_backend_module.py b/exir/backend/test/test_lowered_backend_module.py new file mode 100644 index 00000000000..91078f0ed7f --- /dev/null +++ b/exir/backend/test/test_lowered_backend_module.py @@ -0,0 +1,220 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import executorch.exir.tests.models as models + +import torch +from executorch import exir +from executorch.exir.backend.backend_api import to_backend +from executorch.exir.backend.compile_spec_schema import CompileSpec +from executorch.exir.backend.test.backend_with_compiler_demo import ( + BackendWithCompilerDemo, +) +from executorch.exir.backend.test.qnn_backend_demo import QnnBackend +from executorch.exir.schema import DelegateCall, Program + +from executorch.extension.pybindings.portable import ( # @manual + _load_for_executorch_from_buffer, +) +from hypothesis import given, settings, strategies as st + + +class TestBackendAPI(unittest.TestCase): + def validate_lowered_module_program(self, program: Program) -> None: + """ + For any program emitted from lowered_backend_module, we expect only one delegate call + """ + # there should only be one instruction + self.assertEqual( + len(program.execution_plan[0].chains[0].instructions), + 1, + ) + + # the only instruction should be a delegate call + self.assertTrue( + isinstance( + program.execution_plan[0].chains[0].instructions[0].instr_args, + DelegateCall, + ) + ) + + def get_program_from_wrapped_module( + self, lowered_module, example_inputs, capture_config, edge_compile_config + ): + class WrappedModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.one_module = lowered_module + + def forward(self, *args): + return self.one_module(*args) + + return ( + exir.capture(WrappedModule(), example_inputs, capture_config) + .to_edge(edge_compile_config) + .to_executorch() + .program + ) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_lowered_backend_module_end_to_end(self, unlift): + class SinModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return torch.sin(x) + + sin_module = SinModule() + model_inputs = (torch.ones(1),) + expected_res = sin_module(*model_inputs) + edgeir_m = exir.capture( + sin_module, + model_inputs, + exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift), + ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True)) + max_value = model_inputs[0].shape[0] + compile_specs = [CompileSpec("max_value", bytes([max_value]))] + lowered_sin_module = to_backend( + BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs + ) + + new_res = lowered_sin_module(*model_inputs) + + self.assertTrue(torch.allclose(new_res[0], expected_res)) + program = lowered_sin_module.program() + self.validate_lowered_module_program(program) + buff = lowered_sin_module.buffer() + + executorch_module = _load_for_executorch_from_buffer(buff) + model_inputs = torch.ones(1) + model_outputs = executorch_module.forward([model_inputs]) + self.assertEqual( + model_inputs, + torch.ones(1), + ) + expected_res = 0.8333 * torch.ones(1) + + self.assertTrue( + torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03) + ) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_lowered_backend_module(self, unlift): + module_list = [ + models.Emformer(), + models.Repeat(), + models.ElementwiseAdd(), + models.MLP(), + models.ModelWithUnusedArg(), + ] + + capture_config = ( + exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() + ) + + edge_compile_config = exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ) + + for model in module_list: + model_inputs = model.get_random_inputs() + + edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( + edge_compile_config + ) + lowered_model = to_backend( + QnnBackend.__name__, edgeir_m.exported_program, [] + ) + program = lowered_model.program() + reference_program = self.get_program_from_wrapped_module( + lowered_model, model_inputs, capture_config, edge_compile_config + ) + + # Check program is fairly equal to the reference program + self.assertEqual( + len(program.execution_plan[0].chains[0].instructions), + len(reference_program.execution_plan[0].chains[0].instructions), + ) + + self.assertEqual( + len(program.execution_plan[0].values), + len(reference_program.execution_plan[0].values), + ) + + self.assertEqual( + len(program.execution_plan[0].inputs), + len(reference_program.execution_plan[0].inputs), + ) + + self.assertEqual( + len(program.execution_plan[0].outputs), + len(reference_program.execution_plan[0].outputs), + ) + + # Ensure we can get the buffer + _ = lowered_model.buffer() + self.validate_lowered_module_program(program) + + @given( + unlift=st.booleans(), # verify both lifted and unlifted graph + ) + @settings(deadline=500000) + def test_emit_nested_lowered_backend_module(self, unlift): + module_list = [ + models.Emformer(), + models.Repeat(), + models.ElementwiseAdd(), + models.MLP(), + models.ModelWithUnusedArg(), + ] + + capture_config = ( + exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() + ) + + edge_compile_config = exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ) + + for model in module_list: + model_inputs = model.get_random_inputs() + + edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( + edge_compile_config + ) + lowered_module = to_backend( + QnnBackend.__name__, edgeir_m.exported_program, [] + ) + + # This module will include one operator and two delegate call + class WrappedModule(torch.nn.Module): + def __init__(self, lowered_module): + super().__init__() + self.one_module = lowered_module + + def forward(self, *args): + return self.one_module(*args) + + wrapped_module = WrappedModule(lowered_module) + wrapped_module_edge = exir.capture( + wrapped_module, model_inputs, capture_config + ).to_edge(edge_compile_config) + + nested_lowered_model = to_backend( + QnnBackend.__name__, wrapped_module_edge.exported_program, [] + ) + + program = nested_lowered_model.program() + self.validate_lowered_module_program(program) diff --git a/exir/lowered_backend_module.py b/exir/lowered_backend_module.py index 96190ecac0f..91a4dbb8ade 100644 --- a/exir/lowered_backend_module.py +++ b/exir/lowered_backend_module.py @@ -7,18 +7,29 @@ # pyre-strict import copy -from typing import Dict, List, Tuple, Union +import operator +from typing import Dict, List, Optional, Tuple, Union import torch import torch.utils._pytree as pytree -from executorch.exir import CallSpec, ExportGraphSignature from executorch.exir.backend.compile_spec_schema import CompileSpec -from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.delegate import executorch_call_delegate, get_lowered_module_name +from executorch.exir.emit import emit_program from executorch.exir.graph_module import _get_submodule +from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass +from executorch.exir.passes.spec_prop_pass import make_spec, SpecPropPass +from executorch.exir.schema import Program +from executorch.exir.serialize import serialize_to_flatbuffer + from executorch.exir.tracer import Value -from torch._export.exported_program import ExportedProgram + +from torch._export.exported_program import ( + CallSpec, + ExportedProgram, + ExportGraphSignature, +) from torch._subclasses import FakeTensor from torch.fx.passes.utils.fuser_utils import ( erase_nodes, @@ -77,6 +88,168 @@ def compile_specs(self) -> List[CompileSpec]: def original_module(self) -> ExportedProgram: return self._original_module + # TODO(chenlai): consolidate the seriailization config with serialize_to_flatbuffer api + def buffer( + self, + extract_segments: bool = False, + segment_alignment: int = 4096, + constant_tensor_alignment: Optional[int] = None, + delegate_alignment: Optional[int] = None, + ) -> bytes: + out = serialize_to_flatbuffer( + program=self.program(), + extract_segments=extract_segments, + segment_alignment=segment_alignment, + constant_tensor_alignment=constant_tensor_alignment, + delegate_alignment=delegate_alignment, + ) + return out + + # TODO(chenlai): re-consider recapture instead of manually constructing the program because + # the meta data construction is done manually. + def program(self, emit_stacktrace: bool = False) -> Program: + """ + The idea in this function is to create a module based on the original module. The original module will + look something like following: + + opcode name target args kwargs + ------------- ------------------- ---------------- ------------------------------------------ -------- + placeholder arg0_1 arg0_1 () {} + placeholder arg1_1 arg1_1 () {} + call_function aten_repeat_default * (arg1_1, [4, 1]) {} + call_function aten_mul_tensor * (aten_repeat_default, aten_repeat_default) {} + call_function aten_add_tensor * (arg1_1, arg1_1) {} + output output output ([aten_mul_tensor, aten_add_tensor],) {} + + if the whole module is lowered, the resulting lowered module look like + + opcode name target args kwargs + ------------- ------------------------ --------------------------- ---------------------------------- -------- + placeholder arg0_1 arg0_1 () {} + placeholder arg1_1 arg1_1 () {} + get_attr lowered_module_0 lowered_module_0 () {} + call_function executorch_call_delegate executorch_call_delegate (lowered_module_0, arg0_1, arg1_1) {} + call_function getitem (executorch_call_delegate, 0) {} + call_function getitem_1 (executorch_call_delegate, 1) {} + output output_1 output ([getitem, getitem_1],) {} + + We'll remove all call_function nodes, insert an call_delegate node, inserting getitems nodes to get the result for call_delegate node + and return the list of getitems as the output + """ + lowered_exported_program = copy.deepcopy(self.original_module) + + # The real input nodes are the ones not buffer or parameter + all_input_nodes = [ + node + for node in lowered_exported_program.graph.nodes + if ( + node.op == "placeholder" + and node.name + not in lowered_exported_program.graph_signature.inputs_to_buffers + and node.name + not in lowered_exported_program.graph_signature.inputs_to_parameters + ) + ] + + output_node = [ + node for node in lowered_exported_program.graph.nodes if node.op == "output" + ] + assert len(output_node) == 1, "There should be only one output node" + + # Step 1. Cleaning up the graph before inserting the call_delegate node + # Remove the original output node + lowered_exported_program.graph.erase_node(output_node[0]) + + # Remove all the everything else except the input + for node in reversed(lowered_exported_program.graph.nodes): + if node.op != "placeholder": + lowered_exported_program.graph.erase_node(node) + + # Find placeholders that are parameters or buffers, remove them from the main graph + for node in lowered_exported_program.graph.nodes: + if node.op == "placeholder" and ( + node.name in lowered_exported_program.graph_signature.inputs_to_buffers + or node.name + in lowered_exported_program.graph_signature.inputs_to_parameters + ): + lowered_exported_program.graph.erase_node(node) + + # Step 2. Start constructing the graph + lowered_name = get_lowered_module_name( + lowered_exported_program.graph_module, self + ) + # Insert the lowered module to the graph module as an attibute + lowered_node = lowered_exported_program.graph.get_attr(lowered_name) + + # Insert a call_delegate node to the graph module, with arguments from the arg list + delegate_node = lowered_exported_program.graph.call_function( + executorch_call_delegate, (lowered_node, *all_input_nodes) + ) + # Get the output list. Since the output node is a tuple of list, like ([aten_mul_tensor, aten_add_tensor],) + # We add some handling logic to get the list `[aten_mul_tensor, aten_add_tensor]` properly + original_output_nodes = [ + node for node in self.original_module.graph.nodes if node.op == "output" + ][0].args[0] + + delegate_node.meta["spec"] = tuple( + [make_spec(node.meta["val"]) for node in original_output_nodes] + ) + + # The getitem nodes that are going to be inserted to the lowered graph module + getitem_nodes = [] + for i in range(len(original_output_nodes)): + getitem_node = lowered_exported_program.graph.call_function( + operator.getitem, + args=(delegate_node, i), + ) + getitem_nodes.append(getitem_node) + lowered_exported_program.graph.output(getitem_nodes) + + lowered_exported_program.graph_module.recompile() + lowered_exported_program.graph.lint() + + # Users output will be the get items nodes instead + lowered_exported_program.graph_signature.user_outputs = [ + getitem_node.name for getitem_node in getitem_nodes + ] + # All data are consumed by the delegates so they should be removed from the state dict. + inputs_to_parameters = ( + lowered_exported_program.graph_signature.inputs_to_parameters + ) + inputs_to_buffers = lowered_exported_program.graph_signature.inputs_to_buffers + lowered_exported_program.graph_signature.user_inputs = [ + user_input + for user_input in lowered_exported_program.graph_signature.user_inputs + if user_input in inputs_to_parameters or user_input in inputs_to_buffers + ] + lowered_exported_program.graph_signature.buffers = {} + lowered_exported_program.graph_signature.parameters = {} + lowered_exported_program.graph_signature.inputs_to_parameters = {} + lowered_exported_program.graph_signature.inputs_to_buffers = {} + + # Double check the ExportedProgram data(especially everything except graph) is good + exported_program = ExportedProgram( + root=lowered_exported_program.graph_module, + graph=lowered_exported_program.graph, + graph_signature=lowered_exported_program.graph_signature, + # TODO: May need to set lowered_exported_program.call_spec = CallSpec(None, None) + # somewhere as we should pass it a list of tensors to the lowered module and output a + # list of tensors. Putting call_spec=lowered_exported_program.call_spec is correct here as the + # inputs/outputs to the toplevel program will be in the format of the eager module. + call_spec=lowered_exported_program.call_spec, + state_dict={}, # None because all data are consumed by delegate + range_constraints=lowered_exported_program.range_constraints, + equality_constraints=lowered_exported_program.equality_constraints, + module_call_graph=lowered_exported_program.module_call_graph, + ) + exported_program = exported_program.transform( + SpecPropPass(), MemoryPlanningPass("greedy") + ) + emitted_program = emit_program( + exported_program, emit_stacktrace=emit_stacktrace + ).program + return emitted_program + # Used to patch each delegated function with a call_delegate call # @staticmethod def forward( diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index c366ae9c9cc..6cb28077150 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -86,10 +86,12 @@ def call_delegate(self, lowered_module, args, kwargs, meta): args_data, kwargs_data = pytree.tree_map_only( ProxyValue, lambda x: x.data, (args, kwargs) ) - meta["spec"] = pytree.tree_map( - make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) + # If spec is missing, re-genenrate it with args data + if "spec" not in meta: + meta["spec"] = pytree.tree_map( + make_spec, + executorch_call_delegate(lowered_module, *args_data), + ) return super().call_delegate(lowered_module, args, kwargs, meta) # pyre-ignore From ae77fe2c27b905a038b281ca706a7db556730ebd Mon Sep 17 00:00:00 2001 From: Chen Lai Date: Fri, 25 Aug 2023 10:37:59 -0700 Subject: [PATCH 2/2] Update the example flow to include emit program directly from lowered module (#134) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/134 As title, we can emit the program directly from lowered module Reviewed By: mergennachin, chakriu, digantdesai, larryliu0820 Differential Revision: D47855787 fbshipit-source-id: b14ef020187f71e72099f42e7d6429e610231137 --- examples/export/export_and_delegate.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/export/export_and_delegate.py b/examples/export/export_and_delegate.py index a54f9dafedb..3ddb9eb215d 100644 --- a/examples/export/export_and_delegate.py +++ b/examples/export/export_and_delegate.py @@ -8,7 +8,6 @@ import argparse -import executorch.exir as exir import torch from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -157,11 +156,17 @@ def export_and_lower_the_whole_graph(): # Lower AddMulModule to the demo backend print("Lowering to the demo backend...") - _ = to_backend( - BackendWithCompilerDemo.__name__, edge.exported_program, m.get_compile_spec() + lowered_module = to_backend( + BackendWithCompilerDemo.__name__, edge, m.get_compile_spec() ) - # TODO(chenlai): emit the lowered graph + buffer = lowered_module.buffer() + + model_name = "whole" + filename = f"{model_name}.pte" + print(f"Saving exported program to {filename}") + with open(filename, "wb") as file: + file.write(buffer) OPTIONS_TO_LOWER = {