Skip to content

Commit

Permalink
group constraints by arg
Browse files Browse the repository at this point in the history
Pull Request resolved: #101815

Before, we would emit a soup of specializations / constraints without any obvious order to guide readability.

With this diff, we group such results by arg, and add comments preceding each group. Empirically, the results read much better.
ghstack-source-id: 189742576

Differential Revision: [D45995199](https://our.internmc.facebook.com/intern/diff/D45995199/)
  • Loading branch information
avikchaudhuri committed May 19, 2023
1 parent 18f6f30 commit 9c1901d
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 57 deletions.
101 changes: 57 additions & 44 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,47 +882,47 @@ def test_dim_constraints_solve_full(self):
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource

src0 = TensorPropertySource(
base=LocalSource(local_name="x0"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="x2"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="x3"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="x4"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)

src1 = TensorPropertySource(
base=LocalSource(local_name="x1"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="x7"), prop=TensorProperty.SIZE, idx=3
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)

src5 = TensorPropertySource(
base=LocalSource(local_name="x5"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="x8"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)

src6 = TensorPropertySource(
base=LocalSource(local_name="x6"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="x9"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="x10"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)

src11 = TensorPropertySource(
base=LocalSource(local_name="x11"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="x12"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)

s0 = Symbol("s0", positive=True, integer=True)
Expand Down Expand Up @@ -1755,48 +1755,61 @@ def test_dim_constraints_solve_full(self):

dim_constraints.solve()
self.assertEqual(dim_constraints._static_results, {
"L['x3'].size()[0] == 8",
"L['x4'].size()[0] == 8",
"L['x1'].size()[2] == 96",
"L['x11'].size()[1] == 1",
"L['x7'].size()[3] == 96",
"L['x12'].size()[2] == 3",
"L['x8'].size()[1] == 22",
"L['x2'].size()[0] == 8",
"L['x5'].size()[1] == 22",
"L['x0'].size()[0] == 8",
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['x10'], 1) == dynamic_dim(L['x6'], 1)",
"2 <= dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"2 <= dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})

def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def dummy_fn(a, b, c, d, e, f):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
assert x0.size()[0] == 8
assert x1.size()[2] == 96
assert x11.size()[1] == 1
assert x12.size()[2] == 3
assert x2.size()[0] == 8
assert x3.size()[0] == 8
assert x4.size()[0] == 8
assert x5.size()[1] == 22
assert x7.size()[3] == 96
assert x8.size()[1] == 22
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
'''
expected_dynamic = '''
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def specify_constraints(a, b, c, d, e, f):
return [
dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
# c:
dynamic_dim(c, 1),
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
'''

Expand Down
9 changes: 4 additions & 5 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ def export(
if pre_autograd:
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)

if functionalize and not aten_graph:
raise UserError(
Expand Down Expand Up @@ -884,7 +886,7 @@ def result_capturing_wrapper(*graph_inputs):
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
dim_constraints.solve()
msg = dim_constraints.prettify_results(inspect.signature(f))
msg = dim_constraints.prettify_results(original_signature)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
Expand Down Expand Up @@ -1060,10 +1062,7 @@ def signature_to_fullargspec(sig: inspect.Signature):

# Make dynamo graph to have same input/output spec as user code
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f

sig = inspect.signature(call_to_inspect)
fullargspec = signature_to_fullargspec(sig)
fullargspec = signature_to_fullargspec(original_signature)

# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
Expand Down
63 changes: 55 additions & 8 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,33 +1740,80 @@ def prettify_results(self, original_signature: inspect.Signature):
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc

def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
raise ValueError(f"Cannot find param `{arg}` in {signature}")
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups

# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
def remove_default_lower_bound(dc):
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)

signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)

buf = ""
def print_results(grouped, indent, result_fn):
nonlocal buf

space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"

indent = 4 * " "
if self._static_results:
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
for result in sorted_static_results:
buf += f"\n{indent}assert {result}"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
)
buf += "\n```\n"
if self._dynamic_results:
sorted_dynamic_results = sorted(self._dynamic_results)
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{remove_default_lower_bound(unwrap_local_source(result))},"
print_results(
grouped_dynamic_results,
indent*2,
lambda result: f"{remove_default_lower_bound(result)},",
)
buf += f"\n{indent}]\n```\n"
return buf

Expand Down

0 comments on commit 9c1901d

Please sign in to comment.