## Modify a computation graph with torchFX

In [33]:
import torch
import numpy as np
from torch.fx import symbolic_trace
from typing import Dict

class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(10, 10)
        self.linear2 = torch.nn.Linear(10, 10)

    def forward(self, x, y):
        c = self.linear2(x[1] + self.linear((x + y) * 2))
        d = c + 1
        return d
model = MyModel()
gm = symbolic_trace(model)
print(gm)

MyModel(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x, y):
    getitem = x[1]
    add = x + y;  x = y = None
    mul = add * 2;  add = None
    linear = self.linear(mul);  mul = None
    add_1 = getitem + linear;  getitem = linear = None
    linear2 = self.linear2(add_1);  add_1 = None
    add_2 = linear2 + 1;  linear2 = None
    return add_2
    
# To see more debug info, please use `graph_module.print_readable()`


### Rename a node


In [34]:
def rename_node(gm, old_name, new_name):
    for node in gm.graph.nodes:
        if node.name == old_name:
            node.name = new_name
            break
    gm.recompile()
    return gm

gm = symbolic_trace(model)
new_gm = rename_node(gm, "add_2", "ret")
print(new_gm)

MyModel(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x, y):
    getitem = x[1]
    add = x + y;  x = y = None
    mul = add * 2;  add = None
    linear = self.linear(mul);  mul = None
    add_1 = getitem + linear;  getitem = linear = None
    linear2 = self.linear2(add_1);  add_1 = None
    ret = linear2 + 1;  linear2 = None
    return ret
    
# To see more debug info, please use `graph_module.print_readable()`


### Change return variable
Construct a new graph with the same nodes as the original graph with the return variable changed to the specified one.

In [35]:
def modify_return_node(gm, output_node_name):
    new_graph = torch.fx.Graph()
    env: Dict[torch.fx.Node, torch.fx.Node] = {}
    return_node = None
    for node in gm.graph.nodes:
        if node.op == "output":
            continue
        new_node = new_graph.node_copy(node, lambda x: env[x])
        env[node] = new_node
        if node.name == output_node_name:
            return_node = new_node
    assert return_node is not None, f"{output_node_name} not found"
    new_graph.output(return_node)
    new_graph.lint()
    new_gm = torch.fx.GraphModule(gm, new_graph)
    return new_gm

gm = symbolic_trace(model)
new_gm = modify_return_node(gm, "add_1")
print(new_gm)

GraphModule(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x, y):
    getitem = x[1]
    add = x + y;  x = y = None
    mul = add * 2;  add = None
    linear = self.linear(mul);  mul = None
    add_1 = getitem + linear;  getitem = linear = None
    linear2 = self.linear2(add_1)
    add_2 = linear2 + 1;  linear2 = None
    return add_1
    
# To see more debug info, please use `graph_module.print_readable()`


### Replace a node

In [36]:
def replace_node(gm, old_node_name, new_node):
    for node in gm.graph.nodes:
        if node.name == old_node_name:
            node.replace_all_uses_with(new_node)
            gm.graph.erase_node(node)
            break
    gm.recompile()
    return gm

gm = symbolic_trace(model)
new_node = gm.graph.placeholder('ph')
gm = replace_node(gm, 'linear', new_node)
print(gm)

MyModel(
  (linear): Linear(in_features=10, out_features=10, bias=True)
  (linear2): Linear(in_features=10, out_features=10, bias=True)
)



def forward(self, x, y, ph):
    getitem = x[1]
    add = x + y;  x = y = None
    mul = add * 2;  add = None
    add_1 = getitem + ph;  getitem = ph = None
    linear2 = self.linear2(add_1);  add_1 = None
    add_2 = linear2 + 1;  linear2 = None
    return add_2
    
# To see more debug info, please use `graph_module.print_readable()`
