Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,7 +1368,7 @@ def deserialize(

module_call_graph = self.deserialize_module_call_graph(serialized_graph_module.module_call_graph)
return GraphModuleDeserializer.Result(
graph_module=torch._export.exported_program._create_graph_module_for_export(self.module, self.graph),
graph_module=ep._create_graph_module_for_export(self.module, self.graph),
signature=self.signature,
module_call_graph=module_call_graph,
names_to_symbols=self.symbol_name_to_symbol,
Expand Down Expand Up @@ -1419,7 +1419,7 @@ def deserialize_input(self, inp: Argument) -> Any:
assert isinstance(value, GraphArgument)
with self.save_graph_module():
self.deserialize_graph(value.graph)
submodule = torch._export.exported_program._create_graph_module_for_export(self.module, self.graph)
submodule = ep._create_graph_module_for_export(self.module, self.graph)
self.module.register_module(value.name, submodule)
return self.graph.create_node(
"get_attr",
Expand Down
24 changes: 22 additions & 2 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import dataclasses
import functools
import types
import warnings
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -141,8 +142,6 @@ def __init__(
Dict[str, Union[torch.Tensor, torch._C.ScriptObject]]
] = None,
):
from torch._export.exported_program import _create_graph_module_for_export

# Remove codegen related things from the graph. It should just be a flat graph.
graph._codegen = torch.fx.graph.CodeGen()
self._graph_module = _create_graph_module_for_export(root, graph)
Expand Down Expand Up @@ -664,3 +663,24 @@ def get_shape_env(gm):
if k not in shape_env.replacements:
range_constraints[k] = v
return range_constraints


def _create_graph_module_for_export(root, graph):
try:
gm = torch.fx.GraphModule(root, graph)
except SyntaxError:
# If custom objects stored in memory are being used in the graph,
# the generated python code will result in a syntax error on the custom
# object, since it is unable to parse the in-memory object. However
# we can still run the graph eagerly through torch.fx.Interpreter,
# so we will bypass this error.
warnings.warn(
"Unable to execute the generated python source code from "
"the graph. The graph module will no longer be directly callable, "
"but you can still run the ExportedProgram, and if needed, you can "
"run the graph module eagerly using torch.fx.Interpreter."
)
gm = torch.fx.GraphModule(root, torch.fx.Graph())
gm._graph = graph

return gm