In [37]:
import torch
import torch.nn as nn
from typing import Union
import re, copy
import numpy as np

In [38]:
torch.manual_seed(0)

<torch._C.Generator at 0x7f0153e50c90>

In [39]:
channels = 1
height = 28
width = 28

input_dummy = torch.randn((1, channels, height, width))

## Graph Tracer

In [40]:
class GraphTracer():
    def __init__(self, model: Union[nn.Sequential, nn.Module], dummy_input: np.array) -> None:
        """ ."""

        trace = torch.jit.trace(model, dummy_input)
        _ = trace(dummy_input)
        __ = copy.deepcopy(trace)

        self.graph = __.graph

        self.modules_map, self.name_2_indx_map  = self.get_named_modules(model)
        self.forward_edges                      = self.get_foward_edges()
        self.ATens                              = self.get_ATen_operations()
        self.edges_list                         = self.get_graph_edges()

    def from_name_2_indx(self, name):
        if name in self.name_2_indx_map:
            return self.name_2_indx_map[name]
        else:
            last_indx = None
            for _name, indx in self.name_2_indx_map.items():
                last_indx = indx
            self.name_2_indx_map[name] = last_indx+1
            return self.name_2_indx_map[name]

    def get_named_modules(self, module: nn.Module):
        """ ."""
        modules_map = {}
        name_2_indx_map = {}
        indx = 0
        for name, mod in module.named_modules():
            if name:
                modules_map[indx] = mod
                name_2_indx_map[name] = indx
                indx += 1
        return modules_map, name_2_indx_map
    
    def get_foward_edges(self):
        """ ."""
        forward_edges = {}
        for node in self.graph.nodes():
            node = str(node)
            regex = re.compile(r'%(.*?) :.*prim::CallMethod\[name="forward"\]\(%(.*?), %(.*?)\)')
            match = regex.search(node)
            if match:
                source = match.group(3).replace('_', '')
                target = match.group(2).replace('_', '')
                result = match.group(1).replace('_', '')
                forward_edges[self.from_name_2_indx(result)] = (self.from_name_2_indx(source), self.from_name_2_indx(target))
                
        return forward_edges

    def get_graph_edges(self):
        """ ."""
        edges = []
        last_result = None

        for result_node, forward_edge in self.forward_edges.items():
            src = forward_edge[0]
            trg = forward_edge[1]

            if not last_result:
                last_result = result_node
                edges.append(('input', trg))
            elif src == last_result:
                edges.append((edges[-1][1], trg))
                last_result = result_node
            else:
                scr1, scr2 = self.get_ATen_operands(src)
                edges.append((scr1, trg))
                edges.append((scr2, trg))
                last_result = result_node
    
        edges.append((edges[-1][1], 'output'))

        return edges[1:-1]
    
    def get_ATen_operands(self, node):
        """ ."""
        if node in self.ATens:
            src1 = self.ATens[node]['args'][1]
            src2 = self.ATens[node]['args'][0]
            return self.forward_edges[src1][1], self.forward_edges[src2][1]
        else:
            # throw error
            return None, None
        
    def get_ATen_operations(self):
        """ ATen is PyTorch's tensor library backend, which provides a set of operations that operate on 
        tensors directly. These include arithmetic operations (add, mul, etc.), mathematical 
        functions (sin, cos, etc.), and tensor manipulation operations (view, reshape, etc.)."""
        ATens = {}
        for node in self.graph.nodes():
            node = str(node)
            regex = re.compile(r'%(.*?) :.*aten::(.*?)\(%(.*?), %(.*?), %(.*?)\)')

            match = regex.search(node)

            if match:
                result_node = match.group(1)
                operation = match.group(2)
                operator1 = self.from_name_2_indx(match.group(3))
                operator2 = self.from_name_2_indx(match.group(4))
                const_operator = match.group(5)
                ATens[result_node] = {'op': operation, 'args': (operator1, operator2, const_operator)}
        return ATens
    
    def remove_ignored_nodes(self, default_ignored_nodes):
        """ Recreates the edges list based on layers that 'DynapcnnNetwork' will ignored. This
        is done by setting the source (target) node of an edge where the source (target) node
        will be dropped as the node that originally targeted this node to be dropped.
        """
        edges = copy.deepcopy(self.edges_list[1:-1])
        new_edges = []

        for edge_idx in range(len(edges)):
            _src = edges[edge_idx][0]
            _trg = edges[edge_idx][1]

            if isinstance(self.modules_map[_src], default_ignored_nodes):
                # all edges where node '_src' is target change it to node '_trg' as their target.
                for edge in edges:
                    if edge[1] == _src:
                        new_edge = (edge[0], _trg)
            elif isinstance(self.modules_map[_trg], default_ignored_nodes):
                # all edges where node '_trg' is source change it to node '_src' as their source.
                for edge in edges:
                    if edge[0] == _trg:
                        new_edge = (_src, edge[1])
            else:
                new_edge = (_src, _trg)
            
            if new_edge not in new_edges:
                new_edges.append(new_edge)

        return new_edges
    
    def plot_graph(self):
        """ ."""
        G = nx.DiGraph(self.edges_list)
        layout = nx.spring_layout(G)
        nx.draw(G, pos = layout, with_labels=True, node_size=800)
        plt.title('GraphTracer (new)')
        plt.show()

