diff --git a/exir/serde/export_serialize.py b/exir/serde/export_serialize.py index 799a1dbe78f..fef2b2411fa 100644 --- a/exir/serde/export_serialize.py +++ b/exir/serde/export_serialize.py @@ -1190,13 +1190,17 @@ def deserialize_tensor_meta( ), ) - def deserialize_graph_output(self, output) -> torch.fx.Node: + def deserialize_graph_output(self, output) -> Optional[Union[torch.fx.Node, int]]: if output.type == "as_tensor": return self.serialized_name_to_node[output.as_tensor.name] elif output.type == "as_sym_int": return self.serialized_name_to_node[output.as_sym_int.as_name] elif output.type == "as_sym_bool": return self.serialized_name_to_node[output.as_sym_bool.as_name] + elif output.type == "as_int": + return output.as_int + elif output.type == "as_none": + return None else: raise SerializeError(f"Unable to deserialize output node {output}") @@ -1249,7 +1253,8 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph: output_node.meta["val"] = output_node.args[0].meta["val"] else: output_node.meta["val"] = tuple( - arg.meta["val"] for arg in output_node.args[0] + arg.meta["val"] if isinstance(arg, torch.fx.Node) else arg + for arg in output_node.args[0] ) return self.graph diff --git a/exir/serde/serialize.py b/exir/serde/serialize.py index 5eb28b830ce..dea86155f21 100644 --- a/exir/serde/serialize.py +++ b/exir/serde/serialize.py @@ -35,7 +35,6 @@ LoweredBackendModule as SerdeLoweredBackendModule, ) from torch._export.serde.schema import SchemaVersion -from torch._export.serde.serialize import SerializeError from torch._export.serde.union import _Union from torch._export.verifier import load_verifier from torch.fx.experimental import symbolic_shapes @@ -479,23 +478,22 @@ def deserialize_metadata(self, metadata: Dict[str, str]) -> Dict[str, Any]: return res - def deserialize_graph_output(self, output: schema.Argument) -> torch.fx.Node: - if isinstance(output.value, schema.TensorArgument): - if output.value.name in self.state_dict: # TODO(T157676982) - val = self.state_dict[output.value.name] - setattr(self.module, output.value.name, val) - node = self.graph.create_node( - "get_attr", - output.value.name, - name=output.value.name, - ) - node.meta = {"val": ""} - return node - return self.serialized_name_to_node[output.value.name] - elif isinstance(output.value, (schema.SymIntArgument, schema.SymBoolArgument)): - return self.serialized_name_to_node[output.value.as_name] - else: - raise SerializeError(f"Unable to deserialize output node {output}") + def deserialize_graph_output( + self, output: schema.Argument + ) -> Optional[Union[torch.fx.Node, int]]: + if ( + output.type == "as_tensor" and output.value.name in self.state_dict + ): # TODO(T157676982) + val = self.state_dict[output.value.name] + setattr(self.module, output.value.name, val) + node = self.graph.create_node( + "get_attr", + output.value.name, + name=output.value.name, + ) + node.meta = {"val": ""} + return node + return super().deserialize_graph_output(output) # pyre-ignore def deserialize_alloc_inputs(self, serialized_inputs: List[schema.NamedArgument]): diff --git a/sdk/debug_format/et_schema.py b/sdk/debug_format/et_schema.py index af95bc7f03a..9a6af4edba9 100644 --- a/sdk/debug_format/et_schema.py +++ b/sdk/debug_format/et_schema.py @@ -260,7 +260,12 @@ def gen_operator_graph( assert len(args) == 1 # Args of op=='output' is a wrapped list of return nodes ([ret_1, ret_2, ...], ) in_nodes = [ - nodes[FXOperatorGraph._get_node_name(ret)] for ret in args[0] + ( + nodes[FXOperatorGraph._get_node_name(ret)] + if ret is not None + else [] + ) + for ret in args[0] ] node = ValueNode( name,