Skip to content

Commit

Permalink
rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jun 6, 2023
1 parent 8bb3414 commit 57daabb
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 103 deletions.
76 changes: 1 addition & 75 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
import torch
import torch._dynamo as torchdynamo
from torch._export import export
from torch._export.serde.serialize import ExportedProgramSerializer
from torch._export.serde.serialize import ExportedProgramSerializer
from torch._export.serde.serialize import (
ExportedProgramSerializer,
deserialize,
Expand Down Expand Up @@ -39,27 +37,11 @@ def forward(self, x, w, b):
torch.ones([512]),
torch.ones([512]),
),
<<<<<<< HEAD
<<<<<<< HEAD
)

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-7]
self.assertEqual(node.target, "aten.var_mean.correction")
=======
).module

serialized, _ = serialize(exported_module)
node = serialized.graph.nodes[0]
self.assertEqual(node.target.name, "aten.var_mean.correction")
>>>>>>> [export] Initial serialization
=======
)

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[0]
self.assertEqual(node.target, "aten.var_mean.correction")
>>>>>>> idk what happened
# aten::native_layer_norm returns 3 tensnors
self.assertEqual(len(node.outputs), 2)

Expand All @@ -80,27 +62,11 @@ def forward(self, x):

input = torch.arange(10.0).reshape(5, 2)
input.requires_grad = True
<<<<<<< HEAD
<<<<<<< HEAD
exported_module = export(MyModule(), (input,))

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.split.Tensor")
=======
exported_module = export(MyModule(), (input,)).module

serialized, _ = serialize(exported_module)
node = serialized.graph.nodes[0]
self.assertEqual(node.target.name, "aten.split.Tensor")
>>>>>>> [export] Initial serialization
=======
exported_module = export(MyModule(), (input,))

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[0]
self.assertEqual(node.target, "aten.split.Tensor")
>>>>>>> idk what happened
self.assertEqual(len(node.outputs), 1)
# Input looks like:
# tensor([[0, 1],
Expand Down Expand Up @@ -139,27 +105,11 @@ def forward(self, x):
exported_module = export(
MyModule(),
(torch.ones([512, 512], requires_grad=True),),
<<<<<<< HEAD
<<<<<<< HEAD
)

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.var_mean.correction")
=======
).module

serialized, _ = serialize(exported_module)
node = serialized.graph.nodes[0]
self.assertEqual(node.target.name, "aten.var_mean.correction")
>>>>>>> [export] Initial serialization
=======
)

serialized, _ = ExportedProgramSerializer().serialize(exported_module)
node = serialized.graph_module.graph.nodes[0]
self.assertEqual(node.target, "aten.var_mean.correction")
>>>>>>> idk what happened
self.assertEqual(len(node.outputs), 2)

# check the names are unique
Expand All @@ -180,41 +130,17 @@ def f(x: torch.Tensor) -> torch.Tensor:
return torch.searchsorted(x, values, side="right", right=True)

x, _ = torch.sort(torch.randn(3, 4))
<<<<<<< HEAD
<<<<<<< HEAD
exported_module = export(f, (x,))
serialized, _ = ExportedProgramSerializer().serialize(exported_module)

node = serialized.graph_module.graph.nodes[-1]
self.assertEqual(node.target, "aten.searchsorted.Tensor")
=======
exported_module = export(f, (x,)).module
serialized, _ = serialize(exported_module)

node = serialized.graph.nodes[1]
self.assertEqual(node.target.name, "aten.searchsorted.Tensor")
>>>>>>> [export] Initial serialization
=======
exported_module = export(f, (x,))
serialized, _ = ExportedProgramSerializer().serialize(exported_module)

node = serialized.graph_module.graph.nodes[1]
self.assertEqual(node.target, "aten.searchsorted.Tensor")
>>>>>>> idk what happened
self.assertEqual(len(node.inputs), 6)
self.assertEqual(node.inputs[2].arg.as_bool, False)
self.assertEqual(node.inputs[3].arg.as_bool, True)
self.assertEqual(node.inputs[4].arg.as_string, "right")
self.assertEqual(node.inputs[5].arg.as_none, ())
<<<<<<< HEAD
<<<<<<< HEAD

=======

>>>>>>> [export] Initial serialization
=======

>>>>>>> idk what happened

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestDeserialize(TestCase):
Expand Down Expand Up @@ -273,4 +199,4 @@ def forward(self, x):


if __name__ == '__main__':
run_tests()
run_tests()
29 changes: 1 addition & 28 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,25 +236,11 @@ def serialize_signature(sig: ep.ExportGraphSignature) -> GraphSignature:
backward_signature = None

graph_signature = GraphSignature(
<<<<<<< HEAD
<<<<<<< HEAD
=======
>>>>>>> lint
inputs_to_parameters=sig.inputs_to_parameters, # type: ignore[arg-type]
inputs_to_buffers=sig.inputs_to_buffers, # type: ignore[arg-type]
user_inputs=sig.user_inputs, # type: ignore[arg-type]
user_outputs=sig.user_outputs, # type: ignore[arg-type]
buffers_to_mutate=sig.buffers_to_mutate, # type: ignore[arg-type]
<<<<<<< HEAD
=======
inputs_to_parameters=sig.inputs_to_parameters,
inputs_to_buffers=sig.inputs_to_buffers,
user_inputs=sig.user_inputs,
user_outputs=sig.user_outputs,
buffers_to_mutate=sig.buffers_to_mutate,
>>>>>>> idk what happened
=======
>>>>>>> lint
backward_signature=backward_signature,
)
return graph_signature
Expand Down Expand Up @@ -415,11 +401,7 @@ def serialize_input(self, arg) -> Argument:
if arg.op == "get_attr":
return Argument.create(as_tensor=TensorArgument(name=str(arg.target)))
elif self.is_sym_int_arg(arg):
<<<<<<< HEAD
return Argument.create(as_sym_int=SymIntArgument.create(as_name=arg.name))
=======
return Argument.create(as_sym_int=SymIntArgument.create(asName=arg.name))
>>>>>>> idk what happened
else:
return Argument.create(as_tensor=TensorArgument(name=arg.name))
elif isinstance(arg, bool):
Expand Down Expand Up @@ -504,11 +486,6 @@ def serialize_outputs(self, node: torch.fx.Node) -> List[Argument]:
returns = node.target._schema.returns

# Check single value return
<<<<<<< HEAD
if len(returns) == 0:
return []
=======
>>>>>>> idk what happened
if _is_single_tensor_return(node.target):
return [Argument.create(as_tensor=self.serialize_tensor_output(node.name, meta_val))]
elif len(returns) == 1 and isinstance(returns[0].real_type, torch.SymIntType): # type: ignore[attr-defined]
Expand Down Expand Up @@ -560,12 +537,8 @@ def serialize(self, graph_module: torch.fx.GraphModule) -> GraphModule:
self.node = node
getattr(self, f"handle_{node.op}")(node)
except Exception as e:
<<<<<<< HEAD
raise SerializeError(f"Failed serializing node {node}") from e
=======
if not isinstance(e, SerializeError):
raise SerializeError(f"Failed serializing node {node}") from e
>>>>>>> idk what happened

graph = Graph(
inputs=self.inputs,
Expand Down Expand Up @@ -837,4 +810,4 @@ def deserialize(
return (
ExportedProgramDeserializer(expected_opset_version)
.deserialize(serialized_exported_program, state_dict)
)
)

0 comments on commit 57daabb

Please sign in to comment.