Skip to content

Commit

Permalink
Document example of Proxy use
Browse files Browse the repository at this point in the history
ghstack-source-id: 66a8c17645b9ae3e1be91c933691f8614d1b8738
Pull Request resolved: #50583
  • Loading branch information
Ansley Ussery committed Jan 21, 2021
1 parent 480bb7d commit 710b191
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 3 deletions.
68 changes: 68 additions & 0 deletions torch/fx/examples/inline_function.py
@@ -0,0 +1,68 @@
import torch
from torch.fx import Proxy, symbolic_trace
from torch.fx.node import map_arg


'''
How to inline a function into an existing Graph
If you're using the default Tracer to symbolically trace, ``torch.nn``
module instances will appear as ``call_module`` calls rather than being
traced through. Let's say this behavior is almost what you need; the
only problem is that there's a single module call that you want to
replace with an inlined trace of the function. Creating a custom Tracer
would be too much. Instead, you can accomplish this using Proxies.
The following code demonstrates how to trace a module and inline it
into an existing Graph using Proxy. We'll trace our Graph, trace our
function, then iterate through our Graph's list of Nodes until we find
the right place to swap out the ``call_module`` Node with an inlined
trace. At that point, we'll create Proxies from the Node's args and
kwargs. Finally, we'll call our traced module with those Proxies and
insert the ``output`` Node from that call into our Graph. (This final
step will automatically inline the entire traced function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x) + 1.0

# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Trace through the ReLU module only. This allows us to get its Graph
# to inline
traced_relu = symbolic_trace(m.relu)

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
for node in m.graph.nodes:
# Find `call_module` Node in `m` that corresponds to `self.relu`.
# This is the Node we want to swap out for an inlined version of the
# same call
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# Create a Proxy from each Node in the current Node's
# args/kwargs
proxy_args = map_arg(node.args, Proxy)
proxy_kwargs = map_arg(node.kwargs, Proxy)
# Call `traced_relu` with the newly-created Proxy arguments.
# `traced_relu` is the generic version of the function; by
# calling it with Proxies created from Nodes in `m`, we're
# creating a "closure" that includes those Nodes. The result
# of this call is another Proxy, which we can hook into
# our existing Graph to complete the function inlining.
proxy_output = traced_relu(*proxy_args, **proxy_kwargs)
# Replace the relu `call_module` node with the inlined
# version of the function
node.replace_all_uses_with(proxy_output.node)
# Make sure that the old relu Node is erased
m.graph.erase_node(node)
56 changes: 56 additions & 0 deletions torch/fx/examples/proxy_based_graph_creation.py
@@ -0,0 +1,56 @@
import torch
from torch.fx import Proxy, Graph, GraphModule


'''
How to create a Graph using Proxy objects instead of tracing
It's possible to directly create a Proxy object around a raw Node. This
can be used to create a Graph independently of symbolic tracing.
The following code demonstrates how to use Proxy with a raw Node to
append operations to a fresh Graph. We'll create two parameters (``x``
and ``y``), perform some operations on those parameters, then add
everything we created to the new Graph. We'll then wrap that Graph in
a GraphModule. Doing so creates a runnable instance of ``nn.Module``
where previously-created operations are represented in the Module's
``forward`` function.
By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.
.. code-block:: python
def forward(self, x, y):
cat_1 = torch.cat([x, y]); x = y = None
tanh_1 = torch.tanh(cat_1); cat_1 = None
neg_1 = torch.neg(tanh_1); tanh_1 = None
return neg_1
'''


# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
5 changes: 2 additions & 3 deletions torch/fx/examples/replace_op.py
Expand Up @@ -23,9 +23,8 @@
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, see the Nodes in a tabular format by adding
`from inspect_utils import print_IR` to the top of this file and calling
`print_IR(traced.graph)`.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""

# Sample module
Expand Down
73 changes: 73 additions & 0 deletions torch/fx/examples/wrap_output_dynamically.py
@@ -0,0 +1,73 @@

import torch
from torch.fx import Proxy, GraphModule, symbolic_trace

from enum import Enum, auto

'''
Wrap Graph output dynamically
The following code demonstrates how change an existing Graph based on
parameters specified at runtime. We'll let the user specify an
activation function from a predefined Enum list, then we'll symbolically
trace it. Next, we'll create a Proxy from the last operation in the
Graph. We'll call our traced activation function with this Proxy and
insert the ``output`` Node from that call into our Graph. (This final
step will automatically inline the entire traced function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
y = torch.cat([x, y])
return y

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Selected activation functions
class ActivationFunction(Enum):
RELU = auto()
LEAKY_RELU = auto()
PRELU = auto()

# Map activation function names to their implementation
activation_functions = {
ActivationFunction.RELU: torch.nn.ReLU(),
ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(),
ActivationFunction.PRELU: torch.nn.PReLU(),
}

def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule:
# Get output node
output_node = next(iter(reversed(m.graph.nodes)))
assert output_node.op == "output"

# Get the actual output (the "input" of the output node). This is
# the Node we want to wrap in a user-specified activation function
assert len(output_node.all_input_nodes) == 1
wrap_node = output_node.all_input_nodes[0]

# Wrap the actual output in a Proxy
wrap_proxy = Proxy(wrap_node)

# Get the implementation of the specified activation function and
# symbolically trace it
fn_impl = activation_functions[fn]
fn_impl_traced = symbolic_trace(fn_impl)

# Call the specified activation function using the Proxy wrapper for
# `output_op`. The result of this call is another Proxy, which we
# can hook into our existing Graph.
with traced.graph.inserting_before(wrap_node):
fn_impl_output_node = fn_impl_traced(wrap_proxy)
new_args = (fn_impl_output_node.node,)
output_node.args = new_args


# Example call
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)

0 comments on commit 710b191

Please sign in to comment.