Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

equality assertions #102256

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/export/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,35 @@ def false_fn(x, y):
with self.assertRaisesRegex(RuntimeError, "is outside of inline constraint \\[2, 5\\]."):
ep(torch.tensor(False), torch.tensor([6]), torch.tensor([6]))

def test_runtime_assert_equality_constraint(self):
class Adder(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
return x + y

m = Adder()
x = torch.rand(3, 4)
y = torch.rand(3, 4)
exported = torch._export.export(
m, (x, y), constraints=[dynamic_dim(x, 1) == dynamic_dim(y, 1)]
)
exported = exported.add_runtime_assertions()

x = torch.rand(3, 5)
y = torch.rand(3, 6)
with self.assertRaisesRegex(
RuntimeError,
"Input arg1's dimension #1 size is not equal to input arg0's dimension #1",
):
exported(x, y)

y = torch.rand(3, 5)
dynamo_result = exported(x, y)
real_result = m(x, y)
self.assertTrue(torch._dynamo.utils.same(real_result, dynamo_result))


if __name__ == '__main__':
run_tests()
8 changes: 8 additions & 0 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,14 @@ def serializable_spec(self):
"dim": self.dim,
"min": self.constraint_range.vr.lower,
"max": self.constraint_range.vr.upper,
"shared": (
None
if self.shared is None
else {
"t_id": self.shared.t_id,
"dim": self.shared.dim,
}
),
}

def __eq__(self, other):
Expand Down
36 changes: 25 additions & 11 deletions torch/_export/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@
ConstraintExpr = Union[sympy.Expr, SympyBoolean]


@dataclasses.dataclass
class ConstraintsContainer:
ranges: Any
equalities: Any


@dataclasses.dataclass
class ExportMetadata:
"""The fields in this class are what used to be extra data from ExportGraphModule."""
Expand Down Expand Up @@ -247,31 +253,39 @@ def make_export_graph_module(

gm = fx.GraphModule(root, graph, class_name)

input_tracker = 0

# group by input id
input_shape_constraints_by_tensor_id = defaultdict(list)
for constraint in input_shape_constraints:
input_shape_constraints_by_tensor_id[constraint["t_id"]].append((constraint["dim"], constraint["min"], constraint["max"]))

input_shape_constraints_by_src_name: Dict[str, List[Tuple[int, ConstraintExpr, ConstraintExpr]]] = {}
tensor_id_to_input_names: Dict[int, List[str]] = defaultdict(list)
input_name_to_example_inputs: Dict[str, Any] = {}
if example_inputs is not None:
input_tracker = 0
for node in gm.graph.nodes:
if node.op == "placeholder":
example_input = example_inputs[input_tracker]
if id(example_input) in input_shape_constraints_by_tensor_id:
input_shape_constraints_by_src_name[node.name] = input_shape_constraints_by_tensor_id[id(example_input)]
tensor_id_to_input_names[id(example_input)].append(node.name)
input_name_to_example_inputs[node.name] = example_input
input_tracker += 1

input_shape_constraints_by_src_name: Dict[str, ConstraintsContainer] = defaultdict(
lambda: ConstraintsContainer([], [])
)
for constraint in input_shape_constraints:
for name in tensor_id_to_input_names[constraint["t_id"]]:
input_shape_constraints_by_src_name[name].ranges.append(
(constraint["dim"], constraint["min"], constraint["max"])
)
if constraint["shared"] is not None:
for name in tensor_id_to_input_names[constraint["shared"]["t_id"]]:
for other_name in tensor_id_to_input_names[constraint["t_id"]]:
input_shape_constraints_by_src_name[name].equalities.append(
(constraint["shared"]["dim"], other_name, constraint["dim"])
)

meta = ExportMetadata(
in_spec=in_spec,
out_spec=out_spec,
update_spec=0,
mutation=mutation if mutation else [],
input_shape_constraints=input_shape_constraints_by_src_name,
# copy because pickle cannot handle defaultdict with lambda factory
input_shape_constraints=dict(input_shape_constraints_by_src_name.items()),
inline_constraints=inline_constraints,
input_name_to_example_inputs=input_name_to_example_inputs,
)
Expand Down
12 changes: 12 additions & 0 deletions torch/_export/pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,9 @@ def call_method(

def run_node(self, n: torch.fx.Node) -> Argument:
self.node = n
if self.callback._processing_placeholders and n.op != "placeholder":
self.callback._processing_placeholders = False
self.callback.postprocess_placeholders()
self.callback.node_debug_str = n.format_node()
return super().run_node(n)

Expand All @@ -230,6 +233,9 @@ def __init__(self) -> None:
self.fake_tensor_mode: Optional[FakeTensorMode] = None
self._initialized = True
self.node_debug_str: typing.Optional[str] = None
# state that keeps track of when placeholders are still being processed
# (note that placeholders are always processed before other nodes)
self._processing_placeholders = True

def _fx(
self,
Expand Down Expand Up @@ -294,6 +300,12 @@ def placeholder(self, name: str, arg: Argument, meta: NodeMetadata) -> ProxyValu
self.tracer.set_metadata(arg_proxy.node, arg)
return ProxyValue(arg, arg_proxy)

def postprocess_placeholders(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect other passes to use this hook?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I'm hesitant to put this into pass base. I'd rather just write it into the AddRuntimeAssertionsPass -- you could just write it as a plain FX pass after call().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it is a useful hook. E.g., you can imagine passes that manipulate inputs, by adapting the original inputs here to a different set of values derived from the inputs, and use that to transform the rest of the body.

If you feel strongly about taking it out, I can do it after AddRuntimeAssertionsPass stabilizes a bit.

"""
Hook to post-process placeholders before they are passed to FX nodes.
"""
pass

def call_operator(
self,
op,
Expand Down
185 changes: 117 additions & 68 deletions torch/_export/passes/add_runtime_assertions_for_constraints_pass.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from collections import defaultdict, namedtuple
from functools import partial
from typing import Dict, List, Tuple, Optional

import math
import operator
from collections import defaultdict
from dataclasses import dataclass
from functools import partial
from typing import Dict, List, Optional

import sympy

import torch
import torch.fx

from torch._export.graph_module import get_export_meta
from torch._export.graph_module import ConstraintsContainer, get_export_meta
from torch._export.pass_base import ExportPassBase, ProxyValue
from torch._export.pass_infra.node_metadata import NodeMetadata
from torch.fx.passes.infra.pass_base import PassResult
Expand All @@ -18,7 +19,24 @@
__all__ = ["AddRuntimeAssertionsForConstraintsPass"]


ConstraintSpec = namedtuple("ConstraintSpec", ["constraint_dim", "min_val", "max_val"])
@dataclass
class ConstraintSpec:
"""
Base class for constraint specs.
"""
dim: int

@dataclass
class RangeConstraintSpec(ConstraintSpec):
# encodes min_val <= _.size()[dim] <= max_val
min_val: int
max_val: int

@dataclass
class EqualityConstraintSpec(ConstraintSpec):
# encodes _.size()[dim] = other_name.size()[other_dim]
other_name: str
other_dim: int

# Convert simple sympy Integers into concrete int to
# insert into graph
Expand All @@ -37,41 +55,47 @@ def __init__(self) -> None:
self.current_gm: Optional[torch.fx.GraphModule] = None

def _process_shape_constraints(self, constraints) -> Dict[str, List[ConstraintSpec]]:
constraints_name_to_constraint: Dict[str, List[ConstraintSpec]] = defaultdict(
list
)

constraint_name_to_dim: Dict[str, Dict[int, List[Tuple[int, int]]]] = defaultdict(
lambda: defaultdict(list)
input_name_to_dim_constraints: Dict[str, ConstraintsContainer] = defaultdict(
lambda: ConstraintsContainer(defaultdict(list), defaultdict(list))
)

for name in constraints:
for dim, min_val, max_val in constraints[name]:
min_max = (_convert_to_int(min_val), _convert_to_int(max_val))
constraint_name_to_dim[name][dim].append(min_max)
for name, shape_constraints in constraints.items():
for dim, min_val, max_val in shape_constraints.ranges:
input_name_to_dim_constraints[name].ranges[dim].append(
(_convert_to_int(min_val), _convert_to_int(max_val))
)
for dim, other_name, other_dim in shape_constraints.equalities:
input_name_to_dim_constraints[name].equalities[dim].append(
(other_name, other_dim)
)

# Merge the constraints into a single list of constraints
for name, dim_constraints in constraint_name_to_dim.items():
for dim, constraints in dim_constraints.items():
min_vals = [x[0] for x in constraints]
max_vals = [x[1] for x in constraints]
min_val = sorted(min_vals, reverse=True)[0]
max_val = sorted(max_vals, reverse=False)[0]
input_name_to_constraints: Dict[str, List[ConstraintSpec]] = defaultdict(list)
for name, dim_constraints in input_name_to_dim_constraints.items():
for dim, range_constraints in dim_constraints.ranges.items():
if range_constraints:
min_vals, max_vals = zip(*range_constraints)
min_val = max(min_vals)
max_val = min(max_vals)
assert min_val <= max_val
input_name_to_constraints[name].append(
RangeConstraintSpec(dim=dim, min_val=min_val, max_val=max_val)
)
for dim, eq_constraints in dim_constraints.equalities.items():
for other_name, other_dim in eq_constraints:
input_name_to_constraints[name].append(
EqualityConstraintSpec(dim=dim, other_name=other_name, other_dim=other_dim)
)

Comment on lines +72 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tangential to this diff since I'll do some refactoring of the constraints, but should this belong in the initial processing of the constraints instead of in this pass? #102259 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, for sure. It should be possible to completely replace the current input_constraints and input_name_to_example_inputs with the set of data structures we were talking about with symbol / (name, dim) indexing.

I mention input_name_to_example_inputs too because the only remaining use of it in the add runtime assertion pass is for specialization constraints.

I'm assuming you will do it, but I could do it here if you want.

assert min_val <= max_val

constraints_name_to_constraint[name].append(
ConstraintSpec(constraint_dim=dim, min_val=min_val, max_val=max_val)
)

return constraints_name_to_constraint
return input_name_to_constraints

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.current_gm = graph_module
assert isinstance(self.current_gm, torch.fx.GraphModule)
self.constraints = self._process_shape_constraints(get_export_meta(self.current_gm).input_shape_constraints)
self.input_name_to_example_inputs = get_export_meta(self.current_gm).input_name_to_example_inputs
self.inline_constraints = get_export_meta(self.current_gm).inline_constraints

self.input_name_to_args: Dict[str, ProxyValue] = {}
return super().call(graph_module)

def _insert_specialized_shapes_assert(self, arg, dims, name, current_inp):
Expand Down Expand Up @@ -111,55 +135,80 @@ def placeholder(self, name: str, arg, meta) -> ProxyValue:
arg = super().placeholder(name, arg, meta)
if name not in self.input_name_to_example_inputs:
return arg
current_inp = self.input_name_to_example_inputs[name]
# Record the arg mapped to name.
# This will be used when postprocessing placeholders.
self.input_name_to_args[name] = arg
return arg

def postprocess_placeholders(self):
# Add runtime asserts for input shape constraints. We do this here
# because we can handle both (unary) predicates and (binary) relations.
assert self.current_gm is not None
all_dims = set(range(current_inp.dim()))

# If no dynamism is specified, we assume all dimensions are specialized
if name not in self.constraints:
self._insert_specialized_shapes_assert(arg, all_dims, name, current_inp)
return arg
for name, arg in self.input_name_to_args.items():
current_inp = self.input_name_to_example_inputs[name]
all_dims = set(range(current_inp.dim()))

constraints = self.constraints[name]
# If no dynamism is specified, we assume all dimensions are specialized
if name not in self.constraints:
self._insert_specialized_shapes_assert(arg, all_dims, name, current_inp)
continue

constrained_dims = set()
# Add runtime asserts for user specified constraints for each
# individual dimensions (e.g not the relational constraints like
# x[1] == x[0])
for constraint in constraints:
constrained_dims.add(constraint.constraint_dim)
dim = super().call_operator(
torch.ops.aten.sym_size,
(arg, constraint.constraint_dim),
{},
NodeMetadata({}),
)
assert_msg = (
f"Input {name}'s dimension #{constraint.constraint_dim} size is "
f"outside of specified dynamic range [{constraint.min_val}, {constraint.max_val}]"
)
# TODO (tmanlaibaatar) we are making an assumption that graph generated for
# input dim N >=2 generalizes to N < 2. Ideally we should check that:
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2
# Above can be achieved via a seperate pass.
self._assert_constraint(dim, constraint.min_val, constraint.max_val, assert_msg, low_threshold=2)

specialized_dims = all_dims - constrained_dims
# Make all non-constrained dims to be static
self._insert_specialized_shapes_assert(arg, specialized_dims, name, current_inp)

# TODO Add relational constraints
return arg
constraints = self.constraints[name]

constrained_dims = set()
for constraint in constraints:
constrained_dims.add(constraint.dim)
dim = super().call_operator(
torch.ops.aten.sym_size,
(arg, constraint.dim),
{},
NodeMetadata({}),
)
if isinstance(constraint, RangeConstraintSpec):
# Add runtime asserts for user-specified range constraints for each
# individual dimension.
assert_msg = (
f"Input {name}'s dimension #{constraint.dim} size is "
f"outside of specified dynamic range [{constraint.min_val}, {constraint.max_val}]"
)
# TODO (tmanlaibaatar) we are making an assumption that graph generated for
# input dim N >=2 generalizes to N < 2. Ideally we should check that:
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2
# Above can be achieved via a seperate pass.
self._assert_range_constraint(dim, constraint.min_val, constraint.max_val, assert_msg, low_threshold=2)
else:
assert isinstance(constraint, EqualityConstraintSpec)
# Add runtime asserts for user-specified equality constraints.
other_arg = self.input_name_to_args[constraint.other_name]
other_dim = super().call_operator(
torch.ops.aten.sym_size,
(other_arg, constraint.other_dim),
{},
NodeMetadata({}),
)
assert_msg = (
f"Input {name}'s dimension #{constraint.dim} size is "
f"not equal to input {constraint.other_name}'s dimension #{constraint.other_dim}"
)
self._assert_equality_constraint(dim, other_dim, assert_msg)

def _assert_constraint(self, proxy, lower, upper, assert_msg, low_threshold=2):
specialized_dims = all_dims - constrained_dims
# Make all non-constrained dims to be static
self._insert_specialized_shapes_assert(arg, specialized_dims, name, current_inp)


def _assert_range_constraint(self, proxy, lower, upper, assert_msg, low_threshold=2):
if lower > low_threshold:
self._insert_assert_async(operator.ge, proxy, lower, assert_msg)

if upper < math.inf:
self._insert_assert_async(operator.le, proxy, upper, assert_msg)

def _assert_equality_constraint(self, proxy1, proxy2, assert_msg):
self._insert_assert_async(operator.eq, proxy1, proxy2, assert_msg)

def _insert_assert_async(self, operator, l, r, assert_msg):
cmp = super().call_operator(operator, (l, r), {}, NodeMetadata({}))
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, NodeMetadata({}))
Expand Down Expand Up @@ -194,7 +243,7 @@ def add_assertions(val):
lower = _convert_to_int(constraint.lower)
upper = _convert_to_int(constraint.upper)
assert_msg = f" is outside of inline constraint [{lower}, {upper}]."
call_backs.append(partial(self._assert_constraint, lower=lower, upper=upper, low_threshold=-1))
call_backs.append(partial(self._assert_range_constraint, lower=lower, upper=upper, low_threshold=-1))
messages.append(assert_msg)
elif isinstance(val, torch.Tensor):
for i, sym in enumerate(val.shape):
Expand Down