Skip to content

Commit

Permalink
Document example of Proxy use (#50583)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50583

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D26010501

Pulled By: ansley

fbshipit-source-id: 947121af7e57c16c96f849fbbb3fa83e97d003b2
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Jan 22, 2021
1 parent 89cafde commit d5dc65a
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 3 deletions.
68 changes: 68 additions & 0 deletions torch/fx/examples/inline_function.py
@@ -0,0 +1,68 @@
import torch
from torch.fx import Proxy, symbolic_trace
from torch.fx.node import map_arg


'''
How to inline a function into an existing Graph
One reason you might want to inline a function is to get around FX's
default tracing behavior. For example, unless you've defined a custom
Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes
references to ``torch.nn`` module instances to appear as
``call_module`` calls rather than being traced through. Let's say this
behavior is almost what you need; the only problem is that there's a
single module call that you want to replace with an inlined trace of the
function. Creating a custom Tracer would be too much. Instead, you can
accomplish this using Proxies.
The following code demonstrates how to trace a module and inline it
into an existing Graph using Proxy. We'll trace our Graph, then iterate
through its Nodes until we find the right place to swap out the
``call_module`` Node with an inlined trace. At that point, we'll create
Proxies from the Node's args and kwargs. Finally, we'll call the
function we want to replace with those Proxies--which will, in essence,
"trace" that function. Finally, we'll insert the result of that call
into our Graph. (This last step will automatically inline the function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x) + 1.0

# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
for node in m.graph.nodes:
# Find `call_module` Node in `m` that corresponds to `self.relu`.
# This is the Node we want to swap out for an inlined version of the
# same call
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# Create a Proxy from each Node in the current Node's
# args/kwargs
proxy_args = map_arg(node.args, Proxy)
proxy_kwargs = map_arg(node.kwargs, Proxy)
# Call `m.relu` with the newly-created Proxy arguments.
# `m.relu` is the generic version of the function; by
# calling it with Proxies created from Nodes in `m`, we're
# emitting Nodes that reference exiting values in the IR.
# The result of this call is another Proxy, which we can
# hook into our existing Graph to complete the function
# inlining.
proxy_output = m.relu(*proxy_args, **proxy_kwargs)
# Replace the relu `call_module` node with the inlined
# version of the function
node.replace_all_uses_with(proxy_output.node)
# Make sure that the old relu Node is erased
m.graph.erase_node(node)
56 changes: 56 additions & 0 deletions torch/fx/examples/proxy_based_graph_creation.py
@@ -0,0 +1,56 @@
import torch
from torch.fx import Proxy, Graph, GraphModule


'''
How to create a Graph using Proxy objects instead of tracing
It's possible to directly create a Proxy object around a raw Node. This
can be used to create a Graph independently of symbolic tracing.
The following code demonstrates how to use Proxy with a raw Node to
append operations to a fresh Graph. We'll create two parameters (``x``
and ``y``), perform some operations on those parameters, then add
everything we created to the new Graph. We'll then wrap that Graph in
a GraphModule. Doing so creates a runnable instance of ``nn.Module``
where previously-created operations are represented in the Module's
``forward`` function.
By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.
.. code-block:: python
def forward(self, x, y):
cat_1 = torch.cat([x, y]); x = y = None
tanh_1 = torch.tanh(cat_1); cat_1 = None
neg_1 = torch.neg(tanh_1); tanh_1 = None
return neg_1
'''


# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
5 changes: 2 additions & 3 deletions torch/fx/examples/replace_op.py
Expand Up @@ -23,9 +23,8 @@
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, see the Nodes in a tabular format by adding
`from inspect_utils import print_IR` to the top of this file and calling
`print_IR(traced.graph)`.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""

# Sample module
Expand Down
77 changes: 77 additions & 0 deletions torch/fx/examples/wrap_output_dynamically.py
@@ -0,0 +1,77 @@

import torch
from torch.fx import Proxy, GraphModule, Node, symbolic_trace

from enum import Enum, auto

'''
Wrap Graph output dynamically
The following code demonstrates how change an existing Graph based on
parameters specified at runtime. We'll let the user specify an
activation function from a predefined Enum list, then we'll symbolically
trace it. Next, we'll create a Proxy from the last operation in the
Graph. We'll call our traced activation function with this Proxy and
insert the ``output`` Node from that call into our Graph. (This final
step will automatically inline the entire traced function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
y = torch.cat([x, y])
return y

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Selected activation functions
class ActivationFunction(Enum):
RELU = auto()
LEAKY_RELU = auto()
PRELU = auto()

# Map activation function names to their implementation
activation_functions = {
ActivationFunction.RELU: torch.nn.ReLU(),
ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(),
ActivationFunction.PRELU: torch.nn.PReLU(),
}

def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule:
# Get output node
output_node: Optional[Node] = None
for n in reversed(m.graph.nodes):
if n.op == "output":
output_node = n
break
assert output_node

# Get the actual output (the "input" of the output node). This is
# the Node we want to wrap in a user-specified activation function
assert len(output_node.all_input_nodes) == 1
wrap_node = output_node.all_input_nodes[0]

# Wrap the actual output in a Proxy
wrap_proxy = Proxy(wrap_node)

# Get the implementation of the specified activation function and
# symbolically trace it
fn_impl = activation_functions[fn]
fn_impl_traced = symbolic_trace(fn_impl)

# Call the specified activation function using the Proxy wrapper for
# `output_op`. The result of this call is another Proxy, which we
# can hook into our existing Graph.
with traced.graph.inserting_before(wrap_node):
fn_impl_output_node = fn_impl_traced(wrap_proxy)
new_args = (fn_impl_output_node.node,)
output_node.args = new_args


# Example call
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)

0 comments on commit d5dc65a

Please sign in to comment.