In [5]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import symbolic_trace


In [6]:
#Defining the convolutional neural network
class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(6)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc = nn.Linear(400, 120)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(84, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.maxpool1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.maxpool2(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

In [7]:
model = LeNet5(num_classes=10) # 10 for MNIST digit recognition
x = torch.randn(1, 1, 32, 32)
y = model(x)
print(y)

tensor([[-0.0709, -0.0130,  0.0522, -0.0422, -0.1070, -0.0388,  0.0270,  0.0878,
         -0.0439,  0.0869]], grad_fn=<AddmmBackward0>)


In [8]:
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(model)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu1 : [num_users=1] = call_module[target=relu1](args = (%bn1,), kwargs = {})
    %maxpool1 : [num_users=1] = call_module[target=maxpool1](args = (%relu1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%maxpool1,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu2 : [num_users=1] = call_module[target=relu2](args = (%bn2,), kwargs = {})
    %maxpool2 : [num_users=2] = call_module[target=maxpool2](args = (%relu2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%maxpool2, 0), kwargs = {})
    %reshape : [num_users=1] = call_method[target=reshape](args = (%maxpool2, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%reshape

In [26]:
symbolic_traced.graph.print_tabular()


opcode       name      target    args                  kwargs
-----------  --------  --------  --------------------  --------
placeholder  x         x         ()                    {}
call_module  conv1     conv1     (x,)                  {}
call_module  bn1       bn1       (conv1,)              {}
call_module  relu1     relu1     (bn1,)                {}
call_module  maxpool1  maxpool1  (relu1,)              {}
call_module  conv2     conv2     (maxpool1,)           {}
call_module  bn2       bn2       (conv2,)              {}
call_module  relu2     relu2     (bn2,)                {}
call_module  maxpool2  maxpool2  (relu2,)              {}
call_method  size      size      (maxpool2, 0)         {}
call_method  reshape   reshape   (maxpool2, size, -1)  {}
call_module  fc        fc        (reshape,)            {}
call_module  relu      relu      (fc,)                 {}
call_module  fc1       fc1       (relu,)               {}
call_module  relu1_1   relu1     (fc1,)                {}
call

In [None]:
# class ECG:
#     def __init__(self, model):
#         symbolic_traced = symbolic_trace(model)
#         for
        

In [33]:
for node in symbolic_traced.graph.nodes:
    if node.op == "call_module":
        print(node.op, node.target, node.args, getattr(symbolic_traced, node.target))


call_module conv1 (x,) Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
call_module bn1 (conv1,) BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
call_module relu1 (bn1,) ReLU()
call_module maxpool1 (relu1,) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
call_module conv2 (maxpool1,) Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
call_module bn2 (conv2,) BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
call_module relu2 (bn2,) ReLU()
call_module maxpool2 (relu2,) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
call_module fc (reshape,) Linear(in_features=400, out_features=120, bias=True)
call_module relu (fc,) ReLU()
call_module fc1 (relu,) Linear(in_features=120, out_features=84, bias=True)
call_module relu1 (fc1,) ReLU()
call_module fc2 (relu1_1,) Linear(in_features=84, out_features=10, bias=True)


In [10]:
all_operators = {node.target for node in symbolic_traced.graph.nodes if node.op == "call_module" or node.op == "call_function"}
print(all_operators)

{'maxpool1', 'bn1', 'maxpool2', 'fc2', 'conv1', 'relu1', 'fc', 'bn2', 'relu', 'fc1', 'conv2', 'relu2'}


# Fusion Plan Algorithm

In [None]:
def generate_seed(ops):
    pass

def successors(op):
    pass

def predecessors(op):
    pass

def fuse_successor(sp, successor, block):
    pass

def fuse_predecessor(sp, predecessor, block):
    pass

# <Algorithm Entry>
unfused_ops = all_operators
# Step 1: start fuse from the selected seed
while sp := generate_seed(unfused_ops):
    block = {sp}
    # Step 2: head to successor
    for successor in successors(sp):
        fuse_successor(sp, successor, block)
    # Step 3: head to predecessors
    for predecessor in predecessors(sp):
        fuse_predecessor(sp, predecessor, block)
    unfused_ops -= block
    
    