In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.fx import symbolic_trace
from enum import Enum
import torch.fx as fx

In [2]:
#Defining the convolutional neural network
class LeNet5(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0)
        self.bn1 = nn.BatchNorm2d(6)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0)
        self.bn2 = nn.BatchNorm2d(16)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.fc = nn.Linear(400, 120)
        self.relu = nn.ReLU()
        self.fc1 = nn.Linear(120, 84)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(84, num_classes)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.maxpool1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.maxpool2(out)

        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        out = self.relu(out)
        out = self.fc1(out)
        out = self.relu1(out)
        out = self.fc2(out)
        return out

In [3]:
model = LeNet5(num_classes=10) # 10 for MNIST digit recognition
x = torch.randn(1, 1, 32, 32)
y = model(x)
print(y)

tensor([[ 0.1319, -0.1176, -0.1194,  0.1197,  0.0881, -0.2426, -0.1008,  0.0876,
         -0.0484, -0.0988]], grad_fn=<AddmmBackward0>)


In [4]:
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced: torch.fx.GraphModule = symbolic_trace(model)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)

graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv1 : [num_users=1] = call_module[target=conv1](args = (%x,), kwargs = {})
    %bn1 : [num_users=1] = call_module[target=bn1](args = (%conv1,), kwargs = {})
    %relu1 : [num_users=1] = call_module[target=relu1](args = (%bn1,), kwargs = {})
    %maxpool1 : [num_users=1] = call_module[target=maxpool1](args = (%relu1,), kwargs = {})
    %conv2 : [num_users=1] = call_module[target=conv2](args = (%maxpool1,), kwargs = {})
    %bn2 : [num_users=1] = call_module[target=bn2](args = (%conv2,), kwargs = {})
    %relu2 : [num_users=1] = call_module[target=relu2](args = (%bn2,), kwargs = {})
    %maxpool2 : [num_users=2] = call_module[target=maxpool2](args = (%relu2,), kwargs = {})
    %size : [num_users=1] = call_method[target=size](args = (%maxpool2, 0), kwargs = {})
    %reshape : [num_users=1] = call_method[target=reshape](args = (%maxpool2, %size, -1), kwargs = {})
    %fc : [num_users=1] = call_module[target=fc](args = (%reshape

In [5]:
symbolic_traced.graph.print_tabular()

opcode       name      target    args                  kwargs
-----------  --------  --------  --------------------  --------
placeholder  x         x         ()                    {}
call_module  conv1     conv1     (x,)                  {}
call_module  bn1       bn1       (conv1,)              {}
call_module  relu1     relu1     (bn1,)                {}
call_module  maxpool1  maxpool1  (relu1,)              {}
call_module  conv2     conv2     (maxpool1,)           {}
call_module  bn2       bn2       (conv2,)              {}
call_module  relu2     relu2     (bn2,)                {}
call_module  maxpool2  maxpool2  (relu2,)              {}
call_method  size      size      (maxpool2, 0)         {}
call_method  reshape   reshape   (maxpool2, size, -1)  {}
call_module  fc        fc        (reshape,)            {}
call_module  relu      relu      (fc,)                 {}
call_module  fc1       fc1       (relu,)               {}
call_module  relu1_1   relu1     (fc1,)                {}
call

In [6]:
for node in symbolic_traced.graph.nodes:
    if node.op == "call_module":
        print(node.op, node.target, node.args, getattr(symbolic_traced, node.target))
        print(type(getattr(symbolic_traced, node.target)))
    elif node.op == "call_method":
        print(node.op, node.target, node.args)

call_module conv1 (x,) Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
<class 'torch.nn.modules.conv.Conv2d'>
call_module bn1 (conv1,) BatchNorm2d(6, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
call_module relu1 (bn1,) ReLU()
<class 'torch.nn.modules.activation.ReLU'>
call_module maxpool1 (relu1,) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
<class 'torch.nn.modules.pooling.MaxPool2d'>
call_module conv2 (maxpool1,) Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
<class 'torch.nn.modules.conv.Conv2d'>
call_module bn2 (conv2,) BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
<class 'torch.nn.modules.batchnorm.BatchNorm2d'>
call_module relu2 (bn2,) ReLU()
<class 'torch.nn.modules.activation.ReLU'>
call_module maxpool2 (relu2,) MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
<class 'torch.nn.modules.pooling.MaxPool2d'>
call_method siz

In [7]:
# Helper function that calculates the size of each layer in the ECG
def calculate_IRS_SIZE():
    traced = symbolic_trace(model)
    module_to_size = {}

    for node in traced.graph.nodes:
        if node.op == "call_module":
            module = getattr(traced, node.target)

            def make_hook(name):
                def hook(module, input, output):
                    module_to_size[name] = output.shape
                return hook

            module.register_forward_hook(make_hook(node.target))
    
    traced(x)
    return module_to_size

IRS_sizes = calculate_IRS_SIZE()
print("IRS Sizes:", IRS_sizes)

IRS Sizes: {'conv1': torch.Size([1, 6, 28, 28]), 'bn1': torch.Size([1, 6, 28, 28]), 'relu1': torch.Size([1, 84]), 'maxpool1': torch.Size([1, 6, 14, 14]), 'conv2': torch.Size([1, 16, 10, 10]), 'bn2': torch.Size([1, 16, 10, 10]), 'relu2': torch.Size([1, 16, 10, 10]), 'maxpool2': torch.Size([1, 16, 5, 5]), 'fc': torch.Size([1, 120]), 'relu': torch.Size([1, 120]), 'fc1': torch.Size([1, 84]), 'fc2': torch.Size([1, 10])}


In [8]:
all_operators = [node for node in symbolic_traced.graph.nodes if node.op == "call_module" or node.op == "call_function"]
print(all_operators)

[conv1, bn1, relu1, maxpool1, conv2, bn2, relu2, maxpool2, fc, relu, fc1, relu1_1, fc2]


In [9]:
for node in symbolic_traced.graph.nodes:
    if node.op == "call_module":
        print(node, list(node.users.keys()), node.all_input_nodes)

conv1 [bn1] [x]
bn1 [relu1] [conv1]
relu1 [maxpool1] [bn1]
maxpool1 [conv2] [relu1]
conv2 [bn2] [maxpool1]
bn2 [relu2] [conv2]
relu2 [maxpool2] [bn2]
maxpool2 [size, reshape] [relu2]
fc [relu] [reshape]
relu [fc1] [fc]
fc1 [relu1_1] [relu]
relu1_1 [fc2] [fc1]
fc2 [output] [relu1_1]


# Fusion Plan Algorithm

In [10]:
# Define an enum of values many-to-many, one-to-one
class mapping_type(Enum):
    MANY_TO_MANY = 1
    ONE_TO_ONE = 2
    SHUFFLE = 3
    REORGANIZE = 4

mapping_type_table = {
    torch.nn.modules.batchnorm.BatchNorm2d: mapping_type.ONE_TO_ONE,
    torch.nn.modules.activation.ReLU: mapping_type.ONE_TO_ONE,
    torch.nn.modules.pooling.MaxPool2d: mapping_type.MANY_TO_MANY,
    torch.nn.modules.conv.Conv2d: mapping_type.MANY_TO_MANY,
    torch.nn.modules.linear.Linear: mapping_type.SHUFFLE
}

In [11]:
# # Testing table lookup
# for node in symbolic_traced.graph.nodes:
#     if node.op == "call_module":
#         print(node.op, node.target, node.args, getattr(symbolic_traced, node.target))
#         print(mapping_type_table.get(type(getattr(symbolic_traced, node.target))))

In [12]:
class MAPPING_RELATIONSHIP(Enum):
    FUSE_BREAK = 0
    FUSE_ONE_TO_ONE = 1
    FUSE_MANY_TO_MANY = 2
    FUSE_SHUFFLE = 3

def mapping_check(op, successor):
    print("Mapping check")
    print("op:", op)
    print("successor:", successor)
    print("op type:", type(getattr(symbolic_traced, op.target)))
    print("successor type:", type(getattr(symbolic_traced, successor.target)))
    op1_mapping = mapping_type_table.get(type(getattr(symbolic_traced, op.target)))
    successor_mapping = mapping_type_table.get(type(getattr(symbolic_traced, successor.target)))

    # CASES depending on op and successor
    if op1_mapping == mapping_type.ONE_TO_ONE and successor_mapping == mapping_type.ONE_TO_ONE:
        return MAPPING_RELATIONSHIP.FUSE_ONE_TO_ONE
    if op1_mapping == mapping_type.MANY_TO_MANY and successor_mapping == mapping_type.MANY_TO_MANY:
        return MAPPING_RELATIONSHIP.FUSE_BREAK
    if op1_mapping == mapping_type.MANY_TO_MANY and successor_mapping == mapping_type.ONE_TO_ONE:
        return MAPPING_RELATIONSHIP.FUSE_MANY_TO_MANY
    if op1_mapping == mapping_type.ONE_TO_ONE and successor_mapping == mapping_type.MANY_TO_MANY:
        return MAPPING_RELATIONSHIP.FUSE_MANY_TO_MANY
    if op1_mapping == mapping_type.SHUFFLE and successor_mapping == mapping_type.SHUFFLE:
        return MAPPING_RELATIONSHIP.FUSE_SHUFFLE
    if op1_mapping == mapping_type.SHUFFLE and successor_mapping == mapping_type.ONE_TO_ONE:
        return MAPPING_RELATIONSHIP.FUSE_SHUFFLE
    if op1_mapping == mapping_type.SHUFFLE and successor_mapping == mapping_type.MANY_TO_MANY:
        return MAPPING_RELATIONSHIP.FUSE_BREAK # TODO: this should be a fuse check
    if op1_mapping == mapping_type.ONE_TO_ONE and successor_mapping == mapping_type.SHUFFLE:
        return MAPPING_RELATIONSHIP.FUSE_SHUFFLE
    if op1_mapping == mapping_type.MANY_TO_MANY and successor_mapping == mapping_type.SHUFFLE:
        return MAPPING_RELATIONSHIP.FUSE_BREAK # TODO: this should be a fuse check
    
    return MAPPING_RELATIONSHIP.FUSE_BREAK # DEFAULT CASE
    


In [None]:
def generate_seed(nodes):
    # Find all the one-to-one mapping operators
    one_to_one_nodes = [node for node in nodes if mapping_type_table.get(type(getattr(symbolic_traced, node.target))) == mapping_type.ONE_TO_ONE]
    # Using the IRS_size, return the one-to-one mapping operator with the smallest output size (take product of dimensions)
    min_size = float('inf')
    seed_node = None
    for node in one_to_one_nodes:
        size = np.prod(IRS_sizes[node.target])
        if size < min_size:
            min_size = size
            seed_node = node
    return seed_node
    

def successors(op):
    # # Get the successors that folow the current node
    # successors = []
    # # Find index of the current node
    # index = all_operators.index(op)
    # # Successors are the nodes that follow the current node
    # successors = all_operators[index + 1:]

    # return successors

    successors = list(op.users.keys())

    # Remove successors that meet the conditions
    successors = [s for s in successors if not (s.op == "output" or s.op == "call_method")]

    return successors

def predecessors(op):
    # Get the predecessors that folow the current node
    # predecessors = []
    # # Find index of the current node
    # index = all_operators.index(op)
    # # Predecessors are the nodes that precede the current node
    # predecessors = all_operators[:index]

    # return predecessors

    # pred = [n for n in fx.graph.map_arg(test_node.args, lambda x: x if isinstance(x, fx.Node) else None) if n]

    pred = list(op.all_input_nodes)

    # Check if the pred is "placeholder" or "size"/"reshape"
    pred = [p for p in pred if not (p.op == "call_method" or p.op == "placeholder")]

    return pred

def fuse_successor(op, successor, block):
    # Check the mapping relationship
    relation = mapping_check(op, successor)

    # no relationship exists
    if relation == MAPPING_RELATIONSHIP.FUSE_BREAK:
        return
    
    # Step 2.2: check constraint requirement

    # Step 2.3: if benefit of fusion unknown, get the latency and check 
    # TODO: this only applies if the fusion is potentially beneficial

    # Block is the combination of op and successor
    block.add(successor)
    # Step 2.4: Recursively head to successor
    for fusing_op in successors(successor):
        fuse_successor(successor, fusing_op, block)

def fuse_predecessor(sp, predecessor, block):
    # Check relation
    relation = mapping_check(sp, predecessor)

    # no relationship exists
    if relation == MAPPING_RELATIONSHIP.FUSE_BREAK:
        return
    
    # Step 2.2: check constraint requirement

    # Step 2.3: if benefit of fusion unknown, get the latency and check
    # TODO: this only applies if the fusion is potentially beneficial

    # Block is the combination of op and successor
    block.add(predecessor)

    # Step 2.4: Recursively head to predecessor
    for fusing_op in predecessors(predecessor):
        fuse_predecessor(sp, fusing_op, block)

unfused_ops = all_operators
# seed_node = generate_seed(unfused_ops)
# # print(successors(unfused_ops[0]))
# # print(predecessors(unfused_ops[0]))
# # print(successors(unfused_ops[4]))
# # print(predecessors(unfused_ops[4]))

# test_block = {unfused_ops[-1]}
# test_block = {seed_node}
# fuse_successor(seed_node, successors(seed_node)[0], test_block)

# print("Predecssors: ", predecessors(unfused_ops[3]))

# print("Block after fusing:", test_block)

# test_node = unfused_ops[-1]
# # Print node successors using users and keys
# print("Successors: ", successors(test_node))

# test_node = unfused_ops[3]
# print("Successors: ", successors(test_node))
# test_node = unfused_ops[4]
# print("Predecessors: ", predecessors(test_node))

# test_node = unfused_ops[7]
# print("Successors: ", successors(test_node))

# test_node = unfused_ops[8]
# print("Predecessors: ", predecessors(test_node))

# # <Algorithm Entry>
unfused_ops = set(all_operators)
print("Unfused ops: ", unfused_ops)
# Step 1: start fuse from the selected seed
while (sp := generate_seed(unfused_ops)):
    block = {sp}
    # Step 2: head to successor
    for successor in successors(sp):
        print("Successor: ", successor)
        print("sp: ", sp)
        fuse_successor(sp, successor, block)
    # Step 3: head to predecessors
    for predecessor in predecessors(sp):
        fuse_predecessor(sp, predecessor, block)
    unfused_ops -= block
    
print("Fused ops: ", block)
print("Unfused ops: ", unfused_ops)

Unfused ops:  {fc1, fc, bn2, relu1, conv1, relu2, maxpool1, fc2, relu, maxpool2, relu1_1, conv2, bn1}
Successor:  maxpool1
sp:  relu1
Mapping check
op: relu1
successor: maxpool1
op type: <class 'torch.nn.modules.activation.ReLU'>
successor type: <class 'torch.nn.modules.pooling.MaxPool2d'>
Mapping check
op: maxpool1
successor: conv2
op type: <class 'torch.nn.modules.pooling.MaxPool2d'>
successor type: <class 'torch.nn.modules.conv.Conv2d'>
Mapping check
op: relu1
successor: bn1
op type: <class 'torch.nn.modules.activation.ReLU'>
successor type: <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
Mapping check
op: relu1
successor: conv1
op type: <class 'torch.nn.modules.activation.ReLU'>
successor type: <class 'torch.nn.modules.conv.Conv2d'>
Successor:  fc2
sp:  relu1_1
Mapping check
op: relu1_1
successor: fc2
op type: <class 'torch.nn.modules.activation.ReLU'>
successor type: <class 'torch.nn.modules.linear.Linear'>
Mapping check
op: relu1_1
successor: fc1
op type: <class 'torch.nn.module

In [221]:
# Dictionary to store output sizes
layer_outputs = {}

def hook_fn(module, input, output):
    layer_outputs[module] = output.shape

# Register hooks on all modules
hooks = []
for name, module in model.named_modules():
    if not isinstance(module, nn.Sequential) and not isinstance(module, LeNet5):
        hooks.append(module.register_forward_hook(hook_fn))

# Create a dummy input and pass through model
dummy_input = torch.randn(1, 1, 32, 32)
_ = model(dummy_input)

# Print the recorded output shapes
for layer, shape in layer_outputs.items():
    print(f"{layer.__class__.__name__} -> {shape}")

# Remove the hooks (clean up)
for hook in hooks:
    hook.remove()


Conv2d -> torch.Size([1, 6, 28, 28])
BatchNorm2d -> torch.Size([1, 6, 28, 28])
ReLU -> torch.Size([1, 84])
MaxPool2d -> torch.Size([1, 6, 14, 14])
Conv2d -> torch.Size([1, 16, 10, 10])
BatchNorm2d -> torch.Size([1, 16, 10, 10])
ReLU -> torch.Size([1, 16, 10, 10])
MaxPool2d -> torch.Size([1, 16, 5, 5])
Linear -> torch.Size([1, 120])
ReLU -> torch.Size([1, 120])
Linear -> torch.Size([1, 84])
Linear -> torch.Size([1, 10])


In [222]:
# Print the recorded output shapes
for layer, shape in layer_outputs.items():
    print(f"{layer.__class__.__name__} -> {shape}")

# Remove the hooks (clean up)
for hook in hooks:
    hook.remove()


Conv2d -> torch.Size([1, 6, 28, 28])
BatchNorm2d -> torch.Size([1, 6, 28, 28])
ReLU -> torch.Size([1, 84])
MaxPool2d -> torch.Size([1, 6, 14, 14])
Conv2d -> torch.Size([1, 16, 10, 10])
BatchNorm2d -> torch.Size([1, 16, 10, 10])
ReLU -> torch.Size([1, 16, 10, 10])
MaxPool2d -> torch.Size([1, 16, 5, 5])
Linear -> torch.Size([1, 120])
ReLU -> torch.Size([1, 120])
Linear -> torch.Size([1, 84])
Linear -> torch.Size([1, 10])


In [96]:
traced = symbolic_trace(model)
module_to_size = {}

for node in traced.graph.nodes:
    if node.op == "call_module":
        module = getattr(traced, node.target)

        def make_hook(name):
            def hook(module, input, output):
                module_to_size[name] = output.shape
            return hook

        module.register_forward_hook(make_hook(node.target))

x = torch.randn(1, 1, 32, 32)
traced(x)

print("Result:", module_to_size)

Hook called for: conv1
Hook called for: conv1
name bn1
Module: bn1, Output shape: torch.Size([1, 6, 28, 28])
Hook called for: bn1
Hook called for: bn1
name relu1
Module: relu1, Output shape: torch.Size([1, 6, 28, 28])
name relu1
Module: relu1, Output shape: torch.Size([1, 6, 28, 28])
Hook called for: relu1
Hook called for: relu1
Hook called for: relu1
Hook called for: relu1
Hook called for: maxpool1
Hook called for: maxpool1
Hook called for: conv2
Hook called for: conv2
name bn2
Module: bn2, Output shape: torch.Size([1, 16, 10, 10])
Hook called for: bn2
Hook called for: bn2
name relu2
Module: relu2, Output shape: torch.Size([1, 16, 10, 10])
Hook called for: relu2
Hook called for: relu2
Hook called for: maxpool2
Hook called for: maxpool2
Hook called for: fc
Hook called for: fc
name relu
Module: relu, Output shape: torch.Size([1, 120])
Hook called for: relu
Hook called for: relu
Hook called for: fc1
Hook called for: fc1
name relu1
Module: relu1, Output shape: torch.Size([1, 84])
name rel