Skip to content

Commit

Permalink
group constraints by arg (#101815)
Browse files Browse the repository at this point in the history
Before, we would emit a soup of specializations / constraints without any obvious order to guide readability.

With this diff, we group such results by arg, and add comments preceding each group. Empirically, the results read much better.

Differential Revision: [D45995199](https://our.internmc.facebook.com/intern/diff/D45995199/)

Pull Request resolved: #101815
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed May 20, 2023
1 parent b5ee34e commit 03de158
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 52 deletions.
101 changes: 57 additions & 44 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,47 +879,47 @@ def test_dim_constraints_solve_full(self):
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource

src0 = TensorPropertySource(
base=LocalSource(local_name="x0"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="x2"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="x3"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="x4"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)

src1 = TensorPropertySource(
base=LocalSource(local_name="x1"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="x7"), prop=TensorProperty.SIZE, idx=3
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)

src5 = TensorPropertySource(
base=LocalSource(local_name="x5"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="x8"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)

src6 = TensorPropertySource(
base=LocalSource(local_name="x6"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="x9"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="x10"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)

src11 = TensorPropertySource(
base=LocalSource(local_name="x11"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="x12"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)

s0 = Symbol("s0", positive=True, integer=True)
Expand Down Expand Up @@ -1752,48 +1752,61 @@ def test_dim_constraints_solve_full(self):

dim_constraints.solve()
self.assertEqual(dim_constraints._static_results, {
"L['x3'].size()[0] == 8",
"L['x4'].size()[0] == 8",
"L['x1'].size()[2] == 96",
"L['x11'].size()[1] == 1",
"L['x7'].size()[3] == 96",
"L['x12'].size()[2] == 3",
"L['x8'].size()[1] == 22",
"L['x2'].size()[0] == 8",
"L['x5'].size()[1] == 22",
"L['x0'].size()[0] == 8",
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['x10'], 1) == dynamic_dim(L['x6'], 1)",
"2 <= dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"2 <= dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})

def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def dummy_fn(a, b, c, d, e, f):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
assert x0.size()[0] == 8
assert x1.size()[2] == 96
assert x11.size()[1] == 1
assert x12.size()[2] == 3
assert x2.size()[0] == 8
assert x3.size()[0] == 8
assert x4.size()[0] == 8
assert x5.size()[1] == 22
assert x7.size()[3] == 96
assert x8.size()[1] == 22
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
'''
expected_dynamic = '''
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
def specify_constraints(a, b, c, d, e, f):
return [
dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
# c:
dynamic_dim(c, 1),
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
'''

Expand Down
63 changes: 55 additions & 8 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,33 +1750,80 @@ def prettify_results(self, original_signature: inspect.Signature):
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc

def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
raise ValueError(f"Cannot find param `{arg}` in {signature}")
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups

# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
def remove_default_lower_bound(dc):
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)

signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)

def print_results(grouped, indent, result_fn):
nonlocal buf

space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"

buf = ""
indent = 4 * " "
if self._static_results:
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
grouped_static_results = group(self._static_results, args_index)
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
for result in sorted_static_results:
buf += f"\n{indent}assert {result}"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
)
buf += "\n```\n"
if self._dynamic_results:
sorted_dynamic_results = sorted(self._dynamic_results)
grouped_dynamic_results = group(self._dynamic_results, args_index)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{remove_default_lower_bound(unwrap_local_source(result))},"
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{remove_default_lower_bound(result)},",
)
buf += f"\n{indent}]\n```\n"
return buf

Expand Down

0 comments on commit 03de158

Please sign in to comment.