Skip to content

Commit

Permalink
dynamic_shapes + retrace exported program (#110276)
Browse files Browse the repository at this point in the history
An `ExportedProgram`'s `__call__` signature is different from the original module, so `dynamic_shapes` that follow the original signature would fail when applied to re-export an `ExportedProgram`.

This PR fixes this issue, in other words, the original `dynamic_shapes` should now work when re-exporting.

Differential Revision: D49764011

Pull Request resolved: #110276
Approved by: https://github.com/tugsbayasgalan
  • Loading branch information
avikchaudhuri authored and pytorchmergebot committed Sep 29, 2023
1 parent c2c7c40 commit 359c2a5
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 12 deletions.
4 changes: 1 addition & 3 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,7 @@ def forward(self, x):
torch._dynamo.exc.UserError,
"Cannot provide constraints for already exported program.",
):
_ = torch.export.export(
exported, (inp,), dynamic_shapes={"args": [{0: dim0_x}]}
)
_ = torch.export.export(exported, (inp,), dynamic_shapes={"x": {0: dim0_x}})
# Reexported program should still work for dynamic shapes.
reexported = torch.export.export(exported, (inp,))
self.assertTrue(reexported(torch.ones(7, 5)), Foo()(torch.ones(7, 5)))
Expand Down
23 changes: 14 additions & 9 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def export__RC__(
kwargs = kwargs if kwargs is not None else {}

from collections.abc import Mapping, Sequence
from typing import get_origin, get_args

def assoc_zip(combined_args, dynamic_shapes):
def tree_zip(combined_args, dynamic_shapes):
if isinstance(combined_args, (tuple, list)):
if not isinstance(dynamic_shapes, Sequence):
raise UserError(
Expand All @@ -112,7 +111,7 @@ def assoc_zip(combined_args, dynamic_shapes):
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
)
for i, shape in enumerate(dynamic_shapes):
yield from assoc_zip(combined_args[i], shape)
yield from tree_zip(combined_args[i], shape)
elif isinstance(combined_args, dict):
if not isinstance(dynamic_shapes, Mapping):
raise UserError(
Expand All @@ -126,7 +125,7 @@ def assoc_zip(combined_args, dynamic_shapes):
f"Expected {dynamic_shapes} to have {len(combined_args)} items",
)
for k, shape in dynamic_shapes.items():
yield from assoc_zip(combined_args[k], shape)
yield from tree_zip(combined_args[k], shape)
elif dataclasses.is_dataclass(combined_args):
if not type(dynamic_shapes) == type(combined_args):
raise UserError(
Expand All @@ -135,7 +134,7 @@ def assoc_zip(combined_args, dynamic_shapes):
f"got {dynamic_shapes} instead",
)
for f in dataclasses.fields(combined_args):
yield from assoc_zip(getattr(combined_args, f.name), getattr(dynamic_shapes, f.name))
yield from tree_zip(getattr(combined_args, f.name), getattr(dynamic_shapes, f.name))
elif isinstance(combined_args, torch.Tensor):
yield (combined_args, dynamic_shapes)
else:
Expand Down Expand Up @@ -188,11 +187,17 @@ def update_symbols(tensor, shape):
"try None instead",
)

import inspect
signature = inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f)
combined_args = signature.bind(*args, **kwargs).arguments
if isinstance(f, ExportedProgram):
combined_args = {
k: args[i] if i < len(args) else kwargs[k]
for i, k in enumerate(dynamic_shapes)
}
else:
import inspect
signature = inspect.signature(f.forward) if isinstance(f, torch.nn.Module) else inspect.signature(f)
combined_args = signature.bind(*args, **kwargs).arguments

for tensor, shape in assoc_zip(combined_args, dynamic_shapes):
for tensor, shape in tree_zip(combined_args, dynamic_shapes):
update_symbols(tensor, shape)

constraints = []
Expand Down

0 comments on commit 359c2a5

Please sign in to comment.