## Tracing 1

In [41]:
ann1 = nn.Sequential(
    nn.Conv2d(1, 20, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(20, 32, 5, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Conv2d(32, 128, 3, 1, bias=False),
    nn.ReLU(),
    nn.AvgPool2d(2,2),
    nn.Flatten(),
    nn.Linear(128, 500, bias=False),
    nn.ReLU(),
    nn.Linear(500, 10, bias=False),
)

In [42]:
gtracer1 = GraphTracer(ann1, input_dummy)

In [43]:
for name, mod in gtracer1.modules_map.items():
    print(name, mod)

0 Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)
1 ReLU()
2 AvgPool2d(kernel_size=2, stride=2, padding=0)
3 Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
4 ReLU()
5 AvgPool2d(kernel_size=2, stride=2, padding=0)
6 Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
7 ReLU()
8 AvgPool2d(kernel_size=2, stride=2, padding=0)
9 Flatten(start_dim=1, end_dim=-1)
10 Linear(in_features=128, out_features=500, bias=False)
11 ReLU()
12 Linear(in_features=500, out_features=10, bias=False)


In [44]:
for edge in gtracer1.edges_list:
    print(edge)
print(len(gtracer1.edges_list))

(0, 1)
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 9)
(9, 10)
(10, 11)
(11, 12)
12


## Tracing 2

In [45]:
class ANN(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.con1 = nn.Conv2d(1, 20, 5, 1, bias=False)
        self.rel1 = nn.ReLU()
        self.pool1 = nn.AvgPool2d(2,2)
        self.conv2 = nn.Conv2d(20, 32, 5, 1, bias=False)
        self.rel2 = nn.ReLU()
        self.pool2 = nn.AvgPool2d(2,2)
        self.conv3 = nn.Conv2d(32, 128, 3, 1, bias=False)
        self.rel3 = nn.ReLU()
        self.pool3 = nn.AvgPool2d(2,2)
        self.flat = nn.Flatten()
        self.fc1 = nn.Linear(128, 500, bias=False)
        self.rel4 = nn.ReLU()
        self.fc2 = nn.Linear(500, 10, bias=False)

    def forward(self, x):
        
        con1_out = self.con1(x)
        rel1_out = self.rel1(con1_out)
        pool1_out = self.pool1(rel1_out)
        conv2_out = self.conv2(pool1_out)
        rel2_out = self.rel2(conv2_out)
        pool2_out = self.pool2(rel2_out)
        conv3_out = self.conv3(pool2_out)
        rel3_out = self.rel3(conv3_out)
        pool3_out = self.pool3(rel3_out)
        flat_out = self.flat(pool3_out)
        fc1_out = self.fc1(flat_out)
        rel4_out = self.rel4(fc1_out)
        fc2_out = self.fc2(rel4_out)

        return fc2_out

ann2 = ANN()

In [46]:
gtracer2 = GraphTracer(ann2, input_dummy)

In [47]:
for name, mod in gtracer2.modules_map.items():
    print(name, mod)

0 Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), bias=False)
1 ReLU()
2 AvgPool2d(kernel_size=2, stride=2, padding=0)
3 Conv2d(20, 32, kernel_size=(5, 5), stride=(1, 1), bias=False)
4 ReLU()
5 AvgPool2d(kernel_size=2, stride=2, padding=0)
6 Conv2d(32, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
7 ReLU()
8 AvgPool2d(kernel_size=2, stride=2, padding=0)
9 Flatten(start_dim=1, end_dim=-1)
10 Linear(in_features=128, out_features=500, bias=False)
11 ReLU()
12 Linear(in_features=500, out_features=10, bias=False)


In [48]:
for edge in gtracer2.edges_list:
    print(edge)
print(len(gtracer2.edges_list))

(0, 1)
(1, 2)
(2, 3)
(3, 4)
(4, 5)
(5, 6)
(6, 7)
(7, 8)
(8, 9)
(9, 10)
(10, 11)
(11, 12)
12
