From d91364e7562ea09170b7488d6cc4602d9ae14483 Mon Sep 17 00:00:00 2001 From: Ansley Ussery Date: Tue, 5 Jan 2021 14:30:50 -0800 Subject: [PATCH] Document single op replacement ghstack-source-id: 4e40784cfdfa1268a9cfb98c8c68734e19291a3f Pull Request resolved: https://github.com/pytorch/pytorch/pull/50116 --- torch/fx/examples/replace_op.py | 57 +++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 torch/fx/examples/replace_op.py diff --git a/torch/fx/examples/replace_op.py b/torch/fx/examples/replace_op.py new file mode 100644 index 000000000000..c80408d53de1 --- /dev/null +++ b/torch/fx/examples/replace_op.py @@ -0,0 +1,57 @@ +import torch +from torch.fx import symbolic_trace +import re + + +""" +How to replace one op with another +1. Iterate through all nodes in your GraphModule's Graph. +2. Determine if the current node should be replaced. (Suggested: match +on the node's ``target`` attribute). +3. Create a replacement node. +4. Use the fx built-in ``replace_all_uses_with`` to replace the current +node with the replacement. +5. Call `recompile` on the GraphModule. This updates the generated +Python code to reflect the new Graph state. + +The following code demonstrates an example of replacing any instance of +addition with a bitwise AND. +""" + +# Sample module +class M(torch.nn.Module): + def forward(self, x): + val = torch.neg(x) + torch.relu(x) + return torch.add(val, val) + +# Symbolically trace an instance of the module +traced = symbolic_trace(M()) + +# If you have ``torch`` imported, you have two ways to denonte addition: +# the ``+`` operator and the method ``torch.add``. ``+`` is a Python +# built-in, so it's represented as a ``call_function`` Node with a target +# of ````. ``torch.add`` is part of ``torch``, so +# it becomes a ``call_method`` Node with a target of +# ````. +# +# To determine whether a given node represents addition, we can match +# on the ``target`` attribute. In this particular case, we have two +# different representations of "addition", so we'll use a regex to match +# on ``target``. (If you want to replace a ``torch``-specific operator, +# you can match on a simple string.) +regexp = re.compile(r"(?<=[\s])add(?=[\>\s])") + +# Go through all the nodes in the Graph +for n in traced.graph.nodes: + # If the target matches the regex + if regexp.search(str(n.target)): + # Create a replacement node with the new op + new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs) + # Move the new node to the correct spot + n.append(new_node) + # Replace all uses of `n` with the new node + n.replace_all_uses_with(new_node) + # Remove the old node from the graph + traced.graph.erase_node(n) + +traced.recompile()