Skip to content

Commit

Permalink
Document single op replacement (#50116)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #50116

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D25803457

Pulled By: ansley

fbshipit-source-id: de2f3c0bd037859117dde55ba677fb5da34ab639
  • Loading branch information
Ansley Ussery authored and facebook-github-bot committed Jan 9, 2021
1 parent ea087e2 commit ba1ce71
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 0 deletions.
14 changes: 14 additions & 0 deletions torch/fx/examples/inspect_utils.py
@@ -0,0 +1,14 @@
from tabulate import tabulate

"""
The methods in this file may be used to examine the state of the code
and how the Graph evolves at any time during execution. If you're
unsure of what's happening in an example in this folder, try adding one
of these methods before and after a key line.
"""

def print_IR(graph):
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs]
for n in graph.nodes]
print(tabulate(node_specs,
headers=['opcode', 'name', 'target', 'args', 'kwargs']))
66 changes: 66 additions & 0 deletions torch/fx/examples/replace_op.py
@@ -0,0 +1,66 @@
import torch
from torch.fx import symbolic_trace
import operator


"""
How to replace one op with another
1. Iterate through all Nodes in your GraphModule's Graph.
2. Determine if the current Node should be replaced. (Suggested: match
on the Node's ``target`` attribute).
3. Create a replacement Node and add it to the Graph.
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of
the current Node with the replacement.
5. Delete the old Node from the graph.
6. Call ``recompile`` on the GraphModule. This updates the generated
Python code to reflect the new Graph state.
Currently, FX does not provide any way to guarantee that replaced
operators are syntactically valid. It's up to the user to confirm that
any new operators will work with the existing operands.
The following code demonstrates an example of replacing any instance of
addition with a bitwise AND.
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, see the Nodes in a tabular format by adding
`from inspect_utils import print_IR` to the top of this file and calling
`print_IR(traced.graph)`.
"""

# Sample module
class M(torch.nn.Module):
def forward(self, x, y):
return x + y, torch.add(x, y), x.add(y)

# Symbolically trace an instance of the module
traced = symbolic_trace(M())

# As demonstrated in the above example, there are several different ways
# to denote addition. The possible cases are:
# 1. `x + y` - A `call_function` Node with target
# `<built-in function add>`. This is `operator.add`, so we can
# match on equality with that function directly.
# 2. `torch.add(x, y)` - A `call_function` Node with target
# `<built-in method add of type object at MEMORY-LOCATION-OF-TORCH>`.
# This is `torch.add`, which we can similarly match directly.
# 3. `x.add(y)` - The Tensor method call, whose target we can match
# as a string.

patterns = set([operator.add, torch.add, "add"])

# Go through all the nodes in the Graph
for n in traced.graph.nodes:
# If the target matches one of the patterns
if any(n.target == pattern for pattern in patterns):
# Set the insert point, add the new node, and replace all uses
# of `n` with the new node
with traced.graph.inserting_after(n):
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
n.replace_all_uses_with(new_node)
# Remove the old node from the graph
traced.graph.erase_node(n)

# Don't forget to recompile!
traced.recompile()

0 comments on commit ba1ce71

Please sign in to comment.