From 6c1392271059f0ff36eb38189dcd1c7b50fbfabf Mon Sep 17 00:00:00 2001 From: Olivia Liu Date: Tue, 16 Apr 2024 11:47:05 -0700 Subject: [PATCH] ETRecord ser/de handling "None" outputs and more (#3039) Summary: For the ease of communication, let me assign nicknames to the files related to this diff: * File A: *caffe2/torch/_export/serde/serialize.py* * File B: *executorch/exir/serde/serialize.py* * File C: *executorch/exir/serde/export_serialize.py* Recently, we noticed that error `torch._export.serde.serialize.SerializeError: Unable to deserialize output node Argument(as_none=[])` (P1210590561) was thrown from File B when deserializing ETRecord. It's possible that the error has been there since the beginning, but we've just never tested that logic path. In this diff, I made a fix on File B to resolve this particular issue. Also adding handling for "None" output case in sdk logic. ***Keep on reading if you don't think the code changes make sense:*** I explored the history of file changes. In chronological order: 1. D48258552, `deserialize_graph_output()` was copied from File A to File B, with some modifications made. The `deserialize_graph_output()` in File B overrides that in File A due to polymorphism. 2. D52446586, File C was created by ***copying*** File A. As a result of this diff, the `deserialize_graph_output()` in File B now overrides that in File C. 3. Also in D52446586, the `deserialize_graph_output()` in File A had some significant changes; File C got the new version of `deserialize_graph_output()`. But this diff didn't update the `deserialize_graph_output()` in File B. 4. D55391674 added the handling for "None" outputs to File A. This diff brings (parts of) File C up-to-date with File A, and make `deserialize_graph_output()` in File B properly overrides that in File A. In the future, we should figure out how to keep File C and File A in sync. Recently, File C was broken because it didn't stay in sync with File A in D54855251 and had to be fixed by D55776877. There will be a design review session this Friday to discuss consolidating the serialization code for edge and export. Reviewed By: tarun292 Differential Revision: D56091104 --- exir/serde/export_serialize.py | 9 +++++++-- exir/serde/serialize.py | 34 ++++++++++++++++------------------ sdk/debug_format/et_schema.py | 7 ++++++- 3 files changed, 29 insertions(+), 21 deletions(-) diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 799a1dbe78f..fef2b2411fa 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -1190,13 +1190,17 @@ def deserialize_tensor_meta( ), ) - def deserialize_graph_output(self, output) -> torch.fx.Node: + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: if output.type == "as_tensor": return self.serialized_name_to_node[output.as_tensor.name] elif output.type == "as_sym_int": return self.serialized_name_to_node[output.as_sym_int.as_name] elif output.type == "as_sym_bool": return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_none": + return None else: raise SerializeError(f"Unable to deserialize output node {output}") @@ -1249,7 +1253,8 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: output_node.meta["val"] = output_node.args[0].meta["val"] else: output_node.meta["val"] = tuple( - arg.meta["val"] for arg in output_node.args[0] + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] ) return self.graph diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index e1521b98c09..34d55252d83 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -40,7 +40,6 @@ SCHEMA_VERSION, ) from torch._export.serde.schema import SchemaVersion -from torch._export.serde.serialize import SerializeError from torch._export.serde.union import _Union from torch._export.verifier import load_verifier from torch.fx.experimental import symbolic_shapes @@ -484,23 +483,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: return res - def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node: - if isinstance(output.value, schema.TensorArgument): - if output.value.name in self.state_dict: # TODO(T157676982) - val = self.state_dict[output.value.name] - setattr(self.module, output.value.name, val) - node = self.graph.create_node( - "get_attr", - output.value.name, - name=output.value.name, - ) - node.meta = {"val": ""} - return node - return self.serialized_name_to_node[output.value.name] - elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)): - return self.serialized_name_to_node[output.value.as_name] - else: - raise SerializeError(f"Unable to deserialize output node {output}") + def deserialize_graph_output( + self, output: schema.Argument + ) -> Optional[Union[torch.fx.Node, int]]: + if ( + output.type == "as_tensor" and output.value.name in self.state_dict + ): # TODO(T157676982) + val = self.state_dict[output.value.name] + setattr(self.module, output.value.name, val) + node = self.graph.create_node( + "get_attr", + output.value.name, + name=output.value.name, + ) + node.meta = {"val": ""} + return node + return super().deserialize_graph_output(output) # pyre-ignore def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): diff --git a/sdk/debug_format/et_schema.py b/sdk/debug_format/et_schema.py index af95bc7f03a..9a6af4edba9 100644 --- a/sdk/debug_format/et_schema.py +++ b/sdk/debug_format/et_schema.py @@ -260,7 +260,12 @@ def gen_operator_graph( assert len(args) == 1 # Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], ) in_nodes = [ - nodes[FXOperatorGraph._get_node_name(ret)] for ret in args[0] + ( + nodes[FXOperatorGraph._get_node_name(ret)] + if ret is not None + else [] + ) + for ret in args[0] ] node = ValueNode( name,