In [1]:
import torch
import torch.nn as nn


_CONFIGS = {
    'VGG8': ['M', 256, 256, 'M', 512, 512, 'M'],
    'VGG9': [128, 'M', 256, 256, 'M', 512, 512, 'M'],
    'VGG11': [128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M'],
}


class VGG(nn.Module):

    def __init__(self, config: str, capacity: int = 1, use_bn: bool = False, num_classes: int = 10, seed: int = -1) -> None:

        super(VGG, self).__init__()

        self.pilot      = self._make_pilot(capacity, use_bn)
        self.features   = self._make_features(config, capacity, use_bn)
        self.avgpool    = nn.AdaptiveAvgPool2d((4, 4))
        self.classifier = self._make_classifier(capacity, use_bn, num_classes)

        self._initialize_weights(seed=seed)

    @staticmethod
    def _make_pilot(capacity: int, use_bn: bool) -> nn.Sequential:

        modules = []
        modules += [nn.Conv2d(3, 128 * capacity, kernel_size=3, padding=1, bias=not use_bn)]
        modules += [nn.BatchNorm2d(128 * capacity)] if use_bn else []
        modules += [nn.ReLU(inplace=True)]

        return nn.Sequential(*modules)

    @staticmethod
    def _make_features(config: str, capacity: int, use_bn: bool) -> nn.Sequential:

        modules = []
        in_channels = 128 * capacity
        for v in _CONFIGS[config]:
            if v == 'M':
                modules += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                out_channels = v * capacity
                modules += [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=not use_bn)]
                modules += [nn.BatchNorm2d(out_channels)] if use_bn else []
                modules += [nn.ReLU(inplace=True)]
                in_channels = out_channels

        return nn.Sequential(*modules)

    @staticmethod
    def _make_classifier(capacity: int, use_bn: bool, num_classes: int) -> nn.Sequential:

        modules = []
        modules += [nn.Linear(512 * capacity * 4 * 4, 1024, bias=not use_bn)]
        modules += [nn.BatchNorm1d(1024)] if use_bn else []
        modules += [nn.ReLU(inplace=True)]
        modules += [] if use_bn else [nn.Dropout()]
        modules += [nn.Linear(1024, 1024, bias=not use_bn)]
        modules += [nn.BatchNorm1d(1024)] if use_bn else []
        modules += [nn.ReLU(inplace=True)]
        modules += [] if use_bn else [nn.Dropout()]
        modules += [nn.Linear(1024, num_classes)]

        return nn.Sequential(*modules)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.pilot(x)
        x = self.features(x)
        x = self.avgpool(x)
        
        x = x.view(x.size(0), -1)  # https://stackoverflow.com/questions/57234095/what-is-the-difference-of-flatten-and-view-1-in-pytorch

        x = self.classifier(x)

        return x

    def _initialize_weights(self, seed: int):

        if seed >= 0:
            torch.manual_seed(seed)

        for m in self.modules():

            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)


In [2]:
net = VGG('VGG8', use_bn=True)


In [3]:
import networkx as nx
from collections import OrderedDict

from typing import Tuple


def tree_traversal(parent_module: torch.nn.Module,
                   parent_name:   str = '',
                   G:             nx.DiGraph = nx.DiGraph(),
                   node_2_pttype: OrderedDict = OrderedDict()) -> Tuple[nx.DiGraph, OrderedDict]:

    G.add_nodes_from(set([parent_name]))
    node_2_pttype[parent_name] = type(parent_module)

    for child_name, child_module in parent_module.named_children():
        full_name = '.'.join([parent_name, child_name]) if parent_name != '' else child_name
        G.add_edges_from({(full_name, parent_name)})
        G, node_2_pttype = tree_traversal(child_module, full_name, G, node_2_pttype)
    
    return G, node_2_pttype


def get_pytorch_network_tree(net: torch.nn.Module) -> nx.DiGraph:

    T, node_2_pttype = tree_traversal(net)
    nx.set_node_attributes(T, node_2_pttype, 'type')
    
    return T


In [4]:
from collections import namedtuple
from enum import unique, Enum, IntEnum


# template appearance of GraphViz components

GVNodeAppearance = namedtuple('GVNodeAppearance', ['fontsize', 'fontcolor',                   # node label
                                                   'penwidth', 'pencolor',                    # node boundary
                                                   'shape', 'height', 'width', 'fillcolor'])  # node body
GVEdgeAppearance = namedtuple('GVEdgeAppearance', ['penwidth', 'color'])


# graph components highlights

@unique
class NodeState(IntEnum):
    INACTIVE  = 0
    ACTIVEPOS = 1
    ACTIVENEG = 2


