In [4]:
%load_ext autoreload
%autoreload 2
import os

In [9]:
# %load compute_homology.py
import networkx as nx
import torch
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import numpy as np
import scipy.sparse
import matplotlib.pyplot as plt
import pickle
import persim  # see persim.scikit-tda.org
from ripser import ripser  # see ripser.scikit-tda.org

from archs.mnist.fc1 import fc1 as fc1_mnist
from archs.cifar10.fc1 import fc1 as fc1_cifar10

from archs.mnist.AlexNet import AlexNet as AlexNet_mnist
from archs.mnist.AlexNet import AlexNet_nmp as AlexNet_nmp_mnist
from archs.cifar10.AlexNet import AlexNet as AlexNet_cifar10
from archs.cifar10.AlexNet import AlexNet_nmp as AlexNet_nmp_cifar10

from archs.mnist.LeNet5 import LeNet5 as LeNet5_mnist
from archs.mnist.LeNet5 import LeNet5_nmp as LeNet5_nmp_mnist
from archs.cifar10.LeNet5 import LeNet5 as LeNet5_cifar10
from archs.cifar10.LeNet5 import LeNet5_nmp as LeNet5_nmp_cifar10

from archs.mnist.resnet import resnet18 as resnet18_mnist
from archs.mnist.resnet_nmp import resnet18 as resnet18_nmp_mnist
from archs.cifar10.resnet import resnet18 as resnet18_cifar10
from archs.cifar10.resnet_nmp import resnet18 as resnet18_nmp_cifar10

from nn_homology import nn_graph
import argparse

import matplotlib.pyplot as plt

model_graph_dict = {}

def get_model_param_info(model_name, dataset):
    model_param = {
        "fc1_mnist": fc1_mnist,
        "fc1_cifar10": fc1_cifar10,

        "alexnet_mnist": AlexNet_mnist,
        "alexnet_nmp_mnist": AlexNet_nmp_mnist,
        "alexnet_cifar10": AlexNet_cifar10,
        "alexnet_nmp_cifar10": AlexNet_nmp_cifar10,

        "lenet5_mnist": AlexNet_mnist,
        "lenet5_nmp_mnist": LeNet5_nmp_mnist,
        "lenet5_cifar10": LeNet5_cifar10,
        "lenet5_nmp_cifar10": LeNet5_nmp_cifar10,

        "resnet18_mnist": resnet18_mnist,
        "resnet18_nmp_mnist": resnet18_nmp_mnist,
        "resnet18_cifar10": resnet18_cifar10,
        "resnet18_nmp_cifar10": resnet18_nmp_cifar10
    }
    architecture = model_name + "_" + dataset
    print("Getting parameters for: ", architecture)
    param_info = model_param[architecture]().param_info
    return param_info


def compute_homology(model, dataset, root_dir):
    for listed_file in sorted(os.listdir(root_dir)):
        if (listed_file[0].isdigit()):
            print("epoch: ", listed_file)
            if listed_file != ".ipynb_checkpoints":
                best_model_per_pruning_it_location = root_dir + listed_file + "/" + "model_lt_20.pth.tar"
                # print(best_model_per_pruning_it_location)
                if (os.path.isfile(best_model_per_pruning_it_location)):
                    computer_per_model_homology(model, dataset, root_dir, listed_file,
                                                best_model_per_pruning_it_location)

                


def computer_per_model_homology(model_name, dataset, root_dir, epoch, model_location):
    rips_pickle_dir = root_dir + "pickle/"
    # print(rips_pickle_dir)
    persim_image_dir = root_dir + "persim/"
    # print(persim_image_dir)

    model = torch.load(model_location)
    if dataset == 'mnist':
        input_dim = (1, 1, 28, 28)
    elif dataset == 'cifar10':
        input_dim = (1, 3, 32, 32)

    param_info = get_model_param_info(model_name, dataset)

    architecture = model_name + "_" + dataset
    if (architecture not in model_graph_dict) or (epoch == 0):
        print(("Architecture: {} not found, creating").format(architecture))
        NNG = nn_graph.NNGraph()
        NNG.parameter_graph(model, param_info, input_dim, ignore_zeros=True)
        model_graph_dict[architecture] = NNG
    else:
        print(("Architecture: {} found, loading ... ").format(architecture))
        NNG = model_graph_dict[architecture]
        NNG.update_adjacency(model)

    rips = ripser(scipy.sparse.csr_matrix(NNG.get_adjacency()), distance_matrix=True, maxdim=2, do_cocycles=True)
    # root_dir contains something in the format of:
    # /home/udit/programs/LTHT/remote_data/saves/alexnet_nmp/mnist/0/

    if not (os.path.isdir(rips_pickle_dir)):
        os.mkdir(rips_pickle_dir)
    rips_file = rips_pickle_dir + epoch
    rips_pickle = open(rips_file + ".pickle", "wb")
    pickle.dump(rips, rips_pickle)
    rips_pickle.close()

    # save ripser file as pickle
    persim.plot_diagrams(rips['dgms'])

    if not (os.path.isdir(persim_image_dir)):
        os.mkdir(persim_image_dir)
    persim_plot_file = persim_image_dir + epoch
    plt.savefig(persim_plot_file + ".jpg")
    plt.clf()


