diff --git a/test/export/test_export.py b/test/export/test_export.py index 078fbad3c7572..299e8be8188c3 100644 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -8,8 +8,8 @@ import torch._dynamo as torchdynamo from functorch.experimental.control_flow import map, cond from torch import Tensor -from torch.export import Constraint -from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export, capture_pre_autograd_graph +from torch.export import Constraint, Dim, export +from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, capture_pre_autograd_graph from torch._export.constraints import constrain_as_size, constrain_as_value from torch._export.utils import ( get_buffer, @@ -1449,5 +1449,61 @@ def forward(self, x): ): exported_program(torch.rand(2, 3), torch.rand(2, 3)) + def test_export_decomps_simple(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + with unittest.mock.patch("torch._export.DECOMP_TABLE", None): + ep = export(m, inp) + + + FileCheck().check_count( + "torch.ops.aten.t.default", 1, exactly=True + ).run(ep.graph_module.code) + self.assertTrue(torch.allclose(ep(*inp), m(*inp))) + + core_aten_ep = ep.run_decompositions() + FileCheck().check_count( + "torch.ops.aten.permute.default", 1, exactly=True + ).run(core_aten_ep.graph_module.code) + FileCheck().check_count( + "torch.ops.aten.t.default", 0, exactly=True + ).run(core_aten_ep.graph_module.code) + self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp))) + + def test_export_decomps_dynamic(self): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.lin = torch.nn.Linear(10, 1) + + def forward(self, x): + return self.lin(x) + + inp = (torch.randn(5, 10),) + m = M() + with unittest.mock.patch("torch._export.DECOMP_TABLE", None): + ep = export(m, inp, dynamic_shapes={"x": {0: Dim("batch")}}) + + core_aten_ep = ep.run_decompositions() + + input_node = [node for node in core_aten_ep.graph.nodes if node.op == "placeholder"][-1] + self.assertTrue(isinstance(input_node.meta["val"].shape[0], torch.SymInt)) + + FileCheck().check_count( + "torch.ops.aten.permute.default", 1, exactly=True + ).run(core_aten_ep.graph_module.code) + FileCheck().check_count( + "torch.ops.aten.t.default", 0, exactly=True + ).run(core_aten_ep.graph_module.code) + self.assertTrue(torch.allclose(core_aten_ep(*inp), m(*inp))) + if __name__ == '__main__': run_tests() diff --git a/torch/_export/passes/lift_constant_tensor_pass.py b/torch/_export/passes/lift_constant_tensor_pass.py index 5ad31703a0b3e..724931080d794 100644 --- a/torch/_export/passes/lift_constant_tensor_pass.py +++ b/torch/_export/passes/lift_constant_tensor_pass.py @@ -1,9 +1,12 @@ import torch -from torch._export import ExportedProgram from torch._guards import detect_fake_mode -def lift_constant_tensor_pass(ep: ExportedProgram) -> ExportedProgram: +def lift_constant_tensor_pass(ep): + """ + Takes an ExportedProgram and returns the ExportedProgram modified in-place, + with the constant tensors as buffers. + """ if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0: return ep diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 75a1795f443ad..6329e9b0029f4 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -8,6 +8,7 @@ import torch import torch.fx._pytree as fx_pytree import torch.utils._pytree as pytree +from torch._decomp import core_aten_decompositions from torch.fx._compatibility import compatibility from torch.fx.passes.infra.pass_base import PassResult @@ -422,43 +423,151 @@ def module(self) -> torch.nn.Module: return unlift_exported_program_lifted_states(self) - def _transform(self, *passes: PassType) -> "ExportedProgram": + def run_decompositions( + self, decomp_table: Optional[Dict[torch._ops.OperatorBase, Callable]] = None + ) -> "ExportedProgram": + """ + Run a set of decompositions on the exported program and returns a new + exported program. By default we will run the Core ATen decompositions to + get the Core ATen IR. + + For now, we do not decompose joint graphs. + """ from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( - RangeConstraint, + _AddRuntimeAssertionsForInlineConstraintsPass, + InputDim, + ) + from torch._export.passes.lift_constant_tensor_pass import ( + lift_constant_tensor_pass, + ) + from torch._export.passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass + from torch._functorch.aot_autograd import aot_export_module + + def _get_placeholders(gm): + placeholders = [] + for node in gm.graph.nodes: + if node.op != "placeholder": + break + placeholders.append(node) + return placeholders + + decomp_table = decomp_table or core_aten_decompositions() + + old_placeholders = _get_placeholders(self.graph_module) + fake_args = [node.meta["val"] for node in old_placeholders] + + gm, graph_signature = aot_export_module( + self.graph_module, fake_args, decompositions=decomp_table, trace_joint=False + ) + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + new_placeholders = _get_placeholders(gm) + assert len(new_placeholders) == len(old_placeholders) + old_new_placeholder_map = { + old_node.name: new_node.name + for old_node, new_node in zip(old_placeholders, new_placeholders) + } + old_outputs = list(self.graph.nodes)[-1].args[0] + new_outputs = list(gm.graph.nodes)[-1].args[0] + assert len(new_outputs) == len(old_outputs) + old_new_output_map = { + old_node.name: new_node.name + for old_node, new_node in zip(old_outputs, new_outputs) + } + + new_backward_signature = ( + ExportBackwardSignature( + copy.deepcopy( + self.graph_signature.backward_signature.gradients_to_parameters + ), + { + old_new_placeholder_map[inp]: param + for inp, param in self.graph_signature.backward_signature.gradients_to_parameters + }, + copy.deepcopy(self.graph_signature.backward_signature.loss_output), + ) + if self.graph_signature.backward_signature is not None + else None + ) + + new_graph_signature = ExportGraphSignature( + copy.deepcopy(self.graph_signature.parameters), + copy.deepcopy(self.graph_signature.buffers), + [old_new_placeholder_map[inp] for inp in self.graph_signature.user_inputs], + [old_new_output_map[out] for out in self.graph_signature.user_outputs], + { + old_new_placeholder_map[inp]: param + for inp, param in self.graph_signature.inputs_to_parameters.items() + }, + { + old_new_placeholder_map[inp]: buffer + for inp, buffer in self.graph_signature.inputs_to_buffers.items() + }, + copy.deepcopy(self.graph_signature.buffers_to_mutate), + new_backward_signature, + copy.deepcopy(self.graph_signature.assertion_dep_token), ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(self.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints(gm) + + new_equality_constraints = [ + ( + InputDim(old_new_placeholder_map[inp_dim1.input_name], inp_dim1.dim), + InputDim(old_new_placeholder_map[inp_dim2.input_name], inp_dim2.dim), + ) + for inp_dim1, inp_dim2 in self.equality_constraints + ] + + exported_program = ExportedProgram( + gm, + gm.graph, + new_graph_signature, + copy.deepcopy(self.call_spec), + self.state_dict, + new_range_constraints, + new_equality_constraints, + copy.deepcopy(self.module_call_graph), + self.example_inputs, + self.dialect, + ) + + if len(new_range_constraints) > 0 or len(new_equality_constraints) > 0: + exported_program = exported_program._transform( + _AddRuntimeAssertionsForInlineConstraintsPass( + new_range_constraints, new_equality_constraints + ) + ) + exported_program = lift_constant_tensor_pass(exported_program) + + return exported_program._transform(_ReplaceSymSizeOpPass()) + + def _transform(self, *passes: PassType) -> "ExportedProgram": pm = PassManager(list(passes)) res = pm(self.graph_module) transformed_gm = res.graph_module if res is not None else self.graph_module assert transformed_gm is not None - def _get_updated_range_constraints( - gm: torch.fx.GraphModule, - ) -> Dict[sympy.Symbol, RangeConstraint]: - def get_shape_env(gm): - vals = [ - node.meta["val"] - for node in gm.graph.nodes - if node.meta.get("val", None) is not None - ] - from torch._guards import detect_fake_mode - - fake_mode = detect_fake_mode(vals) - if fake_mode is not None: - return fake_mode.shape_env - for v in vals: - if isinstance(v, torch.SymInt): - return v.node.shape_env - - shape_env = get_shape_env(gm) - if shape_env is None: - return {} - range_constraints = { - k: RangeConstraint(v.lower, v.upper) - for k, v in shape_env.var_to_range.items() - } - return range_constraints - def _get_updated_graph_signature( old_signature: ExportGraphSignature, new_gm: torch.fx.GraphModule, @@ -562,3 +671,34 @@ def _validate(self): if not isinstance(gm, torch.fx.GraphModule): continue verifier.check_valid(self.graph_module) + + +def _get_updated_range_constraints( + gm: torch.fx.GraphModule, +) -> Dict[sympy.Symbol, Any]: + from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( + RangeConstraint, + ) + + def get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + shape_env = get_shape_env(gm) + if shape_env is None: + return {} + range_constraints = { + k: RangeConstraint(v.lower, v.upper) for k, v in shape_env.var_to_range.items() + } + return range_constraints