style_node_state = {
    NodeState.INACTIVE:  GVNodeAppearance(fontsize='6', fontcolor='black',
                                          penwidth='1', pencolor='darkgray',
                                          shape='circle', height='1.0', width='1.0', fillcolor='gray'),
    NodeState.ACTIVEPOS: GVNodeAppearance(fontsize='6', fontcolor='darkgreen',
                                          penwidth='2', pencolor='darkgreen',
                                          shape='circle', height='1.0', width='1.0', fillcolor='chartreuse'),
    NodeState.ACTIVENEG: GVNodeAppearance(fontsize='6', fontcolor='crimson',
                                          penwidth='2', pencolor='crimson',
                                          shape='circle', height='1.0', width='1.0', fillcolor='firebrick2')
}


@unique
class EdgeState(IntEnum):
    INACTIVE  = 0
    ACTIVEPOS = 1
    ACTIVENEG = 2


style_edge_state = {
    EdgeState.INACTIVE:  GVEdgeAppearance(penwidth='1', color='gray'),
    EdgeState.ACTIVEPOS: GVEdgeAppearance(penwidth='2', color='darkgreen'),
    EdgeState.ACTIVENEG: GVEdgeAppearance(penwidth='2', color='crimson')
}


@unique
class NodeContainerLeaf(IntEnum):
    CONTAINER = 0
    LEAF      = 1


style_node_containerleaf = {
    NodeContainerLeaf.CONTAINER: GVNodeAppearance(fontsize='6', fontcolor='black',
                                                  penwidth='1', pencolor='black',
                                                  shape='circle', height='1.0', width='1.0', fillcolor='sienna1'),
    NodeContainerLeaf.LEAF:      GVNodeAppearance(fontsize='6', fontcolor='darkgreen',
                                                  penwidth='2', pencolor='darkgreen',
                                                  shape='circle', height='1.0', width='1.0', fillcolor='palegreen3')
}


# leaf node types

@unique
class NodeType(IntEnum):
    CONTAINER     = 0
    LINEAR        = 1
    NONLINEAR     = 2
    POOLING       = 3
    NORMALISATION = 4


style_node_type = {
    NodeType.CONTAINER:     GVNodeAppearance(fontsize='6', fontcolor='black',
                                             penwidth='1', pencolor='black',
                                             shape='circle', height='1.0', width='1.0', fillcolor='lightslategray'),
    NodeType.LINEAR:        GVNodeAppearance(fontsize='6', fontcolor='black',
                                             penwidth='1', pencolor='black',
                                             shape='circle', height='1.0', width='1.0', fillcolor='lightskyblue'),
    NodeType.NONLINEAR:     GVNodeAppearance(fontsize='6', fontcolor='black',
                                             penwidth='1', pencolor='black',
                                             shape='circle', height='1.0', width='1.0', fillcolor='lightseagreen'),
    NodeType.POOLING:       GVNodeAppearance(fontsize='6', fontcolor='black',
                                             penwidth='1', pencolor='black',
                                             shape='circle', height='1.0', width='1.0', fillcolor='lightsalmon'),
    NodeType.NORMALISATION: GVNodeAppearance(fontsize='6', fontcolor='black',
                                             penwidth='1', pencolor='black',
                                             shape='circle', height='1.0', width='1.0', fillcolor='lightgoldenrod')
}


@unique
class NodeTypeQuant(IntEnum):
    LINEAR    = 1
    NONLINEAR = 2
    POOLING   = 3


style_node_typequant = {
    NodeTypeQuant.LINEAR:    GVNodeAppearance(fontsize='6', fontcolor='black',
                                              penwidth='1', pencolor='black',
                                              shape='circle', height='1.0', width='1.0', fillcolor='skyblue3'),
    NodeTypeQuant.NONLINEAR: GVNodeAppearance(fontsize='6', fontcolor='black',
                                              penwidth='1', pencolor='black',
                                              shape='circle', height='1.0', width='1.0', fillcolor='seagreen'),
    NodeTypeQuant.POOLING:   GVNodeAppearance(fontsize='6', fontcolor='black',
                                              penwidth='1', pencolor='black',
                                              shape='circle', height='1.0', width='1.0', fillcolor='salmon2')
}


# graph rewriting rules

@unique
class NodeGRR(IntEnum):
    CONTEXT     = 0  # K-term
    TEMPLATE    = 1  # L-term
    REPLACEMENT = 2  # R-term


