From d5dc65a45c45570b24cd29ab97d81f94a7743343 Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Fri, 22 Jan 2021 10:59:16 -0800 Subject: [PATCH] Document example of Proxy use (#50583) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/50583 Test Plan: Imported from OSS Reviewed By: pbelevich Differential Revision: D26010501 Pulled By: ansley fbshipit-source-id: 947121af7e57c16c96f849fbbb3fa83e97d003b2 --- torch/fx/examples/inline_function.py | 68 ++++++++++++++++ .../fx/examples/proxy_based_graph_creation.py | 56 ++++++++++++++ torch/fx/examples/replace_op.py | 5 +- torch/fx/examples/wrap_output_dynamically.py | 77 +++++++++++++++++++ 4 files changed, 203 insertions(+), 3 deletions(-) create mode 100644 torch/fx/examples/inline_function.py create mode 100644 torch/fx/examples/proxy_based_graph_creation.py create mode 100644 torch/fx/examples/wrap_output_dynamically.py diff --git a/torch/fx/examples/inline_function.py b/torch/fx/examples/inline_function.py new file mode 100644 index 000000000000..5a2b2057f8d8 --- /dev/null +++ b/torch/fx/examples/inline_function.py @@ -0,0 +1,68 @@ +import torch +from torch.fx import Proxy, symbolic_trace +from torch.fx.node import map_arg + + +''' +How to inline a function into an existing Graph + +One reason you might want to inline a function is to get around FX's +default tracing behavior. For example, unless you've defined a custom +Tracer, the out-of-the-box implementation of ``symbolic_trace`` causes +references to ``torch.nn`` module instances to appear as +``call_module`` calls rather than being traced through. Let's say this +behavior is almost what you need; the only problem is that there's a +single module call that you want to replace with an inlined trace of the +function. Creating a custom Tracer would be too much. Instead, you can +accomplish this using Proxies. + +The following code demonstrates how to trace a module and inline it +into an existing Graph using Proxy. We'll trace our Graph, then iterate +through its Nodes until we find the right place to swap out the +``call_module`` Node with an inlined trace. At that point, we'll create +Proxies from the Node's args and kwargs. Finally, we'll call the +function we want to replace with those Proxies--which will, in essence, +"trace" that function. Finally, we'll insert the result of that call +into our Graph. (This last step will automatically inline the function.) +''' + + +# Sample module +class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, x): + return self.relu(x) + 1.0 + +# Symbolically trace an instance of `M`. After tracing, `self.relu` is +# represented as a `call_module` Node. The full operation in the +# generated `forward` function's code will appear as `self.relu(x)` +m = symbolic_trace(M()) + +# Insert nodes from the ReLU graph in place of the original call to +# `self.relu` +for node in m.graph.nodes: + # Find `call_module` Node in `m` that corresponds to `self.relu`. + # This is the Node we want to swap out for an inlined version of the + # same call + if (node.op, node.target) == ("call_module", "relu"): + with m.graph.inserting_before(node): + # Create a Proxy from each Node in the current Node's + # args/kwargs + proxy_args = map_arg(node.args, Proxy) + proxy_kwargs = map_arg(node.kwargs, Proxy) + # Call `m.relu` with the newly-created Proxy arguments. + # `m.relu` is the generic version of the function; by + # calling it with Proxies created from Nodes in `m`, we're + # emitting Nodes that reference exiting values in the IR. + # The result of this call is another Proxy, which we can + # hook into our existing Graph to complete the function + # inlining. + proxy_output = m.relu(*proxy_args, **proxy_kwargs) + # Replace the relu `call_module` node with the inlined + # version of the function + node.replace_all_uses_with(proxy_output.node) + # Make sure that the old relu Node is erased + m.graph.erase_node(node) diff --git a/torch/fx/examples/proxy_based_graph_creation.py b/torch/fx/examples/proxy_based_graph_creation.py new file mode 100644 index 000000000000..bb120f9e0b92 --- /dev/null +++ b/torch/fx/examples/proxy_based_graph_creation.py @@ -0,0 +1,56 @@ +import torch +from torch.fx import Proxy, Graph, GraphModule + + +''' +How to create a Graph using Proxy objects instead of tracing + +It's possible to directly create a Proxy object around a raw Node. This +can be used to create a Graph independently of symbolic tracing. + +The following code demonstrates how to use Proxy with a raw Node to +append operations to a fresh Graph. We'll create two parameters (``x`` +and ``y``), perform some operations on those parameters, then add +everything we created to the new Graph. We'll then wrap that Graph in +a GraphModule. Doing so creates a runnable instance of ``nn.Module`` +where previously-created operations are represented in the Module's +``forward`` function. + +By the end of the tutorial, we'll have added the following method to an +empty ``nn.Module`` class. + +.. code-block:: python + + def forward(self, x, y): + cat_1 = torch.cat([x, y]); x = y = None + tanh_1 = torch.tanh(cat_1); cat_1 = None + neg_1 = torch.neg(tanh_1); tanh_1 = None + return neg_1 + +''' + + +# Create a graph independently of symbolic tracing +graph = Graph() + +# Create raw Nodes +raw1 = graph.placeholder('x') +raw2 = graph.placeholder('y') + +# Initialize Proxies using the raw Nodes +y = Proxy(raw1) +z = Proxy(raw2) + +# Create other operations using the Proxies `y` and `z` +a = torch.cat([y, z]) +b = torch.tanh(a) +c = torch.neg(b) + +# Create a new output Node and add it to the Graph. By doing this, the +# Graph will contain all the Nodes we just created (since they're all +# linked to the output Node) +graph.output(c.node) + +# Wrap our created Graph in a GraphModule to get a final, runnable +# `nn.Module` instance +mod = GraphModule(torch.nn.Module(), graph) diff --git a/torch/fx/examples/replace_op.py b/torch/fx/examples/replace_op.py index f938ecc0f56b..9987dfa0325d 100644 --- a/torch/fx/examples/replace_op.py +++ b/torch/fx/examples/replace_op.py @@ -23,9 +23,8 @@ To examine how the Graph evolves during op replacement, add the statement `print(traced.graph)` after the line you want to inspect. -Alternatively, see the Nodes in a tabular format by adding -`from inspect_utils import print_IR` to the top of this file and calling -`print_IR(traced.graph)`. +Alternatively, call `traced.graph.print_tabular()` to see the IR in a +tabular format. """ # Sample module diff --git a/torch/fx/examples/wrap_output_dynamically.py b/torch/fx/examples/wrap_output_dynamically.py new file mode 100644 index 000000000000..3d9943c525d7 --- /dev/null +++ b/torch/fx/examples/wrap_output_dynamically.py @@ -0,0 +1,77 @@ + +import torch +from torch.fx import Proxy, GraphModule, Node, symbolic_trace + +from enum import Enum, auto + +''' +Wrap Graph output dynamically + +The following code demonstrates how change an existing Graph based on +parameters specified at runtime. We'll let the user specify an +activation function from a predefined Enum list, then we'll symbolically +trace it. Next, we'll create a Proxy from the last operation in the +Graph. We'll call our traced activation function with this Proxy and +insert the ``output`` Node from that call into our Graph. (This final +step will automatically inline the entire traced function.) +''' + + +# Sample module +class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + y = torch.cat([x, y]) + return y + +# Symbolically trace an instance of `M` +traced = symbolic_trace(M()) + +# Selected activation functions +class ActivationFunction(Enum): + RELU = auto() + LEAKY_RELU = auto() + PRELU = auto() + +# Map activation function names to their implementation +activation_functions = { + ActivationFunction.RELU: torch.nn.ReLU(), + ActivationFunction.LEAKY_RELU: torch.nn.LeakyReLU(), + ActivationFunction.PRELU: torch.nn.PReLU(), +} + +def wrap_in_activation_function(m: GraphModule, fn: ActivationFunction) -> GraphModule: + # Get output node + output_node: Optional[Node] = None + for n in reversed(m.graph.nodes): + if n.op == "output": + output_node = n + break + assert output_node + + # Get the actual output (the "input" of the output node). This is + # the Node we want to wrap in a user-specified activation function + assert len(output_node.all_input_nodes) == 1 + wrap_node = output_node.all_input_nodes[0] + + # Wrap the actual output in a Proxy + wrap_proxy = Proxy(wrap_node) + + # Get the implementation of the specified activation function and + # symbolically trace it + fn_impl = activation_functions[fn] + fn_impl_traced = symbolic_trace(fn_impl) + + # Call the specified activation function using the Proxy wrapper for + # `output_op`. The result of this call is another Proxy, which we + # can hook into our existing Graph. + with traced.graph.inserting_before(wrap_node): + fn_impl_output_node = fn_impl_traced(wrap_proxy) + new_args = (fn_impl_output_node.node,) + output_node.args = new_args + + +# Example call +wrap_in_activation_function(traced, ActivationFunction.LEAKY_RELU)