In [32]:
import torch
import torch.nn as nn
import torch.fx as fx
import torch.optim as optim
import torch.nn.functional as F
from convolution import CustomConv2D

In [33]:
class MNISTModel(nn.Module):
    def __init__(self, is_training=True):
        super(MNISTModel, self).__init__()

        if is_training:
            self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        else:
            self.conv1 = CustomConv2D(1, 32, kernel_size=3, stride=1, padding=1)
            self.conv2 = CustomConv2D(32, 64, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [37]:
class ReplaceReluWithSigmoid(fx.Interpreter):
    def call_function(self, target, args, kwargs):
        if target == torch.relu:
            print(f"Replacing ReLU with Sigmoid at {target}")
            return torch.sigmoid(*args)
        return super().call_function(target, args, kwargs)

In [38]:
model = MNISTModel(is_training=True)
traced_model = fx.symbolic_trace(model)

print(traced_model.graph)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %pool : [num_users=1] = call_module[target=pool](args = (%conv1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%pool,), kwargs = {})
    %pool_1 : [num_users=2] = call_module[target=pool](args = (%conv2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%pool_1, 0), kwargs = {})
    %reshape : [num_users=1] = call_method[target=reshape](args = (%pool_1, %size, -1), kwargs = {})
    %fc1 : [num_users=1] = call_module[target=fc1](args = (%reshape,), kwargs = {})
    %relu : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%fc1,), kwargs = {inplace: False})
    %fc2 : [num_users=1] = call_module[target=fc2](args = (%relu,), kwargs = {})
    return fc2


In [39]:
transformer = ReplaceReluWithSigmoid(traced_model)
transformed_output = transformer.run(torch.randn(1, 1, 28, 28))
print("Transformed Output (ReLU replaced by Sigmoid):", transformed_output)

Transformed Output (ReLU replaced by Sigmoid): tensor([[-0.1405,  0.0976,  0.1702, -0.0080, -0.1772,  0.0951,  0.0214,  0.0063,
          0.0360, -0.0378]], grad_fn=<AddmmBackward0>)


In [40]:
def print_graph_in_tabular(graph):
    headers = ["opcode", "name", "target", "args", "kwargs"]
    row_format = "{:<15} {:<10} {:<25} {:<20} {}"

    print(row_format.format(*headers))
    print("-" * 90)

    for node in graph.nodes:
        opcode = node.op
        name = node.name
        target = node.target
        args = str(node.args)
        kwargs = str(node.kwargs)

        print(row_format.format(opcode, name, str(target), args, kwargs))

print("Transformed Graph:")
print_graph_in_tabular(traced_model.graph)

Transformed Graph:
opcode          name       target                    args                 kwargs
------------------------------------------------------------------------------------------
placeholder     x          x                         ()                   {}
call_module     conv1      conv1                     (x,)                 {}
call_module     pool       pool                      (conv1,)             {}
call_module     conv2      conv2                     (pool,)              {}
call_module     pool_1     pool                      (conv2,)             {}
call_method     size       size                      (pool_1, 0)          {}
call_method     reshape    reshape                   (pool_1, size, -1)   {}
call_module     fc1        fc1                       (reshape,)           {}
call_function   relu       <function relu at 0x7ee9b959f5b0> (fc1,)               {'inplace': False}
call_module     fc2        fc2                       (relu,)              {}
output         