Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document example of Proxy use (#50583)
Summary: Pull Request resolved: #50583 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26010501 Pulled By: ansley fbshipit-source-id: 947121af7e57c16c96f849fbbb3fa83e97d003b2
- Loading branch information
1 parent
89cafde
commit d5dc65a
Showing
4 changed files
with
203 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import torch | ||
from torch.fx import Proxy, symbolic_trace | ||
from torch.fx.node import map_arg | ||
|
||
|
||
''' | ||
How to inline a function into an existing Graph | ||
One reason you might want to inline a function is to get around FX's | ||
default tracing behavior. For example, unless you've defined a custom | ||
Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes | ||
references to ``torch.nn`` module instances to appear as | ||
``call_module`` calls rather than being traced through. Let's say this | ||
behavior is almost what you need; the only problem is that there's a | ||
single module call that you want to replace with an inlined trace of the | ||
function. Creating a custom Tracer would be too much. Instead, you can | ||
accomplish this using Proxies. | ||
The following code demonstrates how to trace a module and inline it | ||
into an existing Graph using Proxy. We'll trace our Graph, then iterate | ||
through its Nodes until we find the right place to swap out the | ||
``call_module`` Node with an inlined trace. At that point, we'll create | ||
Proxies from the Node's args and kwargs. Finally, we'll call the | ||
function we want to replace with those Proxies--which will, in essence, | ||
"trace" that function. Finally, we'll insert the result of that call | ||
into our Graph. (This last step will automatically inline the function.) | ||
''' | ||
|
||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.relu = torch.nn.ReLU() | ||
|
||
def forward(self, x): | ||
return self.relu(x) + 1.0 | ||
|
||
# Symbolically trace an instance of `M`. After tracing, `self.relu` is | ||
# represented as a `call_module` Node. The full operation in the | ||
# generated `forward` function's code will appear as `self.relu(x)` | ||
m = symbolic_trace(M()) | ||
|
||
# Insert nodes from the ReLU graph in place of the original call to | ||
# `self.relu` | ||
for node in m.graph.nodes: | ||
# Find `call_module` Node in `m` that corresponds to `self.relu`. | ||
# This is the Node we want to swap out for an inlined version of the | ||
# same call | ||
if (node.op, node.target) == ("call_module", "relu"): | ||
with m.graph.inserting_before(node): | ||
# Create a Proxy from each Node in the current Node's | ||
# args/kwargs | ||
proxy_args = map_arg(node.args, Proxy) | ||
proxy_kwargs = map_arg(node.kwargs, Proxy) | ||
# Call `m.relu` with the newly-created Proxy arguments. | ||
# `m.relu` is the generic version of the function; by | ||
# calling it with Proxies created from Nodes in `m`, we're | ||
# emitting Nodes that reference exiting values in the IR. | ||
# The result of this call is another Proxy, which we can | ||
# hook into our existing Graph to complete the function | ||
# inlining. | ||
proxy_output = m.relu(*proxy_args, **proxy_kwargs) | ||
# Replace the relu `call_module` node with the inlined | ||
# version of the function | ||
node.replace_all_uses_with(proxy_output.node) | ||
# Make sure that the old relu Node is erased | ||
m.graph.erase_node(node) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import torch | ||
from torch.fx import Proxy, Graph, GraphModule | ||
|
||
|
||
''' | ||
How to create a Graph using Proxy objects instead of tracing | ||
It's possible to directly create a Proxy object around a raw Node. This | ||
can be used to create a Graph independently of symbolic tracing. | ||
The following code demonstrates how to use Proxy with a raw Node to | ||
append operations to a fresh Graph. We'll create two parameters (``x`` | ||
and ``y``), perform some operations on those parameters, then add | ||
everything we created to the new Graph. We'll then wrap that Graph in | ||
a GraphModule. Doing so creates a runnable instance of ``nn.Module`` | ||
where previously-created operations are represented in the Module's | ||
``forward`` function. | ||
By the end of the tutorial, we'll have added the following method to an | ||
empty ``nn.Module`` class. | ||
.. code-block:: python | ||
def forward(self, x, y): | ||
cat_1 = torch.cat([x, y]); x = y = None | ||
tanh_1 = torch.tanh(cat_1); cat_1 = None | ||
neg_1 = torch.neg(tanh_1); tanh_1 = None | ||
return neg_1 | ||
''' | ||
|
||
|
||
# Create a graph independently of symbolic tracing | ||
graph = Graph() | ||
|
||
# Create raw Nodes | ||
raw1 = graph.placeholder('x') | ||
raw2 = graph.placeholder('y') | ||
|
||
# Initialize Proxies using the raw Nodes | ||
y = Proxy(raw1) | ||
z = Proxy(raw2) | ||
|
||
# Create other operations using the Proxies `y` and `z` | ||
a = torch.cat([y, z]) | ||
b = torch.tanh(a) | ||
c = torch.neg(b) | ||
|
||
# Create a new output Node and add it to the Graph. By doing this, the | ||
# Graph will contain all the Nodes we just created (since they're all | ||
# linked to the output Node) | ||
graph.output(c.node) | ||
|
||
# Wrap our created Graph in a GraphModule to get a final, runnable | ||
# `nn.Module` instance | ||
mod = GraphModule(torch.nn.Module(), graph) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
|
||
import torch | ||
from torch.fx import Proxy, GraphModule, Node, symbolic_trace | ||
|
||
from enum import Enum, auto | ||
|
||
''' | ||
Wrap Graph output dynamically | ||
The following code demonstrates how change an existing Graph based on | ||
parameters specified at runtime. We'll let the user specify an | ||
activation function from a predefined Enum list, then we'll symbolically | ||
trace it. Next, we'll create a Proxy from the last operation in the | ||
Graph. We'll call our traced activation function with this Proxy and | ||
insert the ``output`` Node from that call into our Graph. (This final | ||
step will automatically inline the entire traced function.) | ||
''' | ||
|
||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def forward(self, x, y): | ||
y = torch.cat([x, y]) | ||
return y | ||
|
||
# Symbolically trace an instance of `M` | ||
traced = symbolic_trace(M()) | ||
|
||
# Selected activation functions | ||
class ActivationFunction(Enum): | ||
RELU = auto() | ||
LEAKY_RELU = auto() | ||
PRELU = auto() | ||
|
||
# Map activation function names to their implementation | ||
activation_functions = { | ||
ActivationFunction.RELU: torch.nn.ReLU(), | ||
ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(), | ||
ActivationFunction.PRELU: torch.nn.PReLU(), | ||
} | ||
|
||
def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule: | ||
# Get output node | ||
output_node: Optional[Node] = None | ||
for n in reversed(m.graph.nodes): | ||
if n.op == "output": | ||
output_node = n | ||
break | ||
assert output_node | ||
|
||
# Get the actual output (the "input" of the output node). This is | ||
# the Node we want to wrap in a user-specified activation function | ||
assert len(output_node.all_input_nodes) == 1 | ||
wrap_node = output_node.all_input_nodes[0] | ||
|
||
# Wrap the actual output in a Proxy | ||
wrap_proxy = Proxy(wrap_node) | ||
|
||
# Get the implementation of the specified activation function and | ||
# symbolically trace it | ||
fn_impl = activation_functions[fn] | ||
fn_impl_traced = symbolic_trace(fn_impl) | ||
|
||
# Call the specified activation function using the Proxy wrapper for | ||
# `output_op`. The result of this call is another Proxy, which we | ||
# can hook into our existing Graph. | ||
with traced.graph.inserting_before(wrap_node): | ||
fn_impl_output_node = fn_impl_traced(wrap_proxy) | ||
new_args = (fn_impl_output_node.node,) | ||
output_node.args = new_args | ||
|
||
|
||
# Example call | ||
wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU) |