Skip to content

Commit

Permalink
[export] Initial deserialization v2 (#102716)
Browse files Browse the repository at this point in the history
v2 of #102126. mentally stacked on top of #102707

Pull Request resolved: #102716
Approved by: https://github.com/avikchaudhuri, https://github.com/zhxchen17
  • Loading branch information
angelayi authored and pytorchmergebot committed Jun 7, 2023
1 parent adcefcb commit e930c0f
Show file tree
Hide file tree
Showing 6 changed files with 604 additions and 34 deletions.
75 changes: 69 additions & 6 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,16 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.serde.serialize import ExportedProgramSerializer
from torch._export.serde.serialize import (
ExportedProgramSerializer,
deserialize,
serialize,
)
import torch.utils._pytree as pytree
from torch.testing._internal.common_utils import run_tests, TestCase


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
@unittest.skip("serializing constraints failing")
class TestSerialize(TestCase):
def test_serialize_multiple_returns_from_node(self) -> None:
class MyModule(torch.nn.Module):
Expand All @@ -36,7 +40,7 @@ def forward(self, x, w, b):

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-7]
self.assertEqual(node.target, "aten.var_mean.correction")
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
# aten::native_layer_norm returns 3 tensnors
self.assertEqual(len(node.outputs), 2)

Expand All @@ -61,7 +65,7 @@ def forward(self, x):

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.split.Tensor")
self.assertEqual(node.target, "torch._ops.aten.split.Tensor")
self.assertEqual(len(node.outputs), 1)
# Input looks like:
# tensor([[0, 1],
Expand Down Expand Up @@ -104,7 +108,7 @@ def forward(self, x):

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.var_mean.correction")
self.assertEqual(node.target, "torch._ops.aten.var_mean.correction")
self.assertEqual(len(node.outputs), 2)

# check the names are unique
Expand All @@ -129,13 +133,72 @@ def f(x: torch.Tensor) -> torch.Tensor:
serialized, _ = ExportedProgramSerializer().serialize(exported_module)

node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.searchsorted.Tensor")
self.assertEqual(node.target, "torch._ops.aten.searchsorted.Tensor")
self.assertEqual(len(node.inputs), 6)
self.assertEqual(node.inputs[2].arg.as_bool, False)
self.assertEqual(node.inputs[3].arg.as_bool, True)
self.assertEqual(node.inputs[4].arg.as_string, "right")
self.assertEqual(node.inputs[5].arg.as_none, ())


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
def check_graph(self, fn, inputs) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
# TODO(angelayi): test better with some sort of wrapper around all
# export tests

ep = export(fn, inputs, [])
serialized_struct, state_dict = serialize(ep)
deserialized_ep = deserialize(serialized_struct, state_dict)

orig_outputs = ep(*inputs)
loaded_outputs = deserialized_ep(*inputs)

flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)

for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertTrue(torch.allclose(orig, loaded))

def test_multi_return(self) -> None:
"""
Test multiple return from a single node (ex. layer_norm has 2 outputs)
"""
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w, b):
return torch.nn.functional.layer_norm(
x,
x.size()[1:],
weight=w,
bias=b,
eps=1e-5,
)

inputs = (
torch.ones([512, 512], requires_grad=True),
torch.ones([512]),
torch.ones([512]),
)
self.check_graph(MyModule(), inputs)

def test_basic(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = x + x
x = x * x
x = x / x
return x, x.clone()

inputs = (torch.ones([512], requires_grad=True),)
self.check_graph(MyModule(), inputs)


if __name__ == '__main__':
run_tests()
1 change: 1 addition & 0 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def export(
export_graph_signature,
flat_args,
)
assert orig_out_spec is not None
exported_program = ExportedProgram(
gm,
gm.graph,
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
# Information to maintain user calling/returning specs
@dataclasses.dataclass
class CallSpec:
in_spec: Optional[pytree.TreeSpec] = None
out_spec: Optional[pytree.TreeSpec] = None
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec


# Extra information for joint graphs
Expand Down
33 changes: 31 additions & 2 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,23 @@ class _Union:
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
return cls(**{**{field.name: None for field in fields(cls)}, **kwargs})
return cls(**{**{f.name: None for f in fields(cls)}, **kwargs})

def __post_init__(self):
assert sum(1 for field in fields(self) if getattr(self, field.name) is not None) == 1
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1

@property
def value(self):
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
assert val is not None
return val

@property
def type(self):
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
assert val_type is not None
return val_type


class ScalarType(Enum):
UNKNOWN = 0
Expand Down Expand Up @@ -63,6 +76,12 @@ class SymInt(_Union):
as_int: int


@dataclass
class SymBool(_Union):
as_symbol: str
as_bool: bool


@dataclass
class TensorMeta:
dtype: ScalarType
Expand All @@ -80,6 +99,12 @@ class SymIntArgument(_Union):
as_int: int


@dataclass
class SymBoolArgument(_Union):
as_name: str
as_bool: bool


@dataclass
class TensorArgument:
name: str
Expand All @@ -104,6 +129,8 @@ class Argument(_Union):
as_device: Device
as_bool: bool
as_bools: List[bool]
as_sym_bool: SymBoolArgument
as_sym_bools: List[SymBoolArgument]


@dataclass
Expand Down Expand Up @@ -132,6 +159,7 @@ class Graph:
nodes: List[Node]
tensor_values: Dict[str, TensorValue]
sym_int_values: Dict[str, SymInt]
sym_bool_values: Dict[str, SymBool]


@dataclass
Expand Down Expand Up @@ -164,6 +192,7 @@ class GraphModule:
call_spec: CallSpec


# TODO(angelayi) to add symbol to hint
@dataclass
class ExportedProgram:
graph_module: GraphModule
Expand Down

0 comments on commit e930c0f

Please sign in to comment.