Skip to content

Commit

Permalink
[export] Add run_decomposition() function to ExportedProgram
Browse files Browse the repository at this point in the history
Summary:
https://docs.google.com/document/d/1QJJEGnj2nHGPODlw38BEG3KLLCOTfdOVjPrNQbz_LM8/edit#bookmark=id.lp80wfshq130

`exported_program.run_decompositions(decomposition_table)` will optionally take a decomposition table, and run decompositions on the exported program, returning a new exported program. By default we will run the Core ATen decomposition table.

Splitting up this diff with the following one (D49742989) to make migrating Executorch easier:
1. Land this diff
1. Wait for a pytorch nightly to include this diff
1. Update executorch's pytorch nightly pin
1. Land the following diff to have export() return no decomps + updates executorch code to use the run_decomps()

Test Plan: Tested in following diff

Differential Revision: D49743208
  • Loading branch information
angelayi authored and facebook-github-bot committed Sep 28, 2023
1 parent 3de4299 commit 487e933
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 33 deletions.
60 changes: 58 additions & 2 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch._dynamo as torchdynamo
from functorch.experimental.control_flow import map, cond
from torch import Tensor
from torch.export import Constraint
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export, capture_pre_autograd_graph
from torch.export import Constraint, Dim, export
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, capture_pre_autograd_graph
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import (
get_buffer,
Expand Down Expand Up @@ -1449,5 +1449,61 @@ def forward(self, x):
):
exported_program(torch.rand(2, 3), torch.rand(2, 3))

def test_export_decomps_simple(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin(x)

inp = (torch.randn(5, 10),)
m = M()
with unittest.mock.patch("torch._export.DECOMP_TABLE", None):
ep = export(m, inp)


FileCheck().check_count(
"torch.ops.aten.t.default", 1, exactly=True
).run(ep.graph_module.code)
self.assertTrue(torch.allclose(ep(*inp), m(*inp)))

core_aten_ep = ep.run_decompositions()
FileCheck().check_count(
"torch.ops.aten.permute.default", 1, exactly=True
).run(core_aten_ep.graph_module.code)
FileCheck().check_count(
"torch.ops.aten.t.default", 0, exactly=True
).run(core_aten_ep.graph_module.code)
self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp)))

def test_export_decomps_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin(x)

inp = (torch.randn(5, 10),)
m = M()
with unittest.mock.patch("torch._export.DECOMP_TABLE", None):
ep = export(m, inp, dynamic_shapes={"x": {0: Dim("batch")}})

core_aten_ep = ep.run_decompositions()

input_node = [node for node in core_aten_ep.graph.nodes if node.op == "placeholder"][-1]
self.assertTrue(isinstance(input_node.meta["val"].shape[0], torch.SymInt))

FileCheck().check_count(
"torch.ops.aten.permute.default", 1, exactly=True
).run(core_aten_ep.graph_module.code)
FileCheck().check_count(
"torch.ops.aten.t.default", 0, exactly=True
).run(core_aten_ep.graph_module.code)
self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp)))

if __name__ == '__main__':
run_tests()
7 changes: 5 additions & 2 deletions torch/_export/passes/lift_constant_tensor_pass.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
from torch._export import ExportedProgram
from torch._guards import detect_fake_mode


def lift_constant_tensor_pass(ep: ExportedProgram) -> ExportedProgram:
def lift_constant_tensor_pass(ep):
"""
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
with the constant tensors as buffers.
"""
if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
return ep

Expand Down
198 changes: 169 additions & 29 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import torch.fx._pytree as fx_pytree
import torch.utils._pytree as pytree
from torch._decomp import core_aten_decompositions
from torch.fx._compatibility import compatibility

from torch.fx.passes.infra.pass_base import PassResult
Expand Down Expand Up @@ -422,43 +423,151 @@ def module(self) -> torch.nn.Module:

return unlift_exported_program_lifted_states(self)

def _transform(self, *passes: PassType) -> "ExportedProgram":
def run_decompositions(
self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None
) -> "ExportedProgram":
"""
Run a set of decompositions on the exported program and returns a new
exported program. By default we will run the Core ATen decompositions to
get the Core ATen IR.
For now, we do not decompose joint graphs.
"""
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
RangeConstraint,
_AddRuntimeAssertionsForInlineConstraintsPass,
InputDim,
)
from torch._export.passes.lift_constant_tensor_pass import (
lift_constant_tensor_pass,
)
from torch._export.passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from torch._functorch.aot_autograd import aot_export_module

def _get_placeholders(gm):
placeholders = []
for node in gm.graph.nodes:
if node.op != "placeholder":
break
placeholders.append(node)
return placeholders

decomp_table = decomp_table or core_aten_decompositions()

