Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch._dynamo.exc import UserError, UserErrorType
from torch._dynamo.source import ConstantSource
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass
from torch._export.non_strict_utils import make_constraints
from torch._functorch.aot_autograd import aot_export_module, GraphSignature
from torch._functorch.eager_transforms import functionalize
from torch._guards import detect_fake_mode
Expand All @@ -39,7 +40,6 @@
from torch.export._tree_utils import reorder_kwargs
from torch.export._unlift import _create_stateful_graph_module
from torch.export.dynamic_shapes import (
_process_constraints,
Constraint,
dims,
dynamic_dim,
Expand Down Expand Up @@ -175,7 +175,12 @@ def capture_pre_autograd_graph(
_restore_state_dict(f, m)

flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
range_constraints = _process_constraints(fake_mode, m, 0, flat_args)
range_constraints = make_constraints(
fake_mode,
m,
dynamic_shapes,
0,
)

module = _create_stateful_graph_module(
m,
Expand Down
104 changes: 58 additions & 46 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from torch._subclasses.fake_tensor import FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import _Dim
from torch.export.exported_program import InputKind
from torch.export.graph_signature import CustomObjArgument, InputSpec, TensorArgument
from torch.export.graph_signature import CustomObjArgument
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
Expand Down Expand Up @@ -174,51 +173,47 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):
return fake_mode, fake_args, fake_kwargs, equalities_inputs, original_signature


def make_constraints(
def _flatten_dynamic_shapes(
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]]
):
def _is_dynamic_shape_leaf(x):
if isinstance(x, dict):
x = list(x.values())
return x is None or all(isinstance(y, (_Dim, int)) or y is None for y in x)

if isinstance(dynamic_shapes, (list, tuple)):
flat_dynamic_shapes = []
for item in dynamic_shapes:
flat_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
flat_dynamic_shapes += flat_shapes
else:
flat_dynamic_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
return flat_dynamic_shapes


def produce_guards_and_solve_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
equalities_inputs: EqualityConstraint,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
input_specs: List[InputSpec],
original_signature: inspect.Signature,
gm: torch.fx.GraphModule,
):
"""
Given a fake mode, sources pairs corresponding to equal dynamic shape dimensions,
and a graph module, produce guards on the fake mode's shape env (raising constraint
violations if any), solve (to suggest simplifications or fixes), and return the
resulting range constraints and equality constraints.
"""
# TODO(avik): refactor Dynamo to avoid duplication of the following code
# between non-strict and strict.
# Specifically, here (non-strict) we do the following post-tracing steps:
# - Produce guards.
# - Solve constraints.
# - Install shape metadata in IR.
# In strict, these steps are spread across multiple files:
# - guards.py produces guards.
# - eval_frame.py solves constraints
# - _trace.py installs shape metadata in IR.

inline_constraints = gm.meta.get("inline_constraints", [])
range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints
}
if dynamic_shapes == []:
return range_constraints

def _is_dynamic_shape_leaf(x):
if x is None:
return True
if isinstance(x, dict):
x = list(x.values())
return all(isinstance(y, (_Dim, int)) or y is None for y in x)

flat_dynamic_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
violations if any), solve (to suggest simplifications or fixes).
Dynamo already performs this, so this is for non-strict mode.

Additional inputs:
equalities_inputs: the equality constraints to use for guards
original_signature: the signature of the forward method
"""
shape_env = fake_mode.shape_env
assert shape_env.tracked_fakes is not None

placeholders = [tf.fake for tf in shape_env.tracked_fakes]
sources = [tf.source for tf in shape_env.tracked_fakes]
input_contexts = [tf.symbolic_context for tf in shape_env.tracked_fakes]
Expand Down Expand Up @@ -255,23 +250,41 @@ def _is_dynamic_shape_leaf(x):
if constraint_violation_error:
raise constraint_violation_error

user_tensor_input_names = {
spec.arg.name
for spec in input_specs
if spec.kind == InputKind.USER_INPUT and isinstance(spec.arg, TensorArgument)

def make_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
num_lifted_inputs: int,
):
"""
Given a fake mode's shape env and user-specified dynamic shapes,
return the resulting range constraints and equality constraints.

