Skip to content

Commit

Permalink
pytree
Browse files Browse the repository at this point in the history
  • Loading branch information
angelayi committed Jun 1, 2023
1 parent 72a4011 commit 3f11071
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 23 deletions.
4 changes: 2 additions & 2 deletions test/export/test_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,10 +147,10 @@ def check_graph(self, fn, inputs) -> None:
"""Export a graph, serialize it, deserialize it, and compare the results."""
exported_module = export(fn, inputs, {})
serialized_struct, state_dict = serialize(exported_module)
loaded_graph = deserialize(serialized_struct, state_dict)
deserialized_ep = deserialize(serialized_struct, state_dict)

orig_outputs = exported_module(*inputs)
loaded_outputs = loaded_graph(*inputs)
loaded_outputs = deserialized_ep(*inputs)

flat_orig_outputs, _ = pytree.tree_flatten(orig_outputs)
flat_loaded_outputs, _ = pytree.tree_flatten(loaded_outputs)
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
# Information to maintain user calling/returning specs
@dataclasses.dataclass
class CallSpec:
in_spec: Optional[pytree.TreeSpec] = None
out_spec: Optional[pytree.TreeSpec] = None
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec


# Extra information for joint graphs
Expand Down
38 changes: 19 additions & 19 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from torch.fx.experimental.symbolic_shapes import is_concrete_int
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
import torch._export.exported_program as ep
from torch.utils._pytree import pytree_to_str, str_to_pytree
from .schema import ( # type: ignore[attr-defined]
Argument,
BackwardSignature,
Expand Down Expand Up @@ -198,22 +199,30 @@ def deserialize_operator(serialized_target: str):
target = torch.ops
for name in serialized_target.split("."):
if not hasattr(target, name):
log.warning(f"Could not find operator {serialized_target}. Returning target as string.") # noqa: G004
return serialized_target
log.warning(f"Could not find operator {serialized_target}. Returning fake operator.") # noqa: G004

# Create a random fake placeholder op
def fake_op(x):
return x
fake_op.__name__ = serialized_target
return fake_op
else:
target = getattr(target, name)
return target


def serialize_call_spec(call_spec: ep.CallSpec) -> CallSpec:
# TODO(angelayi): spec
return CallSpec(in_spec="", out_spec="")
return CallSpec(
in_spec=pytree_to_str(call_spec.in_spec),
out_spec=pytree_to_str(call_spec.out_spec),
)


def deserialize_call_spec(call_spec: CallSpec) -> ep.CallSpec:
# TODO(angelayi): spec
return ep.CallSpec(in_spec=None, out_spec=None)

return ep.CallSpec(
in_spec=str_to_pytree(call_spec.in_spec),
out_spec=str_to_pytree(call_spec.out_spec),
)


def serialize_signature(sig: ep.ExportGraphSignature) -> GraphSignature:
Expand Down Expand Up @@ -593,16 +602,6 @@ def deserialize(

# Nodes: convert to call_function nodes.
for serialized_node in serialized_graph.nodes:
if serialized_node.target == "torch.set_grad_enabled":
# Hack for torch.no_grad support. In the long run this should become
# a higher order op but this is fine for now. See [NOTE: nograd support]
fx_node = graph.call_function(
torch.set_grad_enabled,
(self.deserialize_input(serialized_node.inputs[0].arg),),
)
fx_node.meta.update(deserialize_metadata(serialized_node.metadata))
continue

target = deserialize_operator(serialized_node.target)

# For convenience: if this node returns a single tensor, name the
Expand All @@ -627,7 +626,7 @@ def deserialize(
assert isinstance(output.value, TensorArgument)
outputs.append(self.serialized_name_to_node[output.value.name])

graph.output(tuple(outputs) if len(outputs) > 1 else outputs[0])
graph.output(tuple(outputs))

sig = deserialize_signature(serialized_graph_module.signature)
call_spec = deserialize_call_spec(serialized_graph_module.call_spec)
Expand Down Expand Up @@ -745,8 +744,9 @@ def deserialize(
)
state_dict = deserialize_state_dict(serialized_state_dict)

# TODO(angelyi): serialize constraints
return ep.ExportedProgram(
state_dict, graph_module.graph, sig, call_spec, state_dict,
state_dict, graph_module.graph, sig, call_spec, state_dict, {}, {},
)


Expand Down

0 comments on commit 3f11071

Please sign in to comment.