In [1]:
import os

import torch
import networkx as nx
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn

from nn_homology.nn_graph import activation_graph, parameter_graph

In [2]:
# instantiate simple lenet model
class LeNet5(nn.Module):
    def __init__(self, num_classes=10):
        super(LeNet5, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.classifier = nn.Sequential(
            nn.Linear(64*14*14, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )

        self.param_info = [{'layer_type': 'Conv2d', 'kernel_size':(3,3), 'stride':1, 'padding': 1, 'name':'Conv1'},
                            {'layer_type': 'Conv2d', 'kernel_size':(3,3), 'stride':1, 'padding':1, 'name':'Conv2'},
                            {'layer_type':'MaxPool2d', 'kernel_size':(2,2), 'stride':2, 'padding':0, 'name':'MaxPool1'},
                            {'layer_type':'Linear', 'name': 'Linear1'},
                            {'layer_type':'Linear', 'name': 'Linear2'}]


    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

## Parameter Graph

In [3]:
# expected size of input layer
input_size = (1,1,28,28)

In [4]:
# load the model (untrained)
model = LeNet5()

In [5]:
G = parameter_graph(model, model.param_info, input_size)

Layer: Conv1
Channel: 0
Layer: Conv2
Channel: 0
Channel: 1
Channel: 2
Channel: 3
Channel: 4
Channel: 5
Channel: 6
Channel: 7
Channel: 8
Channel: 9
Channel: 10
Channel: 11
Channel: 12
Channel: 13
Channel: 14
Channel: 15
Channel: 16
Channel: 17
Channel: 18
Channel: 19
Channel: 20
Channel: 21
Channel: 22
Channel: 23
Channel: 24
Channel: 25
Channel: 26
Channel: 27
Channel: 28
Channel: 29
Channel: 30
Channel: 31
Channel: 32
Channel: 33
Channel: 34
Channel: 35
Channel: 36
Channel: 37
Channel: 38
Channel: 39
Channel: 40
Channel: 41
Channel: 42
Channel: 43
Channel: 44
Channel: 45
Channel: 46
Channel: 47
Channel: 48
Channel: 49
Channel: 50
Channel: 51
Channel: 52
Channel: 53
Channel: 54
Channel: 55
Channel: 56
Channel: 57
Channel: 58
Channel: 59
Channel: 60
Channel: 61
Channel: 62
Channel: 63
Layer: MaxPool1
Channel: 0
Channel: 1
Channel: 2
Channel: 3
Channel: 4
Channel: 5
Channel: 6
Channel: 7
Channel: 8
Channel: 9
Channel: 10
Channel: 11
Channel: 12
Channel: 13
Channel: 14
Channel: 15
Channel

In [6]:
# sanity check
nx.dag_longest_path(G)

['Conv1_0_747',
 'Conv2_48_718',
 'MaxPool1_0_690',
 'Linear1_0_177',
 'Linear2_0_6',
 'Output_0_8']

## Activation Graph

In [9]:
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
testdataset = datasets.MNIST('../../data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=1, shuffle=False, num_workers=0)

In [11]:
G_act = None
# make sure data and model are on same device
device = torch.device("cpu")
model = model.to(device)

# our batch size is 1, so compute activation homology network for single  
data = next(iter(test_loader))[0]
data = data.to(device)
G_act = activation_graph(model, model.param_info, data)

Layer: Conv1
Channel: 0
Layer: Conv2
Channel: 0
Channel: 1
Channel: 2
Channel: 3
Channel: 4
Channel: 5
Channel: 6
Channel: 7
Channel: 8
Channel: 9
Channel: 10
Channel: 11
Channel: 12
Channel: 13
Channel: 14
Channel: 15
Channel: 16
Channel: 17
Channel: 18
Channel: 19
Channel: 20
Channel: 21
Channel: 22
Channel: 23
Channel: 24
Channel: 25
Channel: 26
Channel: 27
Channel: 28
Channel: 29
Channel: 30
Channel: 31
Channel: 32
Channel: 33
Channel: 34
Channel: 35
Channel: 36
Channel: 37
Channel: 38
Channel: 39
Channel: 40
Channel: 41
Channel: 42
Channel: 43
Channel: 44
Channel: 45
Channel: 46
Channel: 47
Channel: 48
Channel: 49
Channel: 50
Channel: 51
Channel: 52
Channel: 53
Channel: 54
Channel: 55
Channel: 56
Channel: 57
Channel: 58
Channel: 59
Channel: 60
Channel: 61
Channel: 62
Channel: 63
Layer: MaxPool1
Channel: 0
Channel: 1
Channel: 2
Channel: 3
Channel: 4
Channel: 5
Channel: 6
Channel: 7
Channel: 8
Channel: 9
Channel: 10
Channel: 11
Channel: 12
Channel: 13
Channel: 14
Channel: 15
Channel

In [12]:
# sanity check
nx.dag_longest_path(G_act)

['Conv1_0_717', 'Conv2_3_744', 'MaxPool1_26_773']