Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow dynamic shapes of tuple type for inputs of dataclass type #117917

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
48 changes: 38 additions & 10 deletions test/export/test_export.py
Expand Up @@ -8,8 +8,8 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._dynamo.test_case import TestCase
Expand All @@ -31,13 +31,8 @@
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
)
from torch.testing._internal.common_device_type import (
onlyCPU,
onlyCUDA,
)
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import onlyCPU, onlyCUDA
from torch.testing._internal.common_utils import (
run_tests,
TestCase as TorchTestCase,
Expand All @@ -57,6 +52,13 @@
treespec_loads,
)

try:
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

HAS_TORCHREC = True
except ImportError:
HAS_TORCHREC = False

try:
from . import testing
except ImportError:
Expand Down Expand Up @@ -711,7 +713,7 @@ def forward(self, inputs):
efoo = export(
foo,
inputs,
dynamic_shapes={"inputs": DataClass(a={0: batch}, b={0: batch})},
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
)
self.assertEqual(
[
Expand All @@ -722,6 +724,32 @@ def forward(self, inputs):
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)

# pass dynamic shapes of inputs [pytree-registered classes]
if HAS_TORCHREC:
# skipping tests if torchrec not available
class Foo(torch.nn.Module):
def forward(self, kjt) -> torch.Tensor:
return kjt.values() + 0, kjt.offsets() + 0
foo = Foo()
kjt = KeyedJaggedTensor(
values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]),
offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]),
)
inputs = (kjt,)
dim = Dim("dim")
dim_plus_one = Dim("dim_plus_one")
efoo = torch.export.export(
foo,
inputs,
dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]},
)
self.assertEqual(
[out.shape for out in efoo(*inputs)],
[out.shape for out in foo(*inputs)]
)

# pass dynamic shapes of inputs [distinct, error]
class Foo(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -1828,7 +1856,7 @@ def forward(self, x: Input):
self.assertIsInstance(s, int)

dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p")
dynamic_shapes = {"x": Input(f={0: dim0_x_f}, p={0: dim0_x_p})}
dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]}
ep_dynamic = torch.export.export(
mod, example_inputs, dynamic_shapes=dynamic_shapes
)
Expand Down
18 changes: 10 additions & 8 deletions torch/export/dynamic_shapes.py
Expand Up @@ -9,6 +9,7 @@

import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES
from .exported_program import ExportedProgram

if TYPE_CHECKING:
Expand Down Expand Up @@ -371,17 +372,18 @@ def tree_zip(combined_args, dynamic_shapes):
)
for k, shape in dynamic_shapes.items():
yield from tree_zip(combined_args[k], shape)
elif dataclasses.is_dataclass(combined_args):
if not type(dynamic_shapes) == type(combined_args):
elif type(combined_args) in SUPPORTED_NODES:
if not isinstance(dynamic_shapes, Sequence):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a {type(combined_args)}, "
f"got {dynamic_shapes} instead",
)
for f in dataclasses.fields(combined_args):
yield from tree_zip(
getattr(combined_args, f.name), getattr(dynamic_shapes, f.name)
f"Expected dynamic_shapes of a user-registered class (e.g., "
f"{type(combined_args)}) to be a Sequence that matches the "
f"flattened structure, but got {dynamic_shapes} instead",
)
yield from tree_zip(
SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],
dynamic_shapes,
)
elif isinstance(combined_args, torch.Tensor):
yield (combined_args, dynamic_shapes)
else:
Expand Down