Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jun 1, 2023
1 parent 76bb21d commit d96b415
Show file tree
Hide file tree
Showing 3 changed files with 414 additions and 8 deletions.
63 changes: 62 additions & 1 deletion test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.serde.serialize import ExportedProgramSerializer
from torch._export.serde.serialize import (
ExportedProgramSerializer,
deserialize,
serialize,
)
import torch.utils._pytree as pytree
from torch.testing._internal.common_utils import run_tests, TestCase


Expand Down Expand Up @@ -136,5 +141,61 @@ def f(x: torch.Tensor) -> torch.Tensor:
self.assertEqual(node.inputs[5].arg.as_none, ())


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
def check_graph(self, fn, inputs) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
exported_module = export(fn, inputs, {})
serialized_struct, state_dict = serialize(exported_module)
loaded_graph = deserialize(serialized_struct, state_dict)

orig_outputs = exported_module(*inputs)
loaded_outputs = loaded_graph(*inputs)

flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)

for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
self.assertTrue(torch.allclose(orig, loaded))

def test_multi_return(self) -> None:
"""
Test multiple return from a single node (ex. layer_norm has 2 outputs)
"""
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, w, b):
return torch.nn.functional.layer_norm(
x,
x.size()[1:],
weight=w,
bias=b,
eps=1e-5,
)

inputs = (
torch.ones([512, 512], requires_grad=True),
torch.ones([512]),
torch.ones([512]),
)
self.check_graph(MyModule(), inputs)

def test_basic(self) -> None:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
x = x + x
x = x * x
x = x / x
return x, x.clone()

inputs = (torch.ones([512], requires_grad=True),)
self.check_graph(MyModule(), inputs)


if __name__ == '__main__':
run_tests()
22 changes: 20 additions & 2 deletions torch/_export/serde/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,28 @@ class _Union:
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
return cls(**{**{field.name: None for field in fields(cls)}, **kwargs})
obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs})
return obj

@classmethod
def fields(cls):
return Enum("FieldEnum", {f.name: f.name for f in fields(cls)})

def __post_init__(self):
assert sum(1 for field in fields(self) if getattr(self, field.name) is not None) == 1
assert sum(1 for f in fields(self) if getattr(self, f.name) is not None) == 1

@property
def value(self):
val = next((getattr(self, f.name) for f in fields(self) if getattr(self, f.name) is not None), None)
assert val is not None
return val

@property
def type(self):
val_type = next((f.name for f in fields(self) if getattr(self, f.name) is not None), None)
assert val_type is not None
return val_type


class ScalarType(Enum):
UNKNOWN = 0
Expand Down

0 comments on commit d96b415

Please sign in to comment.