Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Document single op replacement (#50116)
Summary: Pull Request resolved: #50116 Test Plan: Imported from OSS Reviewed By: jamesr66a Differential Revision: D25803457 Pulled By: ansley fbshipit-source-id: de2f3c0bd037859117dde55ba677fb5da34ab639
- Loading branch information
1 parent
ea087e2
commit ba1ce71
Showing
2 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from tabulate import tabulate | ||
|
||
""" | ||
The methods in this file may be used to examine the state of the code | ||
and how the Graph evolves at any time during execution. If you're | ||
unsure of what's happening in an example in this folder, try adding one | ||
of these methods before and after a key line. | ||
""" | ||
|
||
def print_IR(graph): | ||
node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] | ||
for n in graph.nodes] | ||
print(tabulate(node_specs, | ||
headers=['opcode', 'name', 'target', 'args', 'kwargs'])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import torch | ||
from torch.fx import symbolic_trace | ||
import operator | ||
|
||
|
||
""" | ||
How to replace one op with another | ||
1. Iterate through all Nodes in your GraphModule's Graph. | ||
2. Determine if the current Node should be replaced. (Suggested: match | ||
on the Node's ``target`` attribute). | ||
3. Create a replacement Node and add it to the Graph. | ||
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of | ||
the current Node with the replacement. | ||
5. Delete the old Node from the graph. | ||
6. Call ``recompile`` on the GraphModule. This updates the generated | ||
Python code to reflect the new Graph state. | ||
Currently, FX does not provide any way to guarantee that replaced | ||
operators are syntactically valid. It's up to the user to confirm that | ||
any new operators will work with the existing operands. | ||
The following code demonstrates an example of replacing any instance of | ||
addition with a bitwise AND. | ||
To examine how the Graph evolves during op replacement, add the | ||
statement `print(traced.graph)` after the line you want to inspect. | ||
Alternatively, see the Nodes in a tabular format by adding | ||
`from inspect_utils import print_IR` to the top of this file and calling | ||
`print_IR(traced.graph)`. | ||
""" | ||
|
||
# Sample module | ||
class M(torch.nn.Module): | ||
def forward(self, x, y): | ||
return x + y, torch.add(x, y), x.add(y) | ||
|
||
# Symbolically trace an instance of the module | ||
traced = symbolic_trace(M()) | ||
|
||
# As demonstrated in the above example, there are several different ways | ||
# to denote addition. The possible cases are: | ||
# 1. `x + y` - A `call_function` Node with target | ||
# `<built-in function add>`. This is `operator.add`, so we can | ||
# match on equality with that function directly. | ||
# 2. `torch.add(x, y)` - A `call_function` Node with target | ||
# `<built-in method add of type object at MEMORY-LOCATION-OF-TORCH>`. | ||
# This is `torch.add`, which we can similarly match directly. | ||
# 3. `x.add(y)` - The Tensor method call, whose target we can match | ||
# as a string. | ||
|
||
patterns = set([operator.add, torch.add, "add"]) | ||
|
||
# Go through all the nodes in the Graph | ||
for n in traced.graph.nodes: | ||
# If the target matches one of the patterns | ||
if any(n.target == pattern for pattern in patterns): | ||
# Set the insert point, add the new node, and replace all uses | ||
# of `n` with the new node | ||
with traced.graph.inserting_after(n): | ||
new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) | ||
n.replace_all_uses_with(new_node) | ||
# Remove the old node from the graph | ||
traced.graph.erase_node(n) | ||
|
||
# Don't forget to recompile! | ||
traced.recompile() |