From 339f3fac8ed3fb4383fd03954ee5d7eb9301d1d3 Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 24 Apr 2024 16:44:35 -0700 Subject: [PATCH] [export] kill _process_constraints() (#123985) Summary: The process for populating range_constraints follows separate methods for non-strict (`make_constraints`), and strict (`_process_constraints`). The strict method is somewhat more convoluted, and the analysis that Dynamo performs for strict is already present as part of the non-strict process in make_constraints (produce_guards(), running the export constraint solver). This PR kills _process_constraints() and replaces calls with make_constraints, without duplicating the work that Dynamo already does. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123985 Reviewed By: avikchaudhuri Differential Revision: D56086223 Pulled By: pianpwk --- torch/_export/__init__.py | 9 ++- torch/_export/non_strict_utils.py | 104 ++++++++++++++++------------- torch/export/_trace.py | 26 ++++++-- torch/export/dynamic_shapes.py | 107 +----------------------------- torch/export/exported_program.py | 20 ++---- 5 files changed, 91 insertions(+), 175 deletions(-) diff --git a/torch/_export/__init__.py b/torch/_export/__init__.py index 650e1c0711ad..5591b40e2f9e 100644 --- a/torch/_export/__init__.py +++ b/torch/_export/__init__.py @@ -28,6 +28,7 @@ from torch._dynamo.exc import UserError, UserErrorType from torch._dynamo.source import ConstantSource from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._export.non_strict_utils import make_constraints from torch._functorch.aot_autograd import aot_export_module, GraphSignature from torch._functorch.eager_transforms import functionalize from torch._guards import detect_fake_mode @@ -39,7 +40,6 @@ from torch.export._tree_utils import reorder_kwargs from torch.export._unlift import _create_stateful_graph_module from torch.export.dynamic_shapes import ( - _process_constraints, Constraint, dims, dynamic_dim, @@ -175,7 +175,12 @@ def capture_pre_autograd_graph( _restore_state_dict(f, m) flat_args, _ = pytree.tree_flatten((args, kwargs or {})) - range_constraints = _process_constraints(fake_mode, m, 0, flat_args) + range_constraints = make_constraints( + fake_mode, + m, + dynamic_shapes, + 0, + ) module = _create_stateful_graph_module( m, diff --git a/torch/_export/non_strict_utils.py b/torch/_export/non_strict_utils.py index 425f37ba6411..56812fe19105 100644 --- a/torch/_export/non_strict_utils.py +++ b/torch/_export/non_strict_utils.py @@ -16,8 +16,7 @@ from torch._subclasses.fake_tensor import FakeTensorMode from torch.export import Constraint from torch.export.dynamic_shapes import _Dim -from torch.export.exported_program import InputKind -from torch.export.graph_signature import CustomObjArgument, InputSpec, TensorArgument +from torch.export.graph_signature import CustomObjArgument from torch.fx.experimental.symbolic_shapes import ( ConstraintViolationError, DimDynamic, @@ -174,51 +173,47 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes): return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature -def make_constraints( +def _flatten_dynamic_shapes( + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]] +): + def _is_dynamic_shape_leaf(x): + if isinstance(x, dict): + x = list(x.values()) + return x is None or all(isinstance(y, (_Dim, int)) or y is None for y in x) + + if isinstance(dynamic_shapes, (list, tuple)): + flat_dynamic_shapes = [] + for item in dynamic_shapes: + flat_shapes, _ = tree_flatten( + dynamic_shapes, is_leaf=_is_dynamic_shape_leaf + ) + flat_dynamic_shapes += flat_shapes + else: + flat_dynamic_shapes, _ = tree_flatten( + dynamic_shapes, is_leaf=_is_dynamic_shape_leaf + ) + return flat_dynamic_shapes + + +def produce_guards_and_solve_constraints( fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, equalities_inputs: EqualityConstraint, - dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]], - input_specs: List[InputSpec], original_signature: inspect.Signature, - gm: torch.fx.GraphModule, ): """ Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions, and a graph module, produce guards on the fake mode's shape env (raising constraint - violations if any), solve (to suggest simplifications or fixes), and return the - resulting range constraints and equality constraints. - """ - # TODO(avik): refactor Dynamo to avoid duplication of the following code - # between non-strict and strict. - # Specifically, here (non-strict) we do the following post-tracing steps: - # - Produce guards. - # - Solve constraints. - # - Install shape metadata in IR. - # In strict, these steps are spread across multiple files: - # - guards.py produces guards. - # - eval_frame.py solves constraints - # - _trace.py installs shape metadata in IR. - - inline_constraints = gm.meta.get("inline_constraints", []) - range_constraints = { - symbol: inline_constraints[symbol] for symbol in inline_constraints - } - if dynamic_shapes == []: - return range_constraints - - def _is_dynamic_shape_leaf(x): - if x is None: - return True - if isinstance(x, dict): - x = list(x.values()) - return all(isinstance(y, (_Dim, int)) or y is None for y in x) - - flat_dynamic_shapes, _ = tree_flatten( - dynamic_shapes, is_leaf=_is_dynamic_shape_leaf - ) + violations if any), solve (to suggest simplifications or fixes). + Dynamo already performs this, so this is for non-strict mode. + Additional inputs: + equalities_inputs: the equality constraints to use for guards + original_signature: the signature of the forward method + """ shape_env = fake_mode.shape_env assert shape_env.tracked_fakes is not None + placeholders = [tf.fake for tf in shape_env.tracked_fakes] sources = [tf.source for tf in shape_env.tracked_fakes] input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes] @@ -255,23 +250,41 @@ def _is_dynamic_shape_leaf(x): if constraint_violation_error: raise constraint_violation_error - user_tensor_input_names = { - spec.arg.name - for spec in input_specs - if spec.kind == InputKind.USER_INPUT and isinstance(spec.arg, TensorArgument) + +def make_constraints( + fake_mode: FakeTensorMode, + gm: torch.fx.GraphModule, + dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None], + num_lifted_inputs: int, +): + """ + Given a fake mode's shape env and user-specified dynamic shapes, + return the resulting range constraints and equality constraints. + + Additional args: + num_lifted_inputs: the number of non-user-input placeholder nodes in the graph + (used only to enumerate the user-input nodes) + """ + + shape_env = fake_mode.shape_env + inline_constraints = gm.meta.get("inline_constraints", []) + range_constraints = { + symbol: inline_constraints[symbol] for symbol in inline_constraints } + if not dynamic_shapes: + return range_constraints + flat_dynamic_shapes = _flatten_dynamic_shapes(dynamic_shapes) input_dims = defaultdict(list) free_symbols = set() - input_index = 0 - for node in gm.graph.nodes: - if node.name not in user_tensor_input_names: + for input_index, node in enumerate(gm.graph.nodes): + if input_index < num_lifted_inputs or node.op != "placeholder": continue if _is_constant_argument(node.meta["val"]) or isinstance( node.meta["val"], CustomObjArgument ): continue - shape_spec = flat_dynamic_shapes[input_index] + shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs] for i, d in enumerate(node.meta["val"].shape): if isinstance(d, torch.SymInt): # Look up the range constraint for the symbol corresponding to this shape dimension @@ -290,7 +303,6 @@ def _is_dynamic_shape_leaf(x): ] input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i)) free_symbols.update(d.node.expr.free_symbols) - input_index += 1 for symbol in free_symbols: if symbol not in range_constraints: diff --git a/torch/export/_trace.py b/torch/export/_trace.py index e27bf9a016af..918e10cb906c 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -18,6 +18,7 @@ make_constraints, make_fake_inputs, make_fake_params_buffers, + produce_guards_and_solve_constraints, ) from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( _AddRuntimeAssertionsForInlineConstraintsPass, @@ -50,7 +51,6 @@ from ._safeguard import AutogradStateOpsFailSafeguard -from .dynamic_shapes import _process_constraints from .exported_program import ( _disable_prexisiting_fake_mode, ExportedProgram, @@ -1005,18 +1005,30 @@ def forward(self, *args, **kwargs): for k, v in fake_mode.shape_env.var_to_range.items() if free_unbacked_symbols(k) } + num_lifted = len( + [ + spec + for spec in ep_non_strict.sig.input_specs + if spec.kind != InputKind.USER_INPUT + ] + ) try: - range_constraints = make_constraints( + produce_guards_and_solve_constraints( fake_mode, + ep_non_strict.gm, equalities_inputs, - dynamic_shapes if dynamic_shapes else [], - ep_non_strict.sig.input_specs, original_signature, - ep_non_strict.gm, ) except (ConstraintViolationError, ValueRangeError) as e: raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 + range_constraints = make_constraints( + fake_mode, + ep_non_strict.gm, + dynamic_shapes, + num_lifted, + ) + assert out_spec is not None gm = ep_non_strict.gm @@ -1216,11 +1228,11 @@ def forward(self, *args, **kwargs): ), len(export_graph_signature.input_specs), ) - range_constraints = _process_constraints( + range_constraints = make_constraints( dynamo_fake_mode, gm, + dynamic_shapes, num_lifted, - flat_args, ) # Do some cleanups on the graph module to restore the state dict to the diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 4571781147fc..806a7bef9a30 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -5,10 +5,9 @@ import sys import weakref from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch -from torch._subclasses.fake_tensor import FakeTensor from torch.utils._pytree import SUPPORTED_NODES from .exported_program import ExportedProgram @@ -560,7 +559,6 @@ def _process_dynamic_shapes( kwargs: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, ) -> Optional[List[Constraint]]: - from collections import defaultdict from collections.abc import Mapping, Sequence from torch._dynamo.exc import UserError, UserErrorType @@ -810,106 +808,3 @@ def _create_static_dim(tensor, i, value): constraints.append(primary) return constraints # type: ignore[return-value] - - -def _process_constraints( - fake_mode, - graph_module: torch.fx.GraphModule, - num_lifted_params_buffers: int, - example_inputs: List[torch.Tensor], -) -> Dict: - """ - Process the constraints stored in the graph module to return something more readable. - - Args: - graph_module (torch.fx.GraphModule): GraphModule returned from - dynamo.export, which contains the "input_shape_constraints" and - "inline_constraints" metadata - - example_inputs: Flattened list of example inputs used to export the graph module - - Returns: - range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of - symbols (from SymInts) appearing in the fake tensors in - node.meta["val"] to their range constraints, which are a tuple - containing (lower, upper) constraints. - """ - from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( - InputDim, - ) - - # Import sympy locally - from torch.fx.experimental.symbolic_shapes import SymInt - from torch.utils._sympy.value_ranges import ValueRanges - - input_shape_constraints = graph_module.meta.get("input_shape_constraints", []) - inline_constraints = graph_module.meta.get("inline_constraints", []) - - # Create dict mapping tensor_id to node names - tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list) - # Create dict mapping placeholder node names to their nodes - placeholder_nodes: Dict[str, torch.fx.Node] = {} - for i, node in enumerate(graph_module.graph.nodes): - if node.op != "placeholder": - # All placeholder nodes should be together in the beginning of the - # graph - break - if i >= num_lifted_params_buffers: - example_input = example_inputs[i - num_lifted_params_buffers] - tensor_id_to_nodes[id(example_input)].append(node.name) - placeholder_nodes[node.name] = node - - # Create dict mapping (node name, dim) a list of range (lower, upper) - # constraints - multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list) - for constraint in input_shape_constraints: - for node in tensor_id_to_nodes[constraint["t_id"]]: - # skip static shape constraints - if constraint["min"] == constraint["max"]: - continue - node_dim = InputDim(node, constraint["dim"]) - - # Accumulate range constraints - multi_range_constraints[node_dim].append( - ValueRanges(constraint["min"], constraint["max"]) - ) - - # Create dict mapping symbol to a singular range (lower, upper) - range_constraints: Dict[Any, ValueRanges] = {} - - # Add inline constraints to range_constraints - range_constraints = { - symbol: inline_constraints[symbol] for symbol in inline_constraints - } - - free_symbols: Set["Symbol"] = set() - # Add input range constraints to range_constraints - for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment] - # Simplify the range constraints into a single range constraint - # Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10] - min_vals = [rc.lower for rc in multi_range_constraint] - max_vals = [rc.upper for rc in multi_range_constraint] - min_val = max(min_vals) # type: ignore[type-var] - max_val = min(max_vals) # type: ignore[type-var] - assert min_val <= max_val # type: ignore[operator] - - # Add input node range constraints - val = placeholder_nodes[input_dim.input_name].meta["val"] - assert isinstance(val, FakeTensor) - symint = val.shape[input_dim.dim] - assert isinstance( - symint, SymInt - ), f"Expected SymInt but got {symint}: {type(symint)}" - symbol = symint.node.expr - range_constraints[symbol] = ValueRanges(min_val, max_val) - free_symbols.update(symbol.free_symbols) - - for symbol in free_symbols: - if symbol not in range_constraints: - # Placeholders can have symbolic shapes that are derived expressions. - # The above code will record direct range constraints for them - # so that we can do runtime assertions. In addition, for serde checks - # we want to record range constraints for their root symbols. - range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol] - - return range_constraints diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 64844ea20d9a..7a829df90025 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -651,8 +651,7 @@ def update_arg(old_arg, new_ph): new_range_constraints = _get_updated_range_constraints( gm, - self._num_lifted_params_buffers(), - pytree.tree_leaves(self.example_inputs), + self.range_constraints, _is_executorch=False, ) @@ -764,8 +763,7 @@ def _get_updated_graph_signature( state_dict=self.state_dict, range_constraints=_get_updated_range_constraints( transformed_gm, - self._num_lifted_params_buffers(), - pytree.tree_leaves(self.example_inputs), + self.range_constraints, _is_executorch=False, ), module_call_graph=copy.deepcopy(self._module_call_graph), @@ -812,8 +810,7 @@ def _update( def _get_updated_range_constraints( gm: torch.fx.GraphModule, - num_lifted: Optional[int] = None, - example_inputs: Optional[List[Any]] = None, + old_range_constraints: "Optional[Dict[sympy.Symbol, Any]]" = None, _is_executorch: bool = True, ) -> "Dict[sympy.Symbol, Any]": def get_shape_env(gm): @@ -833,8 +830,7 @@ def get_shape_env(gm): # FIXME(tmanlaibaatar) Remove this whole branch once https://github.com/pytorch/pytorch/pull/123764 if _is_executorch: - assert num_lifted is None - assert example_inputs is None + assert old_range_constraints is None shape_env, _ = get_shape_env(gm) if shape_env is None: return {} @@ -851,17 +847,13 @@ def get_shape_env(gm): range_constraints[k] = v return range_constraints - assert num_lifted is not None - assert example_inputs is not None + assert old_range_constraints is not None shape_env, fake_mode = get_shape_env(gm) if shape_env is None: return {} - from torch.export.dynamic_shapes import _process_constraints - - range_constraints = _process_constraints(fake_mode, gm, num_lifted, example_inputs) - + range_constraints = copy.copy(old_range_constraints) range_constraints = { k: v for k, v in range_constraints.items() if k not in shape_env.replacements }