Skip to content

Commit

Permalink
Allow dynamic shapes of tuple type for inputs of dataclass type (#…
Browse files Browse the repository at this point in the history
…117917)

Summary:
In `torch.export.export(f, args, kwargs, ..., dynamic_shpapes=None, ...)`, `dataclass` is an acceptable type of inputs (for args and kwargs). The `dynamic_shapes` of the `dataclass` inputs needs to be the same `dataclass` type which replaces each tensor attributes with `dynamic_shapes` of the corresponding tensors. (https://github.com/pytorch/pytorch/blob/main/torch/export/dynamic_shapes.py#L375)

However, some `dataclass` may have limitations on the types of attributes (e.g., having to be tensors) such that the same `dataclass` cannot be constructed for dynamic shapes.

For an input of `dataclass` type, this task enables a `dynamic_shapes` of a tuple type that specifies dynamic shape specifications for each tensor of the input in the same order as the input dataclass type's flatten_fn (https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py#L103)

Test Plan: buck test //caffe2/test:test_export

Differential Revision: D52932856

Pull Request resolved: #117917
Approved by: https://github.com/avikchaudhuri
  • Loading branch information
BoyuanFeng authored and pytorchmergebot committed Jan 22, 2024
1 parent 4df65bf commit 792dfa7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 18 deletions.
48 changes: 38 additions & 10 deletions test/export/test_export.py
Expand Up @@ -8,8 +8,8 @@
from dataclasses import dataclass

import torch
import torch.nn.functional as F
import torch._dynamo as torchdynamo
import torch.nn.functional as F
from functorch.experimental.control_flow import cond, map
from torch import Tensor
from torch._dynamo.test_case import TestCase
Expand All @@ -31,13 +31,8 @@
)
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
)
from torch.testing._internal.common_device_type import (
onlyCPU,
onlyCUDA,
)
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
from torch.testing._internal.common_device_type import onlyCPU, onlyCUDA
from torch.testing._internal.common_utils import (
run_tests,
TestCase as TorchTestCase,
Expand All @@ -57,6 +52,13 @@
treespec_loads,
)

try:
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor

HAS_TORCHREC = True
except ImportError:
HAS_TORCHREC = False

try:
from . import testing
except ImportError:
Expand Down Expand Up @@ -711,7 +713,7 @@ def forward(self, inputs):
efoo = export(
foo,
inputs,
dynamic_shapes={"inputs": DataClass(a={0: batch}, b={0: batch})},
dynamic_shapes={"inputs": [{0: batch}, {0: batch}]},
)
self.assertEqual(
[
Expand All @@ -722,6 +724,32 @@ def forward(self, inputs):
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)

# pass dynamic shapes of inputs [pytree-registered classes]
if HAS_TORCHREC:
# skipping tests if torchrec not available
class Foo(torch.nn.Module):
def forward(self, kjt) -> torch.Tensor:
return kjt.values() + 0, kjt.offsets() + 0
foo = Foo()
kjt = KeyedJaggedTensor(
values=torch.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.IntTensor([0, 2, 0, 1, 1, 1, 0, 3]),
offsets=torch.IntTensor([0, 0, 2, 2, 3, 4, 5, 5, 8]),
)
inputs = (kjt,)
dim = Dim("dim")
dim_plus_one = Dim("dim_plus_one")
efoo = torch.export.export(
foo,
inputs,
dynamic_shapes={"kjt": [{0: dim}, None, {0: dim}, {0: dim_plus_one}]},
)
self.assertEqual(
[out.shape for out in efoo(*inputs)],
[out.shape for out in foo(*inputs)]
)

# pass dynamic shapes of inputs [distinct, error]
class Foo(torch.nn.Module):
def forward(self, x, y):
Expand Down Expand Up @@ -1828,7 +1856,7 @@ def forward(self, x: Input):
self.assertIsInstance(s, int)

dim0_x_f, dim0_x_p = torch.export.dims("dim0_x_f", "dim0_x_p")
dynamic_shapes = {"x": Input(f={0: dim0_x_f}, p={0: dim0_x_p})}
dynamic_shapes = {"x": [{0: dim0_x_f}, {0: dim0_x_p}]}
ep_dynamic = torch.export.export(
mod, example_inputs, dynamic_shapes=dynamic_shapes
)
Expand Down
18 changes: 10 additions & 8 deletions torch/export/dynamic_shapes.py
Expand Up @@ -9,6 +9,7 @@

import torch
from torch._subclasses.fake_tensor import FakeTensor
from torch.utils._pytree import SUPPORTED_NODES
from .exported_program import ExportedProgram

if TYPE_CHECKING:
Expand Down Expand Up @@ -371,17 +372,18 @@ def tree_zip(combined_args, dynamic_shapes):
)
for k, shape in dynamic_shapes.items():
yield from tree_zip(combined_args[k], shape)
elif dataclasses.is_dataclass(combined_args):
if not type(dynamic_shapes) == type(combined_args):
elif type(combined_args) in SUPPORTED_NODES:
if not isinstance(dynamic_shapes, Sequence):
raise UserError(
UserErrorType.INVALID_INPUT,
f"Expected dynamic_shapes of a {type(combined_args)} to be a {type(combined_args)}, "
f"got {dynamic_shapes} instead",
)
for f in dataclasses.fields(combined_args):
yield from tree_zip(
getattr(combined_args, f.name), getattr(dynamic_shapes, f.name)
f"Expected dynamic_shapes of a user-registered class (e.g., "
f"{type(combined_args)}) to be a Sequence that matches the "
f"flattened structure, but got {dynamic_shapes} instead",
)
yield from tree_zip(
SUPPORTED_NODES[type(combined_args)].flatten_fn(combined_args)[0],
dynamic_shapes,
)
elif isinstance(combined_args, torch.Tensor):
yield (combined_args, dynamic_shapes)
else:
Expand Down

0 comments on commit 792dfa7

Please sign in to comment.