Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document example of Proxy use #50583

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
68 changes: 68 additions & 0 deletions torch/fx/examples/inline_function.py
@@ -0,0 +1,68 @@
import torch
from torch.fx import Proxy, symbolic_trace
from torch.fx.node import map_arg


'''
How to inline a function into an existing Graph

If you're using the default Tracer to symbolically trace, ``torch.nn``
ansley marked this conversation as resolved.
Show resolved Hide resolved
module instances will appear as ``call_module`` calls rather than being
traced through. Let's say this behavior is almost what you need; the
only problem is that there's a single module call that you want to
replace with an inlined trace of the function. Creating a custom Tracer
would be too much. Instead, you can accomplish this using Proxies.

The following code demonstrates how to trace a module and inline it
into an existing Graph using Proxy. We'll trace our Graph, trace our
function, then iterate through our Graph's list of Nodes until we find
the right place to swap out the ``call_module`` Node with an inlined
trace. At that point, we'll create Proxies from the Node's args and
kwargs. Finally, we'll call our traced module with those Proxies and
insert the ``output`` Node from that call into our Graph. (This final
step will automatically inline the entire traced function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.relu = torch.nn.ReLU()

def forward(self, x):
return self.relu(x) + 1.0

# Symbolically trace an instance of `M`. After tracing, `self.relu` is
# represented as a `call_module` Node. The full operation in the
# generated `forward` function's code will appear as `self.relu(x)`
m = symbolic_trace(M())

# Trace through the ReLU module only. This allows us to get its Graph
# to inline
traced_relu = symbolic_trace(m.relu)

# Insert nodes from the ReLU graph in place of the original call to
# `self.relu`
for node in m.graph.nodes:
# Find `call_module` Node in `m` that corresponds to `self.relu`.
# This is the Node we want to swap out for an inlined version of the
# same call
if (node.op, node.target) == ("call_module", "relu"):
with m.graph.inserting_before(node):
# Create a Proxy from each Node in the current Node's
# args/kwargs
proxy_args = map_arg(node.args, Proxy)
proxy_kwargs = map_arg(node.kwargs, Proxy)
# Call `traced_relu` with the newly-created Proxy arguments.
# `traced_relu` is the generic version of the function; by
# calling it with Proxies created from Nodes in `m`, we're
# creating a "closure" that includes those Nodes. The result
ansley marked this conversation as resolved.
Show resolved Hide resolved
# of this call is another Proxy, which we can hook into
# our existing Graph to complete the function inlining.
proxy_output = traced_relu(*proxy_args, **proxy_kwargs)
ansley marked this conversation as resolved.
Show resolved Hide resolved
# Replace the relu `call_module` node with the inlined
# version of the function
node.replace_all_uses_with(proxy_output.node)
# Make sure that the old relu Node is erased
m.graph.erase_node(node)
56 changes: 56 additions & 0 deletions torch/fx/examples/proxy_based_graph_creation.py
@@ -0,0 +1,56 @@
import torch
from torch.fx import Proxy, Graph, GraphModule


'''
How to create a Graph using Proxy objects instead of tracing

It's possible to directly create a Proxy object around a raw Node. This
can be used to create a Graph independently of symbolic tracing.

The following code demonstrates how to use Proxy with a raw Node to
append operations to a fresh Graph. We'll create two parameters (``x``
and ``y``), perform some operations on those parameters, then add
everything we created to the new Graph. We'll then wrap that Graph in
a GraphModule. Doing so creates a runnable instance of ``nn.Module``
where previously-created operations are represented in the Module's
``forward`` function.

By the end of the tutorial, we'll have added the following method to an
empty ``nn.Module`` class.

.. code-block:: python

def forward(self, x, y):
cat_1 = torch.cat([x, y]); x = y = None
tanh_1 = torch.tanh(cat_1); cat_1 = None
neg_1 = torch.neg(tanh_1); tanh_1 = None
return neg_1

'''


# Create a graph independently of symbolic tracing
graph = Graph()

# Create raw Nodes
raw1 = graph.placeholder('x')
raw2 = graph.placeholder('y')

# Initialize Proxies using the raw Nodes
y = Proxy(raw1)
z = Proxy(raw2)

# Create other operations using the Proxies `y` and `z`
a = torch.cat([y, z])
b = torch.tanh(a)
c = torch.neg(b)

# Create a new output Node and add it to the Graph. By doing this, the
# Graph will contain all the Nodes we just created (since they're all
# linked to the output Node)
graph.output(c.node)

# Wrap our created Graph in a GraphModule to get a final, runnable
# `nn.Module` instance
mod = GraphModule(torch.nn.Module(), graph)
5 changes: 2 additions & 3 deletions torch/fx/examples/replace_op.py
Expand Up @@ -23,9 +23,8 @@

To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, see the Nodes in a tabular format by adding
`from inspect_utils import print_IR` to the top of this file and calling
`print_IR(traced.graph)`.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""

# Sample module
Expand Down
73 changes: 73 additions & 0 deletions torch/fx/examples/wrap_output_dynamically.py
@@ -0,0 +1,73 @@

import torch
from torch.fx import Proxy, GraphModule, symbolic_trace

from enum import Enum, auto

'''
Wrap Graph output dynamically

The following code demonstrates how change an existing Graph based on
parameters specified at runtime. We'll let the user specify an
activation function from a predefined Enum list, then we'll symbolically
trace it. Next, we'll create a Proxy from the last operation in the
Graph. We'll call our traced activation function with this Proxy and
insert the ``output`` Node from that call into our Graph. (This final
step will automatically inline the entire traced function.)
'''


# Sample module
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, y):
y = torch.cat([x, y])
return y

# Symbolically trace an instance of `M`
traced = symbolic_trace(M())

# Selected activation functions
class ActivationFunction(Enum):
RELU = auto()
LEAKY_RELU = auto()
PRELU = auto()

# Map activation function names to their implementation
activation_functions = {
ActivationFunction.RELU: torch.nn.ReLU(),
ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(),
ActivationFunction.PRELU: torch.nn.PReLU(),
}

def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule:
# Get output node
output_node = next(iter(reversed(m.graph.nodes)))
assert output_node.op == "output"
ansley marked this conversation as resolved.
Show resolved Hide resolved

# Get the actual output (the "input" of the output node). This is
# the Node we want to wrap in a user-specified activation function
assert len(output_node.all_input_nodes) == 1
wrap_node = output_node.all_input_nodes[0]

# Wrap the actual output in a Proxy
wrap_proxy = Proxy(wrap_node)

# Get the implementation of the specified activation function and
# symbolically trace it
fn_impl = activation_functions[fn]
fn_impl_traced = symbolic_trace(fn_impl)

# Call the specified activation function using the Proxy wrapper for
# `output_op`. The result of this call is another Proxy, which we
# can hook into our existing Graph.
with traced.graph.inserting_before(wrap_node):
fn_impl_output_node = fn_impl_traced(wrap_proxy)
new_args = (fn_impl_output_node.node,)
output_node.args = new_args


# Example call
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)