From 2a23a8985a99b618cf591627212502899b96e976 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Fri, 19 Apr 2024 12:51:03 -0700 Subject: [PATCH] remove exir.capture from test_lowered_backend_module (#3169) Summary: title Differential Revision: D56368215 --- .../test/test_lowered_backend_module.py | 63 +++++++++---------- 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/exir/backend/test/test_lowered_backend_module.py b/exir/backend/test/test_lowered_backend_module.py index 743c20964ec..65b098f9550 100644 --- a/exir/backend/test/test_lowered_backend_module.py +++ b/exir/backend/test/test_lowered_backend_module.py @@ -10,6 +10,7 @@ import torch from executorch import exir +from executorch.exir import to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.backend.test.backend_with_compiler_demo import ( @@ -22,6 +23,7 @@ _load_for_executorch_from_buffer, ) from hypothesis import given, settings, strategies as st +from torch.export import export class TestBackendAPI(unittest.TestCase): @@ -44,7 +46,7 @@ def validate_lowered_module_program(self, program: Program) -> None: ) def get_program_from_wrapped_module( - self, lowered_module, example_inputs, capture_config, edge_compile_config + self, lowered_module, example_inputs, edge_compile_config ): class WrappedModule(torch.nn.Module): def __init__(self): @@ -55,17 +57,16 @@ def forward(self, *args): return self.one_module(*args) return ( - exir.capture(WrappedModule(), example_inputs, capture_config) - .to_edge(edge_compile_config) + to_edge( + export(WrappedModule(), example_inputs), + compile_config=edge_compile_config, + ) .to_executorch() - .program + .executorch_program ) - @given( - unlift=st.booleans(), # verify both lifted and unlifted graph - ) @settings(deadline=500000) - def test_emit_lowered_backend_module_end_to_end(self, unlift): + def test_emit_lowered_backend_module_end_to_end(self): class SinModule(torch.nn.Module): def __init__(self): super().__init__() @@ -76,15 +77,19 @@ def forward(self, x): sin_module = SinModule() model_inputs = (torch.ones(1),) expected_res = sin_module(*model_inputs) - edgeir_m = exir.capture( - sin_module, - model_inputs, - exir.CaptureConfig(pt2_mode=True, enable_aot=True, _unlift=unlift), - ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False, _use_edge_ops=True)) + edgeir_m = to_edge( + export( + sin_module, + model_inputs, + ), + compile_config=exir.EdgeCompileConfig( + _check_ir_validity=False, _use_edge_ops=True + ), + ) max_value = model_inputs[0].shape[0] compile_specs = [CompileSpec("max_value", bytes([max_value]))] lowered_sin_module = to_backend( - BackendWithCompilerDemo.__name__, edgeir_m.exported_program, compile_specs + BackendWithCompilerDemo.__name__, edgeir_m.exported_program(), compile_specs ) new_res = lowered_sin_module(*model_inputs) @@ -120,10 +125,6 @@ def test_emit_lowered_backend_module(self, unlift): models.ModelWithUnusedArg(), ] - capture_config = ( - exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() - ) - edge_compile_config = exir.EdgeCompileConfig( _check_ir_validity=False, _use_edge_ops=True ) @@ -131,15 +132,15 @@ def test_emit_lowered_backend_module(self, unlift): for model in module_list: model_inputs = model.get_random_inputs() - edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( - edge_compile_config + edgeir_m = to_edge( + export(model, model_inputs), compile_config=edge_compile_config ) lowered_model = to_backend( - QnnBackend.__name__, edgeir_m.exported_program, [] + QnnBackend.__name__, edgeir_m.exported_program(), [] ) program = lowered_model.program() reference_program = self.get_program_from_wrapped_module( - lowered_model, model_inputs, capture_config, edge_compile_config + lowered_model, model_inputs, edge_compile_config ) # Check program is fairly equal to the reference program @@ -180,10 +181,6 @@ def test_emit_nested_lowered_backend_module(self, unlift): models.ModelWithUnusedArg(), ] - capture_config = ( - exir.CaptureConfig(enable_aot=True) if unlift else exir.CaptureConfig() - ) - edge_compile_config = exir.EdgeCompileConfig( _check_ir_validity=False, _use_edge_ops=True ) @@ -191,11 +188,11 @@ def test_emit_nested_lowered_backend_module(self, unlift): for model in module_list: model_inputs = model.get_random_inputs() - edgeir_m = exir.capture(model, model_inputs, capture_config).to_edge( - edge_compile_config + edgeir_m = to_edge( + export(model, model_inputs), compile_config=edge_compile_config ) lowered_module = to_backend( - QnnBackend.__name__, edgeir_m.exported_program, [] + QnnBackend.__name__, edgeir_m.exported_program(), [] ) # This module will include one operator and two delegate call @@ -208,12 +205,12 @@ def forward(self, *args): return self.one_module(*args) wrapped_module = WrappedModule(lowered_module) - wrapped_module_edge = exir.capture( - wrapped_module, model_inputs, capture_config - ).to_edge(edge_compile_config) + wrapped_module_edge = to_edge( + export(wrapped_module, model_inputs), compile_config=edge_compile_config + ) nested_lowered_model = to_backend( - QnnBackend.__name__, wrapped_module_edge.exported_program, [] + QnnBackend.__name__, wrapped_module_edge.exported_program(), [] ) program = nested_lowered_model.program()