In [None]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('..')

In [None]:
import networkx as nx
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
from nn_homology import nn_graph

import persim # see persim.scikit-tda.org
from ripser import ripser # see ripser.scikit-tda.org

In [None]:
# Global-like variable definitions.
model_name = 'lenet5_nmp'
dataset_name = 'cifar10'
data_location = '../data' # location of training data (MNIST, FashionMNIST, CIFAR, etc.)
seed = 42 # prune percentage for LT network
model_loc0 = 'remote_saves/{}/{}/{}/0/model_lt_20.pth.tar'.format(model_name, dataset_name, seed) # location of saved, un-pruned model 
model_loc1 = 'remote_saves/{}/{}/{}/9/model_lt_20.pth.tar'.format(model_name, dataset_name, seed) # location of saved, pruned model (after 1 prune iteration)
input_size = (1,3,32,32)
from archs.cifar10.LeNet5 import LeNet5_nmp as Mc

## Unpruned Model Homology

In [None]:
# load the model.
model = torch.load(model_loc0)
model_class = Mc()

In [None]:
NNG = nn_graph.NNGraph()
NNG.parameter_graph(model, model_class.param_info, input_size, ignore_zeros=True)

In [None]:
# helper function for testing model, outputs accuracy
def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy

In [None]:
# Test to make sure model works.
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
testdataset = datasets.CIFAR10(data_location, train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=32, shuffle=False, num_workers=0,drop_last=True)
criterion = torch.nn.CrossEntropyLoss()
print('Accuracy: {}'.format(test(model, test_loader, criterion)))

In [None]:
def sparse_min_row(csr_mat):
    ret = np.zeros(csr_mat.shape[0])
    ret[np.diff(csr_mat.indptr) != 0] = np.minimum.reduceat(csr_mat.data,csr_mat.indptr[:-1][np.diff(csr_mat.indptr)>0])
    return ret

In [None]:
# compute rips persistent homology (up to 1st dimension) over entire network 
# using (sparse) adjacency matrix as distance matrix.

def sparse_min_row(csr_mat):
    ret = np.zeros(csr_mat.shape[0])
    ret[np.diff(csr_mat.indptr) != 0] = np.minimum.reduceat(csr_mat.data,csr_mat.indptr[:-1][np.diff(csr_mat.indptr)>0])
    return ret

sps = nx.to_scipy_sparse_matrix(NNG.G)
mrs = sparse_min_row(sps)
sps.setdiag(mrs)


%time rips = ripser(sps, distance_matrix=True, maxdim=1)
rips

In [None]:
# plot persistence diagram in dimensions 0 and 1 (on same axes).
# points at infinity (homology groups) are plotted on the dotted 
# line which represents the point \infty.
persim.plot_diagrams(rips['dgms'][0])

## Pruned LT Homology

In [None]:
# load the LT model.
model_lt = torch.load(model_loc1)

In [None]:
# test this model's accuracy.
print('Accuracy: {}'.format(test(model_lt, test_loader, criterion)))

In [None]:
NNGLT = nn_graph.NNGraph()
NNGLT.parameter_graph(model_lt, param_info, input_size, ignore_zeros=True)

In [None]:
# compute rips persistent homology (up to 1st dimension) over entire network 
# using (sparse) adjacency matrix as distance matrix.
sps = nx.to_scipy_sparse_matrix(NNGLT.G)
mrs = sparse_min_row(sps)
sps.setdiag(mrs)
%time rips_lt = ripser(sps, distance_matrix=True, maxdim=1)
rips_lt

In [None]:
# plot persistence diagram in dimensions 0 and 1 (on same axes).
# points at infinity (homology groups) are plotted on the dotted 
# line which represents the point \infty.
persim.plot_diagrams(rips_lt['dgms'][0])

## Compare Networks

In [None]:
# compute the bottleneck distance between networks and plot the implicit matching. 
# bottleneck distance is defined as the distance between the farthest-apart matched points. 
# NOTE: the persim package ignores points at infinity, so this calculation still returns 
# a bounded result. Technically, the bottleneck distance between the two networks is \infty. 
# %time distance_bottleneck = persim.bottleneck(rips['dgms'][0], rips_lt['dgms'][0], matching=False)
# persim.bottleneck_matching(rips['dgms'][0], rips_lt['dgms'][0], matching, D, labels=['FC $H_0$', 'LT $H_0$'])
# print('Bottleneck Distance: {}'.format(distance_bottleneck))
# %time sliced = persim.sliced_wasserstein(rips['dgms'][0], rips_lt['dgms'][0])