diff --git a/torch/fx/examples/inspect_utils.py b/torch/fx/examples/inspect_utils.py new file mode 100644 index 000000000000..26833c1c6e14 --- /dev/null +++ b/torch/fx/examples/inspect_utils.py @@ -0,0 +1,14 @@ +from tabulate import tabulate + +""" +The methods in this file may be used to examine the state of the code +and how the Graph evolves at any time during execution. If you're +unsure of what's happening in an example in this folder, try adding one +of these methods before and after a key line. +""" + +def print_IR(graph): + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] + for n in graph.nodes] + print(tabulate(node_specs, + headers=['opcode', 'name', 'target', 'args', 'kwargs'])) diff --git a/torch/fx/examples/replace_op.py b/torch/fx/examples/replace_op.py new file mode 100644 index 000000000000..22931f92de1e --- /dev/null +++ b/torch/fx/examples/replace_op.py @@ -0,0 +1,66 @@ +import torch +from torch.fx import symbolic_trace +import operator + + +""" +How to replace one op with another +1. Iterate through all Nodes in your GraphModule's Graph. +2. Determine if the current Node should be replaced. (Suggested: match +on the Node's ``target`` attribute). +3. Create a replacement Node and add it to the Graph. +4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of +the current Node with the replacement. +5. Delete the old Node from the graph. +6. Call ``recompile`` on the GraphModule. This updates the generated +Python code to reflect the new Graph state. + +Currently, FX does not provide any way to guarantee that replaced +operators are syntactically valid. It's up to the user to confirm that +any new operators will work with the existing operands. + +The following code demonstrates an example of replacing any instance of +addition with a bitwise AND. + +To examine how the Graph evolves during op replacement, add the +statement `print(traced.graph)` after the line you want to inspect. +Alternatively, see the Nodes in a tabular format by adding +`from inspect_utils import print_IR` to the top of this file and calling +`print_IR(traced.graph)`. +""" + +# Sample module +class M(torch.nn.Module): + def forward(self, x, y): + return x + y, torch.add(x, y), x.add(y) + +# Symbolically trace an instance of the module +traced = symbolic_trace(M()) + +# As demonstrated in the above example, there are several different ways +# to denote addition. The possible cases are: +# 1. `x + y` - A `call_function` Node with target +# ``. This is `operator.add`, so we can +# match on equality with that function directly. +# 2. `torch.add(x, y)` - A `call_function` Node with target +# ``. +# This is `torch.add`, which we can similarly match directly. +# 3. `x.add(y)` - The Tensor method call, whose target we can match +# as a string. + +patterns = set([operator.add, torch.add, "add"]) + +# Go through all the nodes in the Graph +for n in traced.graph.nodes: + # If the target matches one of the patterns + if any(n.target == pattern for pattern in patterns): + # Set the insert point, add the new node, and replace all uses + # of `n` with the new node + with traced.graph.inserting_after(n): + new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) + n.replace_all_uses_with(new_node) + # Remove the old node from the graph + traced.graph.erase_node(n) + +# Don't forget to recompile! +traced.recompile()