Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[export] allow register dataclass as pytree node #106160

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
89 changes: 89 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import export, dynamic_dim
from torch._export.utils import register_dataclass_as_pytree_node
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._pytree import tree_flatten, tree_unflatten, LeafSpec, TreeSpec
from functorch.experimental.control_flow import map
from dataclasses import dataclass


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
Expand Down Expand Up @@ -330,6 +333,92 @@ def fn_ddo(x):
):
_ = export(fn_ddo, (torch.tensor([2, 3, 5]),), constraints=None)

def test_pytree_regster_data_class(self):

@dataclass
class MyDataClass:
x: int
y: int
z: int = None

dt = MyDataClass(x=3, y=4)
flat, spec = tree_flatten(dt)
self.assertTrue(spec, LeafSpec())
self.assertTrue(len(flat) == 1)

register_dataclass_as_pytree_node(MyDataClass)

flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y'],
['z']
),
[LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4])

orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

# Override the registration with keep none fields
register_dataclass_as_pytree_node(MyDataClass, return_none_fields=True)

flat, spec = tree_flatten(dt)
self.assertEqual(
spec,
TreeSpec(
MyDataClass,
(
MyDataClass,
['x', 'y', 'z'],
[],
),
[LeafSpec(), LeafSpec(), LeafSpec()]
)
)
self.assertEqual(flat, [3, 4, None])

orig_dt = tree_unflatten(flat, spec)
self.assertTrue(isinstance(orig_dt, MyDataClass))
self.assertEqual(orig_dt.x, 3)
self.assertEqual(orig_dt.y, 4)
self.assertEqual(orig_dt.z, None)

def test_pytree_regster_nested_data_class(self):

@dataclass
class Inner:
x: int
y: int

@dataclass
class Outer:
xy: Inner
ab: Inner

xy = Inner(1, 2)
ab = Inner(3, 4)
dt = Outer(xy, ab)
inp = {"dt1": (dt, ({},)), "dt2": ((torch.ones(1),), dt)}

register_dataclass_as_pytree_node(Inner)
register_dataclass_as_pytree_node(Outer)

flat, spec = tree_flatten(inp)
self.assertEqual(flat, [1, 2, 3, 4, torch.ones(1), 1, 2, 3, 4])

unflat = tree_unflatten(flat, spec)
self.assertEqual(unflat, inp)


if __name__ == '__main__':
run_tests()
54 changes: 54 additions & 0 deletions torch/_export/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import dataclasses

from typing import Any, List, Optional, Tuple

from torch.utils._pytree import (
_register_pytree_node,
Context,
FlattenFunc,
MaybeFromStrFunc,
ToStrFunc,
UnflattenFunc,
)


def register_dataclass_as_pytree_node(
typ: Any,
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
to_str_fn: Optional[ToStrFunc] = None,
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
*,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
typ
), f"Only dataclasses can be registered with this function: {typ}"

def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
none_names = []
for f in dataclasses.fields(obj):
name, val = f.name, getattr(obj, f.name)
if val is not None or return_none_fields:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, (typ, flat_names, none_names)

def default_unflatten_fn(values: List[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})

flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn

_register_pytree_node(
typ,
flatten_fn,
unflatten_fn,
None,
None,
)