Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions test/test_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,12 @@ def forward(self, a, b):
return a + b
m = M()
g = symbolic_trace(m).graph
t = Proxy(g.result)
new_g = torch.fx.Graph()
new_g.graph_copy(g)
t = Proxy(new_g.nodes[-1])
# test that we can use proxy objects to generate more graph code later for things that do not need to work with modules.
g.output((t + t).node)
gm = GraphModule(m, g)
new_g.output((t + t).node)
gm = GraphModule(m, new_g)
self.assertEqual(gm(3, 4), 14)

@skipIfNoTorchVision
Expand Down Expand Up @@ -466,9 +468,10 @@ def test_deepcopy_graphmodule_with_transform(self):
traced = symbolic_trace(st)

def transform(traced):
new_graph = copy.deepcopy(traced.graph)
new_graph = torch.fx.Graph()
new_graph.graph_copy(traced.graph)
relu_out = new_graph.create_node(
op='call_method', target='neg', args=(new_graph.result,), kwargs={})
op='call_method', target='neg', args=(new_graph.nodes[-1],), kwargs={})
new_graph.output(relu_out)
return GraphModule(traced, new_graph)
transformed = transform(traced)
Expand Down
20 changes: 16 additions & 4 deletions torch/fx/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,21 @@ def map_arg(a: Argument, fn: Callable[[Node], Argument]) -> Argument:

class Graph:
def __init__(self):
self.nodes : List[Node] = []
self._nodes : List[Node] = []
self._used_names : Dict[str, int] = {} # base name -> number

@property
def nodes(self):
return tuple(self._nodes)

def graph_copy(self, g : 'Graph'):
"""
Append all nodes from graph `g` to this graph
"""
val_map : Dict[Node, Node] = {}
for node in g._nodes:
val_map[node] = self.node_copy(node, lambda n : val_map[n])

def _mark_uses(self, a: Argument):
def add_use(n: Node):
n.uses += 1
Expand All @@ -86,7 +98,7 @@ def create_node(self, op: str, target: Target,
self._mark_uses(args)
self._mark_uses(kwargs)
n = Node(self, name if name is not None else self._name(target), op, target, args, kwargs)
self.nodes.append(n)
self._nodes.append(n)
return n

# sugar for above when you know the op
Expand Down Expand Up @@ -161,7 +173,7 @@ def _name(self, target: Target) -> str:
def python_code(self, root_module: str) -> Tuple[str, str, List[str]]:
free_vars: List[str] = []
body: List[str] = []
for node in self.nodes:
for node in self._nodes:
if node.op == 'placeholder':
assert isinstance(node.target, str)
free_vars.append(node.target)
Expand Down Expand Up @@ -237,7 +249,7 @@ def format_node(n : Node) -> Optional[str]:
f'args = {format_arg(n.args)}, kwargs = {format_arg(n.kwargs)})'


node_strs = [format_node(node) for node in self.nodes]
node_strs = [format_node(node) for node in self._nodes]
param_str = ', '.join(placeholder_names)
s = f'graph({param_str}):'
for node_str in node_strs:
Expand Down
30 changes: 26 additions & 4 deletions torch/fx/graph_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,17 @@ def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str):
setattr(to_module, field, from_obj)

class GraphModule(torch.nn.Module):
"""
GraphModule is an nn.Module generated from an fx.Graph. GraphModule has
important attributes:

graph : The graph from which this GraphModule was generated
code : The Python source code for the function generated from `graph`
forward : The Python method generated from `graph`

Note that when `graph` is reassigned, `code` and `forward` will be automatically
regenerated.
"""
def __new__(cls: 'Type[GraphModule]', *args, **kwargs):
# each instance of a graph module needs its own forward method
# so create a new singleton class for each instance.
Expand Down Expand Up @@ -148,10 +159,21 @@ def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
else:
raise RuntimeError('Unsupported type ' + str(root) + ' passed for root!')
self.graph = graph
self._generate_forward()

def _generate_forward(self) -> None:
body, result, free_variables = self.graph.python_code(root_module='self')
# TorchScript breaks trying to compile the graph setter because of the
# continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842
#
# Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway
__ignored_properties__ = ['graph']

@property
def graph(self):
return self._graph

@graph.setter
def graph(self, val) -> None:
self._graph = val
body, result, free_variables = self._graph.python_code(root_module='self')
body = '\n'.join(' ' + line for line in body.split('\n')) + '\n'
self.code = f"""\
def forward(self, {', '.join(free_variables)}):
Expand All @@ -163,7 +185,7 @@ def forward(self, {', '.join(free_variables)}):

def __reduce__(self):
dict_without_graph = self.__dict__.copy()
del dict_without_graph['graph']
del dict_without_graph['_graph']
return (deserialize_graphmodule, (dict_without_graph,))

# because __reduce__ is defined for serialization,
Expand Down