In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('..')

In [2]:
import networkx as nx
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
from nn_homology import nn_graph
from archs.mnist.AlexNet import AlexNet_nmp

import persim # see persim.scikit-tda.org
from ripser import ripser # see ripser.scikit-tda.org

To generate the model used in this example, I ran:

`python main.py --arch_type alexnet --dataset mnist --prune_percent 95 --prune_iterations 2 --end_iter 30`

In [3]:
# Global-like variable definitions.
model_name = 'alexnet_nmp'
dataset_name = 'mnist'
data_location = '../data' # location of training data (MNIST, FashionMNIST, CIFAR, etc.)
seed = 0
model_loc0 = 'remote_saves/{}/{}/{}/0/model_lt_20.pth.tar'.format(model_name, dataset_name, seed) # location of saved, un-pruned model 
model_loc1 = 'remote_saves/{}/{}/{}/1/model_lt_20.pth.tar'.format(model_name, dataset_name, seed) # location of saved, pruned model (after 1 prune iteration)

## Unpruned Model Homology

In [None]:
# load the model.
model = torch.load(model_loc0)
alnt = AlexNet_nmp()
print(alnt.param_info)

In [None]:
NNG = nn_graph.NNGraph()
%time NNG.parameter_graph(model, alnt.param_info, (1,1,28,28))

In [7]:
# helper function for testing model, outputs accuracy
def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy

In [8]:
# Test to make sure model works.
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
testdataset = datasets.MNIST(data_location, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=32, shuffle=False, num_workers=0,drop_last=True)
criterion = torch.nn.CrossEntropyLoss()
print('Accuracy: {}'.format(test(model, test_loader, criterion)))

NameError: name 'model' is not defined

In [None]:
# out of curiosity...
# nx.dag_longest_path(G)

In [None]:
# THIS WILL TAKE A LONG TIME WITH UNPRUNED ALEXNET
# compute rips persistent homology (up to 1st dimension) over entire network 
# using (sparse) adjacency matrix as distance matrix.
%time rips = ripser(nx.to_scipy_sparse_matrix(NNG.G), distance_matrix=True, maxdim=1)
rips

In [None]:
# plot persistence diagram in dimensions 0 and 1 (on same axes).
# points at infinity (homology groups) are plotted on the dotted 
# line which represents the point \infty.
persim.plot_diagrams(rips['dgms'][0])

## Pruned LT Homology

In [4]:
# load the LT model.
model_lt = torch.load(model_loc1)



In [5]:
alnt = AlexNet_nmp()

In [9]:
# test this mode-l's accuracy. (not actually a LT because I'm lazy and picked too high of a threshold)
print('Accuracy: {}'.format(test(model_lt, test_loader, criterion)))

Accuracy: 98.91


In [None]:
# compute networkx representation of LT NN.
# NNG.update_adjacency(model_lt)

In [None]:
# compute rips persistent homology (up to 1st dimension) over entire network 
# using (sparse) adjacency matrix as distance matrix.
# %time rips_lt = ripser(scipy.sparse.csr_matrix(NNG.get_adjacency()), distance_matrix=True, maxdim=1)
# rips_lt

In [None]:
# plot persistence diagram in dimensions 0 and 1 (on same axes).
# points at infinity (homology groups) are plotted on the dotted 
# line which represents the point \infty.
# persim.plot_diagrams(rips_lt['dgms'][0])

In [10]:
NNGLT = nn_graph.NNGraph()
%time NNGLT.parameter_graph(model_lt, alnt.param_info, (1,1,28,28), ignore_zeros=True)

Layer: Conv1
Layer: Conv2
Layer: Conv3
Layer: Conv4
Layer: Conv5
Layer: Conv6
Layer: Conv7
Layer: Conv8
Layer: Linear1
Layer: Linear2
Layer: Linear3
CPU times: user 5min 48s, sys: 3.42 s, total: 5min 51s
Wall time: 5min 42s


<networkx.classes.graph.Graph at 0x7f8f8056aca0>

In [None]:
%time rips_lt2 = ripser(nx.to_scipy_sparse_matrix(NNGLT.G), distance_matrix=True, maxdim=1)
rips_lt2

In [None]:
persim.plot_diagrams(rips_lt2['dgms'][0])