From 19acbca7f9df8f5e9cbf47c7f6f0f93e19f02db3 Mon Sep 17 00:00:00 2001 From: Angela Yi Date: Tue, 10 Oct 2023 16:26:20 -0700 Subject: [PATCH] Get export APIs ready for PTC (#566) Summary: X-link: https://github.com/pytorch/pytorch/pull/110410 https://docs.google.com/document/d/1QJJEGnj2nHGPODlw38BEG3KLLCOTfdOVjPrNQbz_LM8/edit#bookmark=id.lp80wfshq130 Changes: * `torch.export` will return a functional ATen graph but not lowered to core aten decompositions (CompositeImplicitAutograd decomps still run) * `exported_program.run_decompositions(decomposition_table)` will optionally take a decomposition table, and run decompositions on the exported program, returning a new exported program. By default we will run the Core ATen decomposition table. Calling convention for Executorch stays the same: ``` pre_autograd_graph = capture_pre_autograd_graph(f, args, ...) aten_graph_no_decomps = torch.export.export(pre_autograd_graph, args, ...) # Within to_edge we decompose to core aten and then convert to edge edge_graph = exir.to_edge(aten_graph_no_decomps) ``` Reviewed By: larryliu0820, guangy10 Differential Revision: D49742989 --- exir/capture/_capture.py | 6 ++---- exir/program/TARGETS | 1 + exir/program/_program.py | 6 ++++++ exir/program/test/test_program.py | 4 ++-- 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/exir/capture/_capture.py b/exir/capture/_capture.py index 53cc6191ba2..2525c894f11 100644 --- a/exir/capture/_capture.py +++ b/exir/capture/_capture.py @@ -8,7 +8,6 @@ import warnings from collections import namedtuple from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union -from unittest.mock import patch import torch import torch._export @@ -125,9 +124,8 @@ def capture( # noqa: C901 "Functionalization is required for enable_aot.", ) - # TODO remove this later - with patch("torch._export.DECOMP_TABLE", _default_decomposition_table()): - ep = export(f, args, constraints=constraints) + ep = export(f, args, constraints=constraints) + ep = ep.run_decompositions(_default_decomposition_table()) # pyre-ignore[6] ep = ep._transform(ReplaceViewOpsWithViewCopyOpsPass()) if not config._unlift: return ExirExportedProgram(ep, False) diff --git a/exir/program/TARGETS b/exir/program/TARGETS index 4acfdf39883..9215fb6f7a4 100644 --- a/exir/program/TARGETS +++ b/exir/program/TARGETS @@ -21,6 +21,7 @@ python_library( "//executorch/exir:pass_manager", "//executorch/exir:print_program", "//executorch/exir:schema", + "//executorch/exir:tracer", "//executorch/exir/_serialize:lib", "//executorch/exir/backend:backend_api", "//executorch/exir/backend:partitioner", diff --git a/exir/program/_program.py b/exir/program/_program.py index ab42454ff2d..3c124b753ef 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -27,6 +27,7 @@ from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.print_program import pretty_print, print_program from executorch.exir.schema import Program +from executorch.exir.tracer import _default_decomposition_table from executorch.exir.verification.verifier import ( EXIRATenDialectVerifier, EXIREdgeDialectVerifier, @@ -589,6 +590,11 @@ def to_edge( edge_programs: Dict[str, ExportedProgram] = {} for name, program in aten_programs.items(): + # Decompose to Core ATen + program = program.run_decompositions( + _default_decomposition_table() # pyre-ignore[6] + ) + if config._check_ir_validity: try: EXIRATenDialectVerifier()(program.graph_module) diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index fc224aaad30..c514592fa9b 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -44,11 +44,11 @@ def foo(x: torch.Tensor) -> torch.Tensor: torch.ones(1), torch.zeros(1), ), - ) + ).run_decompositions() programs["foo"] = export( foo, (torch.ones(1),), - ) + ).run_decompositions() return programs