style_node_grr = {
    NodeGRR.CONTEXT:     GVNodeAppearance(fontsize='6', fontcolor='black',
                                          penwidth='1', pencolor='black',
                                          shape='circle', height='1.0', width='1.0', fillcolor='goldenrod'),
    NodeGRR.TEMPLATE:    GVNodeAppearance(fontsize='6', fontcolor='black',
                                          penwidth='1', pencolor='black',
                                          shape='circle', height='1.0', width='1.0', fillcolor='turquoise4'),
    NodeGRR.REPLACEMENT: GVNodeAppearance(fontsize='6', fontcolor='black',
                                          penwidth='1', pencolor='black',
                                          shape='circle', height='1.0', width='1.0', fillcolor='orange')
}


@unique
class EdgeGRR(IntEnum):
    CONTEXT             = 0  # K-term
    TEMPLATE            = 1  # L-term (to be removed)
    REPLACEMENT         = 2  # R-term (to be added)
    CONTEXT2TEMPLATE    = 3  # L-term (to be removed)
    CONTEXT2REPLACEMENT = 4  # R-term (to be added)


style_edge_grr = {
    EdgeGRR.CONTEXT:             GVEdgeAppearance(penwidth='2', color='goldenrod'),
    EdgeGRR.TEMPLATE:            GVEdgeAppearance(penwidth='2', color='turquoise4'),
    EdgeGRR.REPLACEMENT:         GVEdgeAppearance(penwidth='2', color='orange'),
    EdgeGRR.CONTEXT2TEMPLATE:    GVEdgeAppearance(penwidth='2', color='turquoise'),
    EdgeGRR.CONTEXT2REPLACEMENT: GVEdgeAppearance(penwidth='2', color='orangered2')
}


In [5]:
from typing import TypeVar
import graphviz as gv
import os
from IPython.display import display, IFrame

from typing import Union, Set, Tuple, List, Dict


NodeName = TypeVar('NodeName', int, str)


def nx_2_gv(G:            nx.DiGraph,
            node_2_label: Dict[NodeName, str] = dict(),
            node_2_style: Dict[NodeName, GVNodeAppearance] = dict(),
            arc_2_style:  Dict[NodeName, GVEdgeAppearance] = dict(),
            revert_arcs:  bool = False) -> gv.Digraph:
    """Convert a `networkx` graph into a `pygraphviz` graph.
    
    GraphViz's annotation format allows descriptive renderings of graphs, which
    make it an ideal tool for didactic purposes. This function takes a NetworkX
    directed graph and produces a PyGraphViz equivalent representation
    annotating the nodes and arcs as specified by keyword arguments.
    """
    
    gvG = gv.Digraph()
    
    for node in G.nodes:

        # retrieve node label
        try:
            label = node_2_label[node]
        except KeyError:
            label = ''

        # retrieve node style
        try:
            style = node_2_style[node]._asdict()
        except KeyError:
            style = {}

        gvG.node(str(node), label, **style, style='filled')
    
    for arc in G.edges:
        
        # retrieve edge style
        try:
            style = arc_2_style[arc]._asdict()
        except KeyError:
            style = {}

        s = arc[1] if revert_arcs else arc[0]
        e = arc[0] if revert_arcs else arc[1]
        gvG.edge(str(s), str(e), **style)  # print trees top-down in case leaves point to root, and not viceversa
    
    return gvG


def print_and_render(G:        gv.Digraph,
                     filename: str,
                     width:    int,
                     height:   int) -> None:
    G.render(directory=os.path.curdir, filename=filename)
    display(IFrame(filename + '.pdf', width, height))


def show_highlighted_part(G:        nx.DiGraph,
                          VH:       Set[NodeName],
                          EH:       Union[Set[Tuple[NodeName, NodeName]], None] = None,
                          show_pos: bool = True) -> None:
    
    assert VH.issubset(set(G.nodes))
    
    if EH is not None:
        assert EH.issubset(set(G.edges))
    else:
        EH = set(G.subgraph(VH).edges)
        
    node_2_style = {n: style_node_state[NodeState.ACTIVEPOS if show_pos else NodeState.ACTIVENEG] if n in VH else style_node_state[NodeState.INACTIVE] for n in set(G.nodes)}
    arc_2_style  = {a: style_edge_state[EdgeState.ACTIVEPOS if show_pos else EdgeState.ACTIVENEG] if a in EH else style_edge_state[EdgeState.INACTIVE] for a in set(G.edges)}
    
    gvG = nx_2_gv(G, node_2_style=node_2_style, arc_2_style=arc_2_style)
    print_and_render(gvG, 'temp', 900, 700)


In [6]:
T = get_pytorch_network_tree(net)