old_placeholders = _get_placeholders(self.graph_module)
fake_args = [node.meta["val"] for node in old_placeholders]

gm, graph_signature = aot_export_module(
self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False
)

# Update the signatures with the new placeholder names in case they
# changed when calling aot_export
new_placeholders = _get_placeholders(gm)
assert len(new_placeholders) == len(old_placeholders)
old_new_placeholder_map = {
old_node.name: new_node.name
for old_node, new_node in zip(old_placeholders, new_placeholders)
}
old_outputs = list(self.graph.nodes)[-1].args[0]
new_outputs = list(gm.graph.nodes)[-1].args[0]
assert len(new_outputs) == len(old_outputs)
old_new_output_map = {
old_node.name: new_node.name
for old_node, new_node in zip(old_outputs, new_outputs)
}

new_backward_signature = (
ExportBackwardSignature(
copy.deepcopy(
self.graph_signature.backward_signature.gradients_to_parameters
),
{
old_new_placeholder_map[inp]: param
for inp, param in self.graph_signature.backward_signature.gradients_to_parameters
},
copy.deepcopy(self.graph_signature.backward_signature.loss_output),
)
if self.graph_signature.backward_signature is not None
else None
)

new_graph_signature = ExportGraphSignature(
copy.deepcopy(self.graph_signature.parameters),
copy.deepcopy(self.graph_signature.buffers),
[old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs],
[old_new_output_map[out] for out in self.graph_signature.user_outputs],
{
old_new_placeholder_map[inp]: param
for inp, param in self.graph_signature.inputs_to_parameters.items()
},
{
old_new_placeholder_map[inp]: buffer
for inp, buffer in self.graph_signature.inputs_to_buffers.items()
},
copy.deepcopy(self.graph_signature.buffers_to_mutate),
new_backward_signature,
copy.deepcopy(self.graph_signature.assertion_dep_token),
)

# NOTE: aot_export adds symint metadata for placeholders with int
# values; since these become specialized, we replace such metadata with
# the original values.
# Also, set the param/buffer metadata back to the placeholders.
for old_node, new_node in zip(old_placeholders, new_placeholders):
if not isinstance(old_node.meta["val"], torch.Tensor):
new_node.meta["val"] = old_node.meta["val"]

if (
new_node.target in new_graph_signature.inputs_to_parameters
or new_node.target in new_graph_signature.inputs_to_buffers
):
for k, v in old_node.meta.items():
new_node.meta[k] = v

# TODO unfortunately preserving graph-level metadata is not
# working well with aot_export. So we manually copy it.
# (The node-level meta is addressed above.)
gm.meta.update(self.graph_module.meta)

new_range_constraints = _get_updated_range_constraints(gm)

new_equality_constraints = [
(
InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim),
InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim),
)
for inp_dim1, inp_dim2 in self.equality_constraints
]

exported_program = ExportedProgram(
gm,
gm.graph,
new_graph_signature,
copy.deepcopy(self.call_spec),
self.state_dict,
new_range_constraints,
new_equality_constraints,
copy.deepcopy(self.module_call_graph),
self.example_inputs,
self.dialect,
)

if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0:
exported_program = exported_program._transform(
_AddRuntimeAssertionsForInlineConstraintsPass(
new_range_constraints, new_equality_constraints
)
)
exported_program = lift_constant_tensor_pass(exported_program)

return exported_program._transform(_ReplaceSymSizeOpPass())

def _transform(self, *passes: PassType) -> "ExportedProgram":
pm = PassManager(list(passes))
res = pm(self.graph_module)
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None

def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> Dict[sympy.Symbol, RangeConstraint]:
def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: RangeConstraint(v.lower, v.upper)
for k, v in shape_env.var_to_range.items()
}
return range_constraints

def _get_updated_graph_signature(
old_signature: ExportGraphSignature,
new_gm: torch.fx.GraphModule,
Expand Down Expand Up @@ -562,3 +671,34 @@ def _validate(self):
if not isinstance(gm, torch.fx.GraphModule):
continue
verifier.check_valid(self.graph_module)


def _get_updated_range_constraints(
gm: torch.fx.GraphModule,
) -> Dict[sympy.Symbol, Any]:
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
RangeConstraint,
)

def get_shape_env(gm):
vals = [
node.meta["val"]
for node in gm.graph.nodes
if node.meta.get("val", None) is not None
]
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(vals)
if fake_mode is not None:
return fake_mode.shape_env
for v in vals:
if isinstance(v, torch.SymInt):
return v.node.shape_env

shape_env = get_shape_env(gm)
if shape_env is None:
return {}
range_constraints = {
k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items()
}
return range_constraints

0 comments on commit 487e933

Please sign in to comment.