Additional args:
num_lifted_inputs: the number of non-user-input placeholder nodes in the graph
(used only to enumerate the user-input nodes)
"""

shape_env = fake_mode.shape_env
inline_constraints = gm.meta.get("inline_constraints", [])
range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints
}
if not dynamic_shapes:
return range_constraints

flat_dynamic_shapes = _flatten_dynamic_shapes(dynamic_shapes)
input_dims = defaultdict(list)
free_symbols = set()
input_index = 0
for node in gm.graph.nodes:
if node.name not in user_tensor_input_names:
for input_index, node in enumerate(gm.graph.nodes):
if input_index < num_lifted_inputs or node.op != "placeholder":
continue
if _is_constant_argument(node.meta["val"]) or isinstance(
node.meta["val"], CustomObjArgument
):
continue
shape_spec = flat_dynamic_shapes[input_index]
shape_spec = flat_dynamic_shapes[input_index - num_lifted_inputs]
for i, d in enumerate(node.meta["val"].shape):
if isinstance(d, torch.SymInt):
# Look up the range constraint for the symbol corresponding to this shape dimension
Expand All @@ -290,7 +303,6 @@ def _is_dynamic_shape_leaf(x):
]
input_dims[d.node.expr].append(InputDim(input_name=node.name, dim=i))
free_symbols.update(d.node.expr.free_symbols)
input_index += 1

for symbol in free_symbols:
if symbol not in range_constraints:
Expand Down
26 changes: 19 additions & 7 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
make_constraints,
make_fake_inputs,
make_fake_params_buffers,
produce_guards_and_solve_constraints,
)
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForInlineConstraintsPass,
Expand Down Expand Up @@ -50,7 +51,6 @@

from ._safeguard import AutogradStateOpsFailSafeguard

from .dynamic_shapes import _process_constraints
from .exported_program import (
_disable_prexisiting_fake_mode,
ExportedProgram,
Expand Down Expand Up @@ -1005,18 +1005,30 @@ def forward(self, *args, **kwargs):
for k, v in fake_mode.shape_env.var_to_range.items()
if free_unbacked_symbols(k)
}
num_lifted = len(
[
spec
for spec in ep_non_strict.sig.input_specs
if spec.kind != InputKind.USER_INPUT
]
)
try:
range_constraints = make_constraints(
produce_guards_and_solve_constraints(
fake_mode,
ep_non_strict.gm,
equalities_inputs,
dynamic_shapes if dynamic_shapes else [],
ep_non_strict.sig.input_specs,
original_signature,
ep_non_strict.gm,
)
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200

range_constraints = make_constraints(
fake_mode,
ep_non_strict.gm,
dynamic_shapes,
num_lifted,
)

assert out_spec is not None

gm = ep_non_strict.gm
Expand Down Expand Up @@ -1216,11 +1228,11 @@ def forward(self, *args, **kwargs):
),
len(export_graph_signature.input_specs),
)
range_constraints = _process_constraints(
range_constraints = make_constraints(
dynamo_fake_mode,
gm,
dynamic_shapes,
num_lifted,
flat_args,
)

# Do some cleanups on the graph module to restore the state dict to the
Expand Down
107 changes: 1 addition & 106 deletions torch/export/dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import sys
import weakref
from collections import defaultdict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union

import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES

from .exported_program import ExportedProgram
Expand Down Expand Up @@ -560,7 +559,6 @@ def _process_dynamic_shapes(
kwargs: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None,
) -> Optional[List[Constraint]]:
from collections import defaultdict
from collections.abc import Mapping, Sequence

from torch._dynamo.exc import UserError, UserErrorType
Expand Down Expand Up @@ -810,106 +808,3 @@ def _create_static_dim(tensor, i, value):
constraints.append(primary)

return constraints # type: ignore[return-value]


def _process_constraints(
fake_mode,
graph_module: torch.fx.GraphModule,
num_lifted_params_buffers: int,
example_inputs: List[torch.Tensor],
) -> Dict:
"""
Process the constraints stored in the graph module to return something more readable.

Args:
graph_module (torch.fx.GraphModule): GraphModule returned from
dynamo.export, which contains the "input_shape_constraints" and
"inline_constraints" metadata

example_inputs: Flattened list of example inputs used to export the graph module

Returns:
range_constraints (Dict[sympy.Symbol, ValueRanges]): Mapping of
symbols (from SymInts) appearing in the fake tensors in
node.meta["val"] to their range constraints, which are a tuple
containing (lower, upper) constraints.
"""
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
InputDim,
)

# Import sympy locally
from torch.fx.experimental.symbolic_shapes import SymInt
from torch.utils._sympy.value_ranges import ValueRanges

input_shape_constraints = graph_module.meta.get("input_shape_constraints", [])
inline_constraints = graph_module.meta.get("inline_constraints", [])

# Create dict mapping tensor_id to node names
tensor_id_to_nodes: Dict[int, List[str]] = defaultdict(list)
# Create dict mapping placeholder node names to their nodes
placeholder_nodes: Dict[str, torch.fx.Node] = {}
for i, node in enumerate(graph_module.graph.nodes):
if node.op != "placeholder":
# All placeholder nodes should be together in the beginning of the
# graph
break
if i >= num_lifted_params_buffers:
example_input = example_inputs[i - num_lifted_params_buffers]
tensor_id_to_nodes[id(example_input)].append(node.name)
placeholder_nodes[node.name] = node

# Create dict mapping (node name, dim) a list of range (lower, upper)
# constraints
multi_range_constraints: Dict[InputDim, List[ValueRanges]] = defaultdict(list)
for constraint in input_shape_constraints:
for node in tensor_id_to_nodes[constraint["t_id"]]:
# skip static shape constraints
if constraint["min"] == constraint["max"]:
continue
node_dim = InputDim(node, constraint["dim"])

# Accumulate range constraints
multi_range_constraints[node_dim].append(
ValueRanges(constraint["min"], constraint["max"])
)

# Create dict mapping symbol to a singular range (lower, upper)
range_constraints: Dict[Any, ValueRanges] = {}

# Add inline constraints to range_constraints
range_constraints = {
symbol: inline_constraints[symbol] for symbol in inline_constraints
}

free_symbols: Set["Symbol"] = set()
# Add input range constraints to range_constraints
for input_dim, multi_range_constraint in multi_range_constraints.items(): # type: ignore[assignment]
# Simplify the range constraints into a single range constraint
# Ex. ranges [2, 10] and [3, 11] would get merged to [3, 10]
min_vals = [rc.lower for rc in multi_range_constraint]
max_vals = [rc.upper for rc in multi_range_constraint]
min_val = max(min_vals) # type: ignore[type-var]
max_val = min(max_vals) # type: ignore[type-var]
assert min_val <= max_val # type: ignore[operator]

# Add input node range constraints
val = placeholder_nodes[input_dim.input_name].meta["val"]
assert isinstance(val, FakeTensor)
symint = val.shape[input_dim.dim]
assert isinstance(
symint, SymInt
), f"Expected SymInt but got {symint}: {type(symint)}"
symbol = symint.node.expr
range_constraints[symbol] = ValueRanges(min_val, max_val)
free_symbols.update(symbol.free_symbols)

for symbol in free_symbols:
if symbol not in range_constraints:
# Placeholders can have symbolic shapes that are derived expressions.
# The above code will record direct range constraints for them
# so that we can do runtime assertions. In addition, for serde checks
# we want to record range constraints for their root symbols.
range_constraints[symbol] = fake_mode.shape_env.var_to_range[symbol]

return range_constraints
Loading