Skip to content

Commit

Permalink
group constraints by arg
Browse files Browse the repository at this point in the history
Pull Request resolved: #101815

Before, we would emit a soup of specializations / constraints without any obvious order to guide readability.

With this diff, we group such results by arg, and add comments preceding each group. Empirically, the results read much better.
ghstack-source-id: 189734399

Differential Revision: [D45995199](https://our.internmc.facebook.com/intern/diff/D45995199/)
  • Loading branch information
avikchaudhuri committed May 19, 2023
1 parent 18f6f30 commit 5f8750f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 19 deletions.
36 changes: 30 additions & 6 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1772,31 +1772,55 @@ def test_dim_constraints_solve_full(self):
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
})

def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12):
# x0:
assert x0.size()[0] == 8
# x1:
assert x1.size()[2] == 96
assert x11.size()[1] == 1
assert x12.size()[2] == 3
# x2:
assert x2.size()[0] == 8
# x3:
assert x3.size()[0] == 8
# x4:
assert x4.size()[0] == 8
# x5:
assert x5.size()[1] == 22
# x7:
assert x7.size()[3] == 96
# x8:
assert x8.size()[1] == 22
# x11:
assert x11.size()[1] == 1
# x12:
assert x12.size()[2] == 3
'''
expected_dynamic = '''
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12):
return [
# x6:
dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
# x9:
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
# x10:
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
]
'''

Expand Down
9 changes: 4 additions & 5 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,8 @@ def export(
if pre_autograd:
assert aten_graph, "pre_autograd=True can only be used when aten_graph=True"
f = innermost_fn(f)
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f
original_signature = inspect.signature(call_to_inspect)

if functionalize and not aten_graph:
raise UserError(
Expand Down Expand Up @@ -884,7 +886,7 @@ def result_capturing_wrapper(*graph_inputs):
dim_constraints = shape_env.dim_constraints
assert dim_constraints is not None
dim_constraints.solve()
msg = dim_constraints.prettify_results(inspect.signature(f))
msg = dim_constraints.prettify_results(original_signature)
if constraint_violation_error:
constraint_violation_error.args = (
constraint_violation_error.args[0] + msg,
Expand Down Expand Up @@ -1060,10 +1062,7 @@ def signature_to_fullargspec(sig: inspect.Signature):

# Make dynamo graph to have same input/output spec as user code
def argument_names(f: Callable[..., Any], *args, **kwargs) -> List[str]:
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f

sig = inspect.signature(call_to_inspect)
fullargspec = signature_to_fullargspec(sig)
fullargspec = signature_to_fullargspec(original_signature)

# 1. Map `args` 1-to-1 to positional arguments in original signature.
input_strs = fullargspec.args[: len(args)]
Expand Down
63 changes: 55 additions & 8 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,33 +1740,80 @@ def prettify_results(self, original_signature: inspect.Signature):
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc

def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
raise ValueError(f"Cannot find param `{arg}` in {signature}")
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups

# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
def remove_default_lower_bound(dc):
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)

signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)

def print_results(grouped, buf, indent, result_fn):
space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)},"

buf = ""
indent = 4 * " "
if self._static_results:
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
for result in sorted_static_results:
buf += f"\n{indent}assert {result}"
print_results(
grouped_static_results,
buf,
indent,
lambda result: f"assert {result}",
)
buf += "\n```\n"
if self._dynamic_results:
sorted_dynamic_results = sorted(self._dynamic_results)
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{remove_default_lower_bound(unwrap_local_source(result))},"
print_results(
grouped_dynamic_results,
buf,
indent*2,
lambda result: remove_default_lower_bound(result)
)
buf += f"\n{indent}]\n```\n"
return buf

Expand Down

0 comments on commit 5f8750f

Please sign in to comment.