def main(args):
    ROOT_DIR = args.root_dir
    model_name = args.model_name
    dataset = args.dataset
    seed = args.seed

    model_dataset_seed_dir = ROOT_DIR + "{}/{}/{}/".format(model_name, dataset, seed)
    print("In: ", model_dataset_seed_dir)

    if (os.path.isdir(model_dataset_seed_dir)):
        compute_homology(model_name, dataset, model_dataset_seed_dir)


# if __name__ == '__main__':
#     parser = argparse.ArgumentParser()

#     parser.add_argument("--root_dir", default="/home/udit/programs/LTHT/data/saves/", type=str)
#     parser.add_argument("--model_name", default='fc1', type=str)
#     parser.add_argument("--dataset", default='mnist', type=str)
#     parser.add_argument("--seed", default='0', type=str)

#     args = parser.parse_args()
#     print(args)
#     main(args)


In [10]:
os.getcwd()

'/home/udit/programs/LTHT/LTHT'

In [11]:
# ROOT_DIR = '/home/udit/programs/LTHT/data/saves/' # need absolute root dir location
ROOT_DIR = '/home/udit/programs/LTHT/remote_data/saves/' # need absolute root dir location
# model_list = ['alexnet_nmp', 'fc1', 'lenet5_nmp', 'resnet18', 'resnet18_nmp', 'vgg16']
model_list = ['lenet5_nmp'] #, 'fc1', 'lenet5_nmp', 'alexnet_nmp']
dataset_list = ['mnist', 'cifar10']
random_seed = ['0', '42', '1337']
# model_graph_dict = {}

In [None]:
for model_name in model_list:
    print(model_name)
    for dataset in dataset_list:
        print(dataset)
        for seed in random_seed:
            print(seed)
            model_dataset_seed_dir = ROOT_DIR + "{}/{}/{}/".format(model_name, dataset, seed)
            if (os.path.isdir(model_dataset_seed_dir)):
                compute_homology(model_name, dataset, model_dataset_seed_dir)
            break
        break
    break

lenet5_nmp
mnist
0
epoch:  0
Getting parameters for:  lenet5_nmp_mnist
Architecture: lenet5_nmp_mnist not found, creating
Layer: Conv1
Layer: Conv2


In [13]:
rips_dir = "/home/udit/programs/LTHT/data/saves/fc1/mnist/0/pickle/"
for file1 in os.listdir(rips_dir):
    if file1[0].isdigit():
        prune_iteration1 = file1[0]
        rips1 = pickle.load(open(rips_dir+file1, 'rb'))
        for file2 in os.listdir(rips_dir):
            if file2[0].isdigit() and file1 != file2:
                print(file1, file2)
                prune_iteration2 = file2[0]
                rips2 = pickle.load(open(rips_dir+file2, 'rb'))
                %time distance_bottleneck, (matching, D) = persim.bottleneck(rips1['dgms'][0], rips2['dgms'][0], matching=True)
#                 %time persim.bottleneck_matching(rips1['dgms'][0], rips2['dgms'][0], matching, D, labels=['FC $H_0$', 'LT $H_0$'])
                print('Bottleneck Distance: {}'.format(distance_bottleneck))
                break
    break
                

1.pickle 9.pickle
CPU times: user 56.2 s, sys: 35.8 ms, total: 56.2 s
Wall time: 56.2 s
Bottleneck Distance: 0.09660559892654419