nsd = {n: style_node_containerleaf[NodeContainerLeaf.LEAF] if len(list(T.predecessors(n))) == 0 else style_node_containerleaf[NodeContainerLeaf.CONTAINER] for n in set(T.nodes)}
gvT = nx_2_gv(T, node_2_label={n: n for n in set(T.nodes)}, node_2_style=nsd, revert_arcs=True)
print_and_render(gvT, 'VGG8_leaves', 900, 160)


In [7]:
pttype_2_type = {
    VGG:                                        NodeType.CONTAINER,
    torch.nn.modules.container.Sequential:      NodeType.CONTAINER,
    torch.nn.modules.conv.Conv2d:               NodeType.LINEAR,
    torch.nn.modules.linear.Linear:             NodeType.LINEAR,
    torch.nn.modules.batchnorm.BatchNorm1d:     NodeType.NORMALISATION,
    torch.nn.modules.batchnorm.BatchNorm2d:     NodeType.NORMALISATION,
    torch.nn.modules.pooling.MaxPool2d:         NodeType.POOLING,
    torch.nn.modules.pooling.AdaptiveAvgPool2d: NodeType.POOLING,
    torch.nn.modules.activation.ReLU:           NodeType.NONLINEAR
}

gvT = nx_2_gv(T, node_2_label={n: n for n in set(T.nodes)}, node_2_style={n: style_node_type[pttype_2_type[pttype]] for n, pttype in nx.get_node_attributes(T, 'type').items()}, revert_arcs=True)
print_and_render(gvT, 'VGG8_optypes', 900, 160)


## Algebraic graph rewriting

The task of *graph rewriting* is concerned with transforming graphs into other graphs.
Therefore, it is an important topic in graph theory.


In [8]:
# define mock-up node types to mimic standard operations in DL computational graphs
nxtypes = {'L', 'NL', 'P', 'BN'}

nxtype_2_type = {
    'L':  NodeType.LINEAR,
    'NL': NodeType.NONLINEAR,
    'P':  NodeType.POOLING,
    'BN': NodeType.NORMALISATION
}

# define nodes and connectivity of an example graph
G_node_2_nxtype = OrderedDict([( 0, 'L'),
                               ( 1, 'BN'),
                               ( 2, 'NL'),
                               ( 3, 'P'),
                               ( 4, 'L'),  ( 9, 'L'),
                               ( 5, 'BN'), (10, 'BN'),
                               ( 6, 'NL'),
                               ( 7, 'L'),
                               ( 8, 'BN'),
                               (11, 'L'),
                               (12, 'NL'),
                               (13, 'P'),
                               (14, 'L'),  (19, 'L'),
                               (15, 'BN'), (20, 'P'),
                               (16, 'NL'),
                               (17, 'L'),
                               (18, 'BN'),
                               (21, 'L'),
                               (22, 'NL'),
                               (23, 'P'),
                               (24, 'L')])
VG                          = set(G_node_2_nxtype.keys())  # nodes set
EG                          = {( 0,  1),                   # arcs set
                               ( 1,  2),
                               ( 2,  3),
                               ( 3,  4), ( 3,  9),
                               ( 4,  5), ( 9, 10),
                               ( 5,  6),
                               ( 6,  7),
                               ( 7,  8),
                               ( 8, 11), (10, 11),
                               (11, 12),
                               (12, 13),
                               (13, 14), (13, 19),
                               (14, 15), (19, 20),
                               (15, 16),
                               (16, 17),
                               (17, 18),
                               (18, 21), (20, 21),
                               (21, 22),
                               (22, 23),
                               (23, 24)}

# build (directed) graph
G = nx.DiGraph()
G.add_nodes_from(VG)
G.add_edges_from(EG)

# label nodes with type information
nx.set_node_attributes(G, G_node_2_nxtype, 'nxtype')

# show graph
nsd  = {n: style_node_type[nxtype_2_type[nxtype]] for n, nxtype in nx.get_node_attributes(G, 'nxtype').items()}  # GraphViz style dictionary
gvG = nx_2_gv(G, node_2_style=nsd)
print_and_render(gvG, 'G_nxtypes', 900, 700)


In [9]:
# define template L nodes and connectivity
L_node_2_nxtype = OrderedDict([( 0, 'NL'),
                               ( 1, 'P'),
                               ( 2, 'L'),  ( 7, 'L'),
                               ( 3, 'BN'), ( 8, 'BN'),
                               ( 4, 'NL'),
                               ( 5, 'L'),
                               ( 6, 'BN'),
                               ( 9, 'L'),
                               (10, 'NL'),
                               (11, 'P')])
VL                          = set(L_node_2_nxtype.keys())
EL                          = {( 0,  1),
                               ( 1,  2), ( 1,  7),
                               ( 2,  3), ( 7,  8),
                               ( 3,  4),
                               ( 4,  5),
                               ( 5,  6),
                               ( 6,  9), ( 8,  9),
                               ( 9, 10),
                               (10, 11)}

# build (directed) graph
L = nx.DiGraph()
L.add_nodes_from(VL)
L.add_edges_from(EL)

# relabel nodes to reflect the role they have in the GRR
VK = {0, 11}
nx.relabel_nodes(L, {n: '/'.join([('K' if n in VK else 'L') + '-term', str(n)]) for n in VL}, copy=False)
VL = set(L.nodes)
EL = set(L.edges)

# define context K as sub-graph of L
VK = {n for n in L.nodes if n.startswith('K')}
K  = L.subgraph(VK)  # in this case we use the induced sub-graph (https://en.wikipedia.org/wiki/Induced_subgraph), which has no edges
EK = set(K.edges)

# isolate "core" template L\K and context-template connections
VLK     = VL.difference(VK)
LK      = L.subgraph(VLK)
ELK     = set(LK.edges)
EK2LK2K = EL.difference(EK | ELK)

# define "core" replacement R\K
RK_node_2_nxtype = OrderedDict([(12, 'P'),
                                (13, 'L'),  (16, 'L'),
                                (14, 'NL'),
                                (15, 'L'),
                                (17, 'L'),
                                (18, 'NL')])
VRK              = set(RK_node_2_nxtype.keys())
ERK              = {(12, 13), (12, 16),
                    (13, 14),
                    (14, 15),
                    (15, 17), (16, 17),
                    (17, 18)}

# build (directed) graph
RK = nx.DiGraph()
RK.add_nodes_from(VRK)
RK.add_edges_from(ERK)

# relabel nodes to reflect the role they have in the GRR
nx.relabel_nodes(RK, {n: '/'.join(['R-term', str(n)]) for n in VRK}, copy=False)
VRK = set(RK.nodes)
ERK = set(RK.edges)

# glue R\K to the context graph
S = nx.compose(L, RK)
EK2RK2K = {('K-term/0', 'R-term/12'), ('R-term/18', 'K-term/11')}
S.add_edges_from(EK2RK2K)

# identify the (full) replacement R
R  = S.subgraph(VK | VRK)
VR = set(R.nodes)
ER = set(R.edges)


In [10]:
# prepare the styles of the GRR's components
node_grr = dict()
node_grr.update({n: NodeGRR.CONTEXT     for n in VK})
node_grr.update({n: NodeGRR.TEMPLATE    for n in VLK})
node_grr.update({n: NodeGRR.REPLACEMENT for n in VRK})

arc_grr = dict()
arc_grr.update({a: EdgeGRR.CONTEXT             for a in EK})
arc_grr.update({a: EdgeGRR.TEMPLATE            for a in ELK})
arc_grr.update({a: EdgeGRR.REPLACEMENT         for a in ERK})
arc_grr.update({a: EdgeGRR.CONTEXT2TEMPLATE    for a in EK2LK2K})
arc_grr.update({a: EdgeGRR.CONTEXT2REPLACEMENT for a in EK2RK2K})

nx.set_node_attributes(S, node_grr, 'node_grr')
nx.set_edge_attributes(S, arc_grr,  'arc_grr')

# show the GRR components
nsd = {n: style_node_grr[grr] for n, grr in nx.get_node_attributes(S, 'node_grr').items()}
asd = {a: style_edge_grr[grr] for a, grr in nx.get_edge_attributes(S, 'arc_grr').items()}
gvS = nx_2_gv(S, node_2_style=nsd, arc_2_style=asd)
print_and_render(gvS, 'S_GRR', 900, 700)


In [11]:
show_highlighted_part(S, VL)

In [12]:
show_highlighted_part(S, VLK)

In [13]:
show_highlighted_part(S, set(), EH=EK2LK2K)

In [14]:
show_highlighted_part(S, VRK)

In [20]:
from networkx.algorithms import isomorphism


def find_morphisms(G: nx.DiGraph,
                   L: nx.DiGraph) -> List[Dict[NodeName, NodeName]]:

    matcher      = isomorphism.DiGraphMatcher(G, L)
    isomorphisms = list(matcher.subgraph_isomorphisms_iter())
    
    return isomorphisms


isomorphisms = find_morphisms(G, L)

In [21]:
show_highlighted_part(G, set(isomorphisms[0].keys()), show_pos=True)

In [22]:
show_highlighted_part(G, set(isomorphisms[1].keys()), show_pos=False)