Skip to content

Commit

Permalink
Revert "group constraints by arg (#101815)"
Browse files Browse the repository at this point in the history
This reverts commit 03de158.
Reverted #101815 on behalf of https://github.com/malfet due to it broke ExecuTorch and author was well aware about it"
  • Loading branch information
malfet committed May 22, 2023
1 parent a630328 commit 496212f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 112 deletions.
101 changes: 44 additions & 57 deletions test/test_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,47 +879,47 @@ def test_dim_constraints_solve_full(self):
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource

src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="x0"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="x2"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="x3"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
base=LocalSource(local_name="x4"), prop=TensorProperty.SIZE, idx=0
)

src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="x1"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
base=LocalSource(local_name="x7"), prop=TensorProperty.SIZE, idx=3
)

src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x5"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x8"), prop=TensorProperty.SIZE, idx=1
)

src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x6"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x9"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x10"), prop=TensorProperty.SIZE, idx=1
)

src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
base=LocalSource(local_name="x11"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
base=LocalSource(local_name="x12"), prop=TensorProperty.SIZE, idx=2
)

s0 = Symbol("s0", positive=True, integer=True)
Expand Down Expand Up @@ -1752,61 +1752,48 @@ def test_dim_constraints_solve_full(self):

dim_constraints.solve()
self.assertEqual(dim_constraints._static_results, {
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
"L['x3'].size()[0] == 8",
"L['x4'].size()[0] == 8",
"L['x1'].size()[2] == 96",
"L['x11'].size()[1] == 1",
"L['x7'].size()[3] == 96",
"L['x12'].size()[2] == 3",
"L['x8'].size()[1] == 22",
"L['x2'].size()[0] == 8",
"L['x5'].size()[1] == 22",
"L['x0'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"2 <= dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
"dynamic_dim(L['x10'], 1) == dynamic_dim(L['x6'], 1)",
"2 <= dynamic_dim(L['x6'], 1)",
"dynamic_dim(L['x9'], 1) == dynamic_dim(L['x6'], 1)",
})

def dummy_fn(a, b, c, d, e, f):
def dummy_f(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
pass

action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
action_code = dim_constraints.prettify_results(inspect.signature(dummy_f))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
print(static_code)
expected_static = '''
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
assert a.size()[1] == 22
assert a.size()[2] == 96
assert a.size()[3] == 96
# b:
assert b.size()[0] == 8
assert b.size()[1] == 22
assert b.size()[2] == 3
# c:
assert c.size()[0] == 8
# d:
assert d.size()[0] == 8
# f:
assert f.size()[1] == 1
def specializations(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
assert x0.size()[0] == 8
assert x1.size()[2] == 96
assert x11.size()[1] == 1
assert x12.size()[2] == 3
assert x2.size()[0] == 8
assert x3.size()[0] == 8
assert x4.size()[0] == 8
assert x5.size()[1] == 22
assert x7.size()[3] == 96
assert x8.size()[1] == 22
'''
expected_dynamic = '''
def specify_constraints(a, b, c, d, e, f):
def specify_constraints(x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x11, x12):
return [
# c:
dynamic_dim(c, 1),
# d:
dynamic_dim(d, 1) == dynamic_dim(c, 1),
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
dynamic_dim(x6, 1),
dynamic_dim(x10, 1) == dynamic_dim(x6, 1),
dynamic_dim(x9, 1) == dynamic_dim(x6, 1),
]
'''

Expand Down
63 changes: 8 additions & 55 deletions torch/fx/experimental/symbolic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,80 +1750,33 @@ def prettify_results(self, original_signature: inspect.Signature):
# LocalSource.name() wraps the name with L[""]. We use regular
# expression to do the replacement to avoid traversing up
# the source hierarchy manually.
def extract_and_rewrite_local(dc):
match = re.search(r"L\['(.+?)'\]", dc)
if match is None:
return
arg = match.expand(r'\1')
dc = re.sub(r"L\['(.+?)'\]", r'\1', dc)
return arg, dc

def group(results, args_index):
groups = defaultdict(list)
for dc in results:
local = extract_and_rewrite_local(dc)
if local is None:
# This can happen, e.g., with `assume_constant_result`.
# In that case, we drop the constraint.
# TODO(avik) Maybe we should generate an assertion here?
continue
arg, dc = local
if arg in args_index:
groups[args_index[arg]].append(dc)
else:
raise ValueError(f"Cannot find param `{arg}` in {signature}")
sorted_groups = []
for idx, dcs in sorted(groups.items()):
_, arg = idx
sorted_groups.append((arg, sorted(dcs)))
return sorted_groups
def unwrap_local_source(source_name):
return re.sub(r"L\['(.+?)'\]", r'\1', source_name)

# Instead of 2 <= dynamic_dim(...) simply suggest dynamic_dim(...).
# There is no change in behavior since 2 is the default lower bound.
def remove_default_lower_bound(dc):
return re.sub(r"2 <= dynamic_dim(.+)", r"dynamic_dim\1", dc)

signature = original_signature.replace(return_annotation=inspect.Signature.empty)
args_index = {}
for i, arg in enumerate(signature.parameters.keys()):
args_index[arg] = (i, arg)

def print_results(grouped, indent, result_fn):
nonlocal buf

space = False
for arg, results in grouped:
if space:
buf += "\n"
else:
space = True
buf += f"\n{indent}# {arg}:"
for result in results:
buf += f"\n{indent}{result_fn(result)}"

buf = ""
indent = 4 * " "
if self._static_results:
grouped_static_results = group(self._static_results, args_index)
sorted_static_results = [unwrap_local_source(res) for res in sorted(self._static_results)]
buf += "\nThe following dimensions have been specialized and CANNOT be dynamic."
buf += f"\n```\ndef specializations{str(signature)}:"
print_results(
grouped_static_results,
indent,
lambda result: f"assert {result}",
)
for result in sorted_static_results:
buf += f"\n{indent}assert {result}"
buf += "\n```\n"
if self._dynamic_results:
grouped_dynamic_results = group(self._dynamic_results, args_index)
sorted_dynamic_results = sorted(self._dynamic_results)
buf += "\nThe following dimensions CAN be dynamic."
buf += "\nYou can use the following code to specify the constraints they must satisfy:"
buf += f"\n```\ndef specify_constraints{str(signature)}:"
buf += f"\n{indent}return ["
print_results(
grouped_dynamic_results,
indent * 2,
lambda result: f"{remove_default_lower_bound(result)},",
)
for result in sorted_dynamic_results:
buf += f"\n{indent*2}{remove_default_lower_bound(unwrap_local_source(result))},"
buf += f"\n{indent}]\n```\n"
return buf

Expand Down

0 comments on commit 496212f

Please sign in to comment.