Skip to content

Commit

Permalink
example of splitting up an FX graph into smaller subgraphs with own s…
Browse files Browse the repository at this point in the history
…ubmodules (#45404)

Summary: Pull Request resolved: #45404

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D23956147

Pulled By: Lilyjjo

fbshipit-source-id: a35e33a0b9f1ed5f3fb6e5cd146f66c29bf3d518
  • Loading branch information
Lilyjjo authored and facebook-github-bot committed Oct 2, 2020
1 parent 1552a92 commit f6dc256
Show file tree
Hide file tree
Showing 2 changed files with 212 additions and 0 deletions.
38 changes: 38 additions & 0 deletions test/test_fx.py
Expand Up @@ -9,6 +9,7 @@
from torch.fx.experimental import GraphManipulation
from torch.fx.experimental import shape_prop
from torch.fx.experimental.Partitioner import DAG, Partitioner
from torch.fx.experimental.subgraph_creation_example import split_module

from torch.fx.proxy import TraceError

Expand Down Expand Up @@ -780,5 +781,42 @@ def forward(self, a, b):
assert(r.input_nodes == d.input_nodes)
assert(r.output_nodes == d.output_nodes)

def test_subgraph_creation(self):
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)

def forward(self, x, y):
z = self.linear(x + self.param).clamp(min=0.0, max=1.0)
w = self.linear(y).clamp(min=0.0, max=1.0)
return z + w

# symbolically trace model
my_module = MyModule()
my_module_traced = symbolic_trace(my_module)

# random mod partitioning
partition_counter = 0
NPARTITIONS = 3

def mod_partition(node: Node):
nonlocal partition_counter
partition = partition_counter % NPARTITIONS
partition_counter = (partition_counter + 1) % NPARTITIONS
return partition

# split module in module with submodules
module_with_submodules = split_module(my_module_traced, my_module, mod_partition)

x = torch.rand(3, 4)
y = torch.rand(3, 4)

orig_out = my_module_traced(x, y)
submodules_out = module_with_submodules(x, y)

self.assertEqual(orig_out, submodules_out)

if __name__ == '__main__':
run_tests()
174 changes: 174 additions & 0 deletions torch/fx/experimental/subgraph_creation_example.py
@@ -0,0 +1,174 @@
import torch
from torch.fx.graph_module import GraphModule
from typing import Callable, List, Dict, Set, Any, Optional

class Partition:
def __init__(self, name: str):
self.name: str = name
self.node_names: List[str] = []
self.inputs: Set[str] = set()
self.outputs: Set[str] = set()
self.partitions_dependent_on: Set[str] = set()
self.partition_dependents: Set[str] = set()
self.graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
self.environment : Dict[torch.fx.node.Node, torch.fx.node.Node] = {}
self.targets : Dict[str, Any] = {}

def __repr__(self) -> str:
return f"name: {self.name},\n" \
f" nodes: {self.node_names},\n" \
f" inputs: {self.inputs},\n" \
f" outputs: {self.outputs},\n" \
f" partitions depenent on: {self.partitions_dependent_on},\n" \
f" parition dependents: {self.partition_dependents}"

# Creates subgraphs out of main graph
def split_module(
m: GraphModule,
root_m: torch.nn.Module,
split_callback: Callable[[torch.fx.node.Node], int],
):
partitions: Dict[str, Partition] = {}
orig_nodes: Dict[str, torch.fx.node.Node] = {}

def record_cross_partition_use(def_node : torch.fx.node.Node, use_node : Optional[torch.fx.node.Node]):
def_partition_name = getattr(def_node, '_fx_partition', None)
use_partition_name = getattr(use_node, '_fx_partition', None)
if def_partition_name != use_partition_name:
if def_partition_name is not None:
def_partition = partitions[def_partition_name]
def_partition.outputs.add(def_node.name)
if use_partition_name is not None:
def_partition.partition_dependents.add(use_partition_name)

if use_partition_name is not None:
use_partition = partitions[use_partition_name]
use_partition.inputs.add(def_node.name)
if def_partition_name is not None:
use_partition.partitions_dependent_on.add(def_partition_name)

# split nodes into parititons
for node in m.graph.nodes:
orig_nodes[node.name] = node

# TODO currently placeholders/parameters aren't put into random partitions,
# rather they're added to the graphs where they are used down below
if node.op in ["placeholder", "get_attr"]:
continue
partition_name = str(split_callback(node))

# add node to partitions
partition = partitions.get(partition_name)
if partition is None:
partitions[partition_name] = partition = Partition(partition_name)

partition.node_names.append(node.name)
node._fx_partition = partition_name

torch.fx.graph.map_arg(node.args, lambda def_node: record_cross_partition_use(def_node, node))
torch.fx.graph.map_arg(node.kwargs, lambda def_node: record_cross_partition_use(def_node, node))

torch.fx.graph.map_arg(m.graph.result, lambda n: record_cross_partition_use(n, None))

# find partitions with no dependencies
root_partitions : List[str] = []
for partition_name, partition in partitions.items():
if not len(partition.partitions_dependent_on):
root_partitions.append(partition_name)

# check partitions for circular dependencies and create topological partition ordering
sorted_partitions : List[str] = []
while root_partitions:
root_partition = root_partitions.pop()
sorted_partitions.append(root_partition)
for dependent in partitions[root_partition].partition_dependents:
partitions[dependent].partitions_dependent_on.remove(root_partition)
if not partitions[dependent].partitions_dependent_on:
root_partitions.append(dependent)
if len(sorted_partitions) != len(partitions):
raise RuntimeError("cycle exists between partitions!")

# add placeholders to parititons
for partition_name in sorted_partitions:
partition = partitions[partition_name]
for input in partition.inputs:
placeholder = partition.graph.placeholder(input)
partition.environment[orig_nodes[input]] = placeholder

# Transform nodes and collect targets for partition's submodule
for node in m.graph.nodes:
if hasattr(node, '_fx_partition'):
partition = partitions[node._fx_partition]

# swap out old graph nodes in kw/args with references to new nodes in this submodule
environment = partition.environment
gathered_args = torch.fx.graph.map_arg(node.args, lambda n : environment[n])
gathered_kwargs = torch.fx.graph.map_arg(node.kwargs, lambda n : environment[n])

if node.op not in ['call_module', 'get_attr']:
target = node.target
else:
target_atoms = node.target.split('.')
target_attr = m
for atom in target_atoms:
if not hasattr(target_attr, atom):
raise RuntimeError(f'Operator target {node.target} not found!')
target_attr = getattr(target_attr, atom)
partition.targets[node.target] = target_attr
target = target_atoms[-1]

assert isinstance(gathered_args, tuple)
assert isinstance(gathered_kwargs, dict)
new_node = partition.graph.create_node(op=node.op, target=target, args=gathered_args,
kwargs=gathered_kwargs)
partition.environment[node] = new_node

# Set up values to construct base module
base_mod_env : Dict[str, torch.fx.node.Node] = {}
base_mod_graph : torch.fx.graph.Graph = torch.fx.graph.Graph()
base_mod_attrs : Dict[str, torch.fx.graph_module.GraphModule] = {}
for node in m.graph.nodes:
if node.op == 'placeholder':
base_mod_env[node.name] = base_mod_graph.placeholder(node.name)
elif node.op == 'get_attr':
base_mod_env[node.name] = base_mod_graph.get_attr(node.target)
attr_val = m
for atom in node.target.split('.'):
if not hasattr(attr_val, atom):
raise RuntimeError(f'Node target {node.target} not found!')
attr_val = getattr(attr_val, atom)
base_mod_attrs[node.target] = attr_val

# Do some things iterating over the partitions in topological order again:
# 1) Finish off submodule Graphs by setting corresponding outputs
# 2) Construct GraphModules for each submodule
# 3) Construct the base graph by emitting calls to those submodules in
# topological order

for partition_name in sorted_partitions:
partition = partitions[partition_name]

# Set correct output values
output_vals = tuple(partition.environment[orig_nodes[name]] for name in partition.outputs)
output_vals = output_vals[0] if len(output_vals) == 1 else output_vals # type: ignore
partition.graph.output(output_vals)

# Construct GraphModule for this partition
submod_name = f'submod_{partition_name}'
base_mod_attrs[submod_name] = torch.fx.graph_module.GraphModule(partition.targets, partition.graph)

# Emit call in base graph to this submodule

output_val = base_mod_graph.call_module(submod_name, [base_mod_env[name] for name in partition.inputs]) # type: ignore
if len(partition.outputs) > 1:
# Unpack multiple return values from submodule
output_val_proxy = torch.fx.proxy.Proxy(output_val)
for i, output_name in enumerate(partition.outputs):
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore
else:
base_mod_env[list(partition.outputs)[0]] = output_val

# Set output value for base graph
base_mod_graph.output(torch.fx.graph.map_arg(m.graph.result, lambda n : base_mod_env[n.name]))

return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph)

0 comments on commit f6dc256

Please sign in to comment.