Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/export/export_and_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import argparse

import executorch.exir as exir
import torch
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.test.backend_with_compiler_demo import (
Expand Down Expand Up @@ -157,11 +156,17 @@ def export_and_lower_the_whole_graph():

# Lower AddMulModule to the demo backend
print("Lowering to the demo backend...")
_ = to_backend(
BackendWithCompilerDemo.__name__, edge.exported_program, m.get_compile_spec()
lowered_module = to_backend(
BackendWithCompilerDemo.__name__, edge, m.get_compile_spec()
)

# TODO(chenlai): emit the lowered graph
buffer = lowered_module.buffer()

model_name = "whole"
filename = f"{model_name}.pte"
print(f"Saving exported program to {filename}")
with open(filename, "wb") as file:
file.write(buffer)


OPTIONS_TO_LOWER = {
Expand Down
6 changes: 5 additions & 1 deletion exir/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,14 @@ python_library(
deps = [
":delegate",
":graph_module",
":lib",
":schema",
":tracer",
"//caffe2:torch",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/emit:lib",
"//executorch/exir/passes:memory_planning_pass",
"//executorch/exir/passes:spec_prop_pass",
"//executorch/exir/serialize:lib",
],
)

Expand Down
23 changes: 23 additions & 0 deletions exir/backend/test/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,29 @@ python_unittest(
],
)

python_unittest(
name = "test_lowered_backend_module",
srcs = [
"test_lowered_backend_module.py",
],
supports_static_listing = True,
deps = [
"fbsource//third-party/pypi/hypothesis:hypothesis",
":backend_with_compiler_demo",
":qnn_backend_demo",
"//caffe2:torch",
"//executorch/exir:lib",
"//executorch/exir:schema",
"//executorch/exir/backend:backend_api",
"//executorch/exir/backend:compile_spec_schema",
"//executorch/exir/tests:models",
"//executorch/extension/pybindings:portable", # @manual
"//executorch/kernels/portable:custom_ops_generated_lib",
"//executorch/kernels/quantized:custom_ops_generated_lib",
"//executorch/runtime/executor/test:test_backend_compiler_lib",
],
)

python_unittest(
name = "test_graph_partition",
srcs = [
Expand Down
25 changes: 0 additions & 25 deletions exir/backend/test/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,31 +115,6 @@ def check_backend_delegate(
program.backend_delegate_data[processed.index].data, expected_processed
)

def test_simple(self):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

sin_module = SinModule()
model_inputs = (torch.ones(1),)
expected_res = sin_module(*model_inputs)
edgeir_m = exir.capture(
sin_module, model_inputs, exir.CaptureConfig()
).to_edge()

lowered_sin_module = to_backend(
"BackendWithCompilerDemo", edgeir_m.exported_program, []
)
new_res = lowered_sin_module(*model_inputs)

self.assertTrue(torch.allclose(new_res, expected_res))

# TODO(tkaruturi): emitting single LoweredBackendModule
# program = exir.capture(graph_module).to_edge().to_exectorch().program

@vary_segments
def test_backend_with_compiler(self, extract_segments: bool):
class SinModule(torch.nn.Module):
Expand Down
220 changes: 220 additions & 0 deletions exir/backend/test/test_lowered_backend_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import executorch.exir.tests.models as models

import torch
from executorch import exir
from executorch.exir.backend.backend_api import to_backend
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.backend.test.backend_with_compiler_demo import (
BackendWithCompilerDemo,
)
from executorch.exir.backend.test.qnn_backend_demo import QnnBackend
from executorch.exir.schema import DelegateCall, Program

from executorch.extension.pybindings.portable import ( # @manual
_load_for_executorch_from_buffer,
)
from hypothesis import given, settings, strategies as st


class TestBackendAPI(unittest.TestCase):
def validate_lowered_module_program(self, program: Program) -> None:
"""
For any program emitted from lowered_backend_module, we expect only one delegate call
"""
# there should only be one instruction
self.assertEqual(
len(program.execution_plan[0].chains[0].instructions),
1,
)

# the only instruction should be a delegate call
self.assertTrue(
isinstance(
program.execution_plan[0].chains[0].instructions[0].instr_args,
DelegateCall,
)
)

def get_program_from_wrapped_module(
self, lowered_module, example_inputs, capture_config, edge_compile_config
):
class WrappedModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.one_module = lowered_module

def forward(self, *args):
return self.one_module(*args)

return (
exir.capture(WrappedModule(), example_inputs, capture_config)
.to_edge(edge_compile_config)
.to_executorch()
.program
)

@given(
unlift=st.booleans(), # verify both lifted and unlifted graph
)
@settings(deadline=500000)
def test_emit_lowered_backend_module_end_to_end(self, unlift):
class SinModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.sin(x)

sin_module = SinModule()
model_inputs = (torch.ones(1),)
expected_res = sin_module(*model_inputs)
edgeir_m = exir.capture(
sin_module,
model_inputs,
exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift),
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True))
max_value = model_inputs[0].shape[0]
compile_specs = [CompileSpec("max_value", bytes([max_value]))]
lowered_sin_module = to_backend(
BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs
)

new_res = lowered_sin_module(*model_inputs)

self.assertTrue(torch.allclose(new_res[0], expected_res))
program = lowered_sin_module.program()
self.validate_lowered_module_program(program)
buff = lowered_sin_module.buffer()

executorch_module = _load_for_executorch_from_buffer(buff)
model_inputs = torch.ones(1)
model_outputs = executorch_module.forward([model_inputs])
self.assertEqual(
model_inputs,
torch.ones(1),
)
expected_res = 0.8333 * torch.ones(1)

self.assertTrue(
torch.allclose(model_outputs[0], expected_res, atol=1e-03, rtol=1e-03)
)

@given(
unlift=st.booleans(), # verify both lifted and unlifted graph
)
@settings(deadline=500000)
def test_emit_lowered_backend_module(self, unlift):
module_list = [
models.Emformer(),
models.Repeat(),
models.ElementwiseAdd(),
models.MLP(),
models.ModelWithUnusedArg(),
]

capture_config = (
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
)

edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
)

for model in module_list:
model_inputs = model.get_random_inputs()

edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
edge_compile_config
)
lowered_model = to_backend(
QnnBackend.__name__, edgeir_m.exported_program, []
)
program = lowered_model.program()
reference_program = self.get_program_from_wrapped_module(
lowered_model, model_inputs, capture_config, edge_compile_config
)

# Check program is fairly equal to the reference program
self.assertEqual(
len(program.execution_plan[0].chains[0].instructions),
len(reference_program.execution_plan[0].chains[0].instructions),
)

self.assertEqual(
len(program.execution_plan[0].values),
len(reference_program.execution_plan[0].values),
)

self.assertEqual(
len(program.execution_plan[0].inputs),
len(reference_program.execution_plan[0].inputs),
)

self.assertEqual(
len(program.execution_plan[0].outputs),
len(reference_program.execution_plan[0].outputs),
)

# Ensure we can get the buffer
_ = lowered_model.buffer()
self.validate_lowered_module_program(program)

@given(
unlift=st.booleans(), # verify both lifted and unlifted graph
)
@settings(deadline=500000)
def test_emit_nested_lowered_backend_module(self, unlift):
module_list = [
models.Emformer(),
models.Repeat(),
models.ElementwiseAdd(),
models.MLP(),
models.ModelWithUnusedArg(),
]

capture_config = (
exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig()
)

edge_compile_config = exir.EdgeCompileConfig(
_check_ir_validity=False, _use_edge_ops=True
)

for model in module_list:
model_inputs = model.get_random_inputs()

edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge(
edge_compile_config
)
lowered_module = to_backend(
QnnBackend.__name__, edgeir_m.exported_program, []
)

# This module will include one operator and two delegate call
class WrappedModule(torch.nn.Module):
def __init__(self, lowered_module):
super().__init__()
self.one_module = lowered_module

def forward(self, *args):
return self.one_module(*args)

wrapped_module = WrappedModule(lowered_module)
wrapped_module_edge = exir.capture(
wrapped_module, model_inputs, capture_config
).to_edge(edge_compile_config)

nested_lowered_model = to_backend(
QnnBackend.__name__, wrapped_module_edge.exported_program, []
)

program = nested_lowered_model.program()
self.validate_lowered_module_program(program)
Loading