-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
equality assertions #102256
equality assertions #102256
Changes from all commits
da4349c
3c04e94
ebaf377
faf76d0
9b62acf
e98b836
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,16 @@ | ||
from collections import defaultdict, namedtuple | ||
from functools import partial | ||
from typing import Dict, List, Tuple, Optional | ||
|
||
import math | ||
import operator | ||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from functools import partial | ||
from typing import Dict, List, Optional | ||
|
||
import sympy | ||
|
||
import torch | ||
import torch.fx | ||
|
||
from torch._export.graph_module import get_export_meta | ||
from torch._export.graph_module import ConstraintsContainer, get_export_meta | ||
from torch._export.pass_base import ExportPassBase, ProxyValue | ||
from torch._export.pass_infra.node_metadata import NodeMetadata | ||
from torch.fx.passes.infra.pass_base import PassResult | ||
|
@@ -18,7 +19,24 @@ | |
__all__ = ["AddRuntimeAssertionsForConstraintsPass"] | ||
|
||
|
||
ConstraintSpec = namedtuple("ConstraintSpec", ["constraint_dim", "min_val", "max_val"]) | ||
@dataclass | ||
class ConstraintSpec: | ||
""" | ||
Base class for constraint specs. | ||
""" | ||
dim: int | ||
|
||
@dataclass | ||
class RangeConstraintSpec(ConstraintSpec): | ||
# encodes min_val <= _.size()[dim] <= max_val | ||
min_val: int | ||
max_val: int | ||
|
||
@dataclass | ||
class EqualityConstraintSpec(ConstraintSpec): | ||
# encodes _.size()[dim] = other_name.size()[other_dim] | ||
other_name: str | ||
other_dim: int | ||
|
||
# Convert simple sympy Integers into concrete int to | ||
# insert into graph | ||
|
@@ -37,41 +55,47 @@ def __init__(self) -> None: | |
self.current_gm: Optional[torch.fx.GraphModule] = None | ||
|
||
def _process_shape_constraints(self, constraints) -> Dict[str, List[ConstraintSpec]]: | ||
constraints_name_to_constraint: Dict[str, List[ConstraintSpec]] = defaultdict( | ||
list | ||
) | ||
|
||
constraint_name_to_dim: Dict[str, Dict[int, List[Tuple[int, int]]]] = defaultdict( | ||
lambda: defaultdict(list) | ||
input_name_to_dim_constraints: Dict[str, ConstraintsContainer] = defaultdict( | ||
lambda: ConstraintsContainer(defaultdict(list), defaultdict(list)) | ||
) | ||
|
||
for name in constraints: | ||
for dim, min_val, max_val in constraints[name]: | ||
min_max = (_convert_to_int(min_val), _convert_to_int(max_val)) | ||
constraint_name_to_dim[name][dim].append(min_max) | ||
for name, shape_constraints in constraints.items(): | ||
for dim, min_val, max_val in shape_constraints.ranges: | ||
input_name_to_dim_constraints[name].ranges[dim].append( | ||
(_convert_to_int(min_val), _convert_to_int(max_val)) | ||
) | ||
for dim, other_name, other_dim in shape_constraints.equalities: | ||
input_name_to_dim_constraints[name].equalities[dim].append( | ||
(other_name, other_dim) | ||
) | ||
|
||
# Merge the constraints into a single list of constraints | ||
for name, dim_constraints in constraint_name_to_dim.items(): | ||
for dim, constraints in dim_constraints.items(): | ||
min_vals = [x[0] for x in constraints] | ||
max_vals = [x[1] for x in constraints] | ||
min_val = sorted(min_vals, reverse=True)[0] | ||
max_val = sorted(max_vals, reverse=False)[0] | ||
input_name_to_constraints: Dict[str, List[ConstraintSpec]] = defaultdict(list) | ||
for name, dim_constraints in input_name_to_dim_constraints.items(): | ||
for dim, range_constraints in dim_constraints.ranges.items(): | ||
if range_constraints: | ||
min_vals, max_vals = zip(*range_constraints) | ||
min_val = max(min_vals) | ||
max_val = min(max_vals) | ||
assert min_val <= max_val | ||
input_name_to_constraints[name].append( | ||
RangeConstraintSpec(dim=dim, min_val=min_val, max_val=max_val) | ||
) | ||
for dim, eq_constraints in dim_constraints.equalities.items(): | ||
for other_name, other_dim in eq_constraints: | ||
input_name_to_constraints[name].append( | ||
EqualityConstraintSpec(dim=dim, other_name=other_name, other_dim=other_dim) | ||
) | ||
|
||
Comment on lines
+72
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tangential to this diff since I'll do some refactoring of the constraints, but should this belong in the initial processing of the constraints instead of in this pass? #102259 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, for sure. It should be possible to completely replace the current I mention I'm assuming you will do it, but I could do it here if you want. |
||
assert min_val <= max_val | ||
|
||
constraints_name_to_constraint[name].append( | ||
ConstraintSpec(constraint_dim=dim, min_val=min_val, max_val=max_val) | ||
) | ||
|
||
return constraints_name_to_constraint | ||
return input_name_to_constraints | ||
|
||
def call(self, graph_module: torch.fx.GraphModule) -> PassResult: | ||
self.current_gm = graph_module | ||
assert isinstance(self.current_gm, torch.fx.GraphModule) | ||
self.constraints = self._process_shape_constraints(get_export_meta(self.current_gm).input_shape_constraints) | ||
self.input_name_to_example_inputs = get_export_meta(self.current_gm).input_name_to_example_inputs | ||
self.inline_constraints = get_export_meta(self.current_gm).inline_constraints | ||
|
||
self.input_name_to_args: Dict[str, ProxyValue] = {} | ||
return super().call(graph_module) | ||
|
||
def _insert_specialized_shapes_assert(self, arg, dims, name, current_inp): | ||
|
@@ -111,55 +135,80 @@ def placeholder(self, name: str, arg, meta) -> ProxyValue: | |
arg = super().placeholder(name, arg, meta) | ||
if name not in self.input_name_to_example_inputs: | ||
return arg | ||
current_inp = self.input_name_to_example_inputs[name] | ||
# Record the arg mapped to name. | ||
# This will be used when postprocessing placeholders. | ||
self.input_name_to_args[name] = arg | ||
return arg | ||
|
||
def postprocess_placeholders(self): | ||
# Add runtime asserts for input shape constraints. We do this here | ||
# because we can handle both (unary) predicates and (binary) relations. | ||
assert self.current_gm is not None | ||
all_dims = set(range(current_inp.dim())) | ||
|
||
# If no dynamism is specified, we assume all dimensions are specialized | ||
if name not in self.constraints: | ||
self._insert_specialized_shapes_assert(arg, all_dims, name, current_inp) | ||
return arg | ||
for name, arg in self.input_name_to_args.items(): | ||
current_inp = self.input_name_to_example_inputs[name] | ||
all_dims = set(range(current_inp.dim())) | ||
|
||
constraints = self.constraints[name] | ||
# If no dynamism is specified, we assume all dimensions are specialized | ||
if name not in self.constraints: | ||
self._insert_specialized_shapes_assert(arg, all_dims, name, current_inp) | ||
continue | ||
|
||
constrained_dims = set() | ||
# Add runtime asserts for user specified constraints for each | ||
# individual dimensions (e.g not the relational constraints like | ||
# x[1] == x[0]) | ||
for constraint in constraints: | ||
constrained_dims.add(constraint.constraint_dim) | ||
dim = super().call_operator( | ||
torch.ops.aten.sym_size, | ||
(arg, constraint.constraint_dim), | ||
{}, | ||
NodeMetadata({}), | ||
) | ||
assert_msg = ( | ||
f"Input {name}'s dimension #{constraint.constraint_dim} size is " | ||
f"outside of specified dynamic range [{constraint.min_val}, {constraint.max_val}]" | ||
) | ||
# TODO (tmanlaibaatar) we are making an assumption that graph generated for | ||
# input dim N >=2 generalizes to N < 2. Ideally we should check that: | ||
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2 | ||
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2 | ||
# Above can be achieved via a seperate pass. | ||
self._assert_constraint(dim, constraint.min_val, constraint.max_val, assert_msg, low_threshold=2) | ||
|
||
specialized_dims = all_dims - constrained_dims | ||
# Make all non-constrained dims to be static | ||
self._insert_specialized_shapes_assert(arg, specialized_dims, name, current_inp) | ||
|
||
# TODO Add relational constraints | ||
return arg | ||
constraints = self.constraints[name] | ||
|
||
constrained_dims = set() | ||
for constraint in constraints: | ||
constrained_dims.add(constraint.dim) | ||
dim = super().call_operator( | ||
torch.ops.aten.sym_size, | ||
(arg, constraint.dim), | ||
{}, | ||
NodeMetadata({}), | ||
) | ||
if isinstance(constraint, RangeConstraintSpec): | ||
# Add runtime asserts for user-specified range constraints for each | ||
# individual dimension. | ||
assert_msg = ( | ||
f"Input {name}'s dimension #{constraint.dim} size is " | ||
f"outside of specified dynamic range [{constraint.min_val}, {constraint.max_val}]" | ||
) | ||
# TODO (tmanlaibaatar) we are making an assumption that graph generated for | ||
# input dim N >=2 generalizes to N < 2. Ideally we should check that: | ||
# 1. if we can generalize to N < 2, not add any assertion saying N >= 2 | ||
# 2. If we can't generalize to N < 2, add an assertion saying N >= 2 | ||
# Above can be achieved via a seperate pass. | ||
self._assert_range_constraint(dim, constraint.min_val, constraint.max_val, assert_msg, low_threshold=2) | ||
else: | ||
assert isinstance(constraint, EqualityConstraintSpec) | ||
# Add runtime asserts for user-specified equality constraints. | ||
other_arg = self.input_name_to_args[constraint.other_name] | ||
other_dim = super().call_operator( | ||
torch.ops.aten.sym_size, | ||
(other_arg, constraint.other_dim), | ||
{}, | ||
NodeMetadata({}), | ||
) | ||
assert_msg = ( | ||
f"Input {name}'s dimension #{constraint.dim} size is " | ||
f"not equal to input {constraint.other_name}'s dimension #{constraint.other_dim}" | ||
) | ||
self._assert_equality_constraint(dim, other_dim, assert_msg) | ||
|
||
def _assert_constraint(self, proxy, lower, upper, assert_msg, low_threshold=2): | ||
specialized_dims = all_dims - constrained_dims | ||
# Make all non-constrained dims to be static | ||
self._insert_specialized_shapes_assert(arg, specialized_dims, name, current_inp) | ||
|
||
|
||
def _assert_range_constraint(self, proxy, lower, upper, assert_msg, low_threshold=2): | ||
if lower > low_threshold: | ||
self._insert_assert_async(operator.ge, proxy, lower, assert_msg) | ||
|
||
if upper < math.inf: | ||
self._insert_assert_async(operator.le, proxy, upper, assert_msg) | ||
|
||
def _assert_equality_constraint(self, proxy1, proxy2, assert_msg): | ||
self._insert_assert_async(operator.eq, proxy1, proxy2, assert_msg) | ||
|
||
def _insert_assert_async(self, operator, l, r, assert_msg): | ||
cmp = super().call_operator(operator, (l, r), {}, NodeMetadata({})) | ||
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, NodeMetadata({})) | ||
|
@@ -194,7 +243,7 @@ def add_assertions(val): | |
lower = _convert_to_int(constraint.lower) | ||
upper = _convert_to_int(constraint.upper) | ||
assert_msg = f" is outside of inline constraint [{lower}, {upper}]." | ||
call_backs.append(partial(self._assert_constraint, lower=lower, upper=upper, low_threshold=-1)) | ||
call_backs.append(partial(self._assert_range_constraint, lower=lower, upper=upper, low_threshold=-1)) | ||
messages.append(assert_msg) | ||
elif isinstance(val, torch.Tensor): | ||
for i, sym in enumerate(val.shape): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we expect other passes to use this hook?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I'm hesitant to put this into pass base. I'd rather just write it into the AddRuntimeAssertionsPass -- you could just write it as a plain FX pass after call().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO it is a useful hook. E.g., you can imagine passes that manipulate inputs, by adapting the original inputs here to a different set of values derived from the inputs, and use that to transform the rest of the body.
If you feel strongly about taking it out, I can do it after AddRuntimeAssertionsPass stabilizes a bit.