Skip to content

Commit

Permalink
[export] use tree_map for _flatten_dynamic_shapes (#125415)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #125415

Fixing the implementation of `_flatten_dynamic_shapes()`, to follow how `_process_dynamic_shapes()` does it. The previous implementation would misinterpret some nested dynamic shapes specs, causing it to miss out on some shapes specs, for example with nested inputs/constant input tuples:

```
inputs = (
    (2, 1),
    (
        torch.randn(2, 1),
        torch.randn(2, 2),
        torch.randn(2, 3),
    )
)

dynamic_shapes = (
    (None, None),
    (
        None,
        None,
        None,
    )
)
```
This would get interpreted as 2 shapes specs for 2d and 3d tensors. Fixing so this doesn't happen.

Test Plan: Existing export tests

Differential Revision: D56894923
  • Loading branch information
pianpwk authored and facebook-github-bot committed May 2, 2024
1 parent b1b0399 commit 5c9ddd1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 22 deletions.
29 changes: 29 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4561,6 +4561,35 @@ def forward(self, x, y, div="floor"):
self.assertEqual(div_spec.arg.name, "div")
self.assertEqual(div_spec.arg.value, "floor")

def test_nested_dynamic_shapes_spec(self):
class Foo(torch.nn.Module):
def forward(self, x):
(a0, a1), (b0, b1), (c0, c1, c2) = x
return a0 + a1 + b0 + b1 + c0 + c1 + c2

f = Foo()
inputs = (
(1, 2),
(
torch.randn(4, 4),
torch.randn(4, 4),
),
(
torch.randn(4, 4),
torch.randn(4, 4),
torch.randn(4, 4),
),
)
# make sure this gets parsed correctly as 7 individual inputs, not 3 tensors
dynamic_shapes = {
"x": (
(None, None),
(None, None),
(None, None, None),
)
}
export(f, (inputs,), dynamic_shapes=dynamic_shapes)


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
Expand Down
3 changes: 3 additions & 0 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Constraint,
dims,
dynamic_dim,
_combine_args,
)
from torch.export.exported_program import (
_disable_prexisiting_fake_mode,
Expand Down Expand Up @@ -175,9 +176,11 @@ def capture_pre_autograd_graph(
_restore_state_dict(f, m)

flat_args, _ = pytree.tree_flatten((args, kwargs or {}))
combined_args = _combine_args(f, args, kwargs)
range_constraints = make_constraints(
fake_mode,
m,
combined_args,
dynamic_shapes,
0,
)
Expand Down
45 changes: 23 additions & 22 deletions torch/_export/non_strict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch._guards import Source
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.export import Constraint
from torch.export.dynamic_shapes import _Dim
from torch.export.dynamic_shapes import _tree_map
from torch.export.graph_signature import CustomObjArgument
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
Expand All @@ -30,7 +30,6 @@
KeyPath,
MappingKey,
SequenceKey,
tree_flatten,
tree_map_with_path,
)

Expand Down Expand Up @@ -180,25 +179,17 @@ def make_fake_inputs(nn_module, args, kwargs, dynamic_shapes):


def _flatten_dynamic_shapes(
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]]
):
def _is_dynamic_shape_leaf(x):
if isinstance(x, dict):
x = list(x.values())
return x is None or all(isinstance(y, (_Dim, int)) or y is None for y in x)

if isinstance(dynamic_shapes, (list, tuple)):
flat_dynamic_shapes = []
for item in dynamic_shapes:
flat_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
flat_dynamic_shapes += flat_shapes
else:
flat_dynamic_shapes, _ = tree_flatten(
dynamic_shapes, is_leaf=_is_dynamic_shape_leaf
)
return flat_dynamic_shapes
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any]],
) -> List[Any]:
flat_shapes = []

def _tree_map_helper(t, shape):
nonlocal flat_shapes
flat_shapes.append(shape)

_tree_map(_tree_map_helper, combined_args, dynamic_shapes)
return flat_shapes


def produce_guards_and_solve_constraints(
Expand Down Expand Up @@ -260,6 +251,7 @@ def produce_guards_and_solve_constraints(
def make_constraints(
fake_mode: FakeTensorMode,
gm: torch.fx.GraphModule,
combined_args: Dict[str, Any],
dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
num_lifted_inputs: int,
):
Expand All @@ -280,7 +272,16 @@ def make_constraints(
if not dynamic_shapes:
return range_constraints

flat_dynamic_shapes = _flatten_dynamic_shapes(dynamic_shapes)
# get individual dynamic shapes spec for each input
if not isinstance(dynamic_shapes, dict):
assert isinstance(dynamic_shapes, (tuple, list))
combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc]
flat_dynamic_shapes = _flatten_dynamic_shapes(combined_args, dynamic_shapes)

# check number of shapes vs. number of inputs
num_placeholders = [node.op == "placeholder" for node in gm.graph.nodes].count(True)
assert len(flat_dynamic_shapes) == num_placeholders - num_lifted_inputs

input_dims = defaultdict(list)
free_symbols = set()
for input_index, node in enumerate(gm.graph.nodes):
Expand Down
5 changes: 5 additions & 0 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from torch._guards import detect_fake_mode
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch._utils_internal import log_export_usage
from torch.export.dynamic_shapes import _combine_args
from torch.export.exported_program import OutputKind
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.symbolic_shapes import (
Expand Down Expand Up @@ -1061,9 +1062,11 @@ def forward(self, *args, **kwargs):
except (ConstraintViolationError, ValueRangeError) as e:
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200

combined_args = _combine_args(mod, args, kwargs)
range_constraints = make_constraints(
fake_mode,
ep_non_strict.gm,
combined_args,
dynamic_shapes,
num_lifted,
)
Expand Down Expand Up @@ -1269,9 +1272,11 @@ def forward(self, *args, **kwargs):
),
len(export_graph_signature.input_specs),
)
combined_args = _combine_args(mod, args, kwargs)
range_constraints = make_constraints(
dynamo_fake_mode,
gm,
combined_args,
dynamic_shapes,
num_lifted,
)
Expand Down

0 comments on commit 5c9ddd1

Please sign in to comment.