In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir('..')

In [2]:
import networkx as nx
import random
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn.functional as F
import numpy as np
import pandas as pd
import scipy.sparse
import matplotlib.pyplot as plt
from nn_homology import nn_graph
import dionysus as dion
from networkx.drawing.nx_agraph import graphviz_layout
from scipy.cluster import hierarchy
from scipy.spatial.distance import squareform
from scipy.sparse.csgraph import maximum_flow

import persim # see persim.scikit-tda.org
from ripser import ripser # see ripser.scikit-tda.org

In [3]:
# Global-like variable definitions.
model_name = 'resnet18'
dataset_name = 'cifar10'
data_location = '../data' # location of training data (MNIST, FashionMNIST, CIFAR, etc.)
seeds = [0]
model_loc0 = 'saves/{}/{}/{}/prune_all/global/0/model_lt_20.pth.tar' # location of saved, un-pruned model 
input_size = (1,3,32,32)
from archs.cifar10.resnet import resnet18 as Mc
isd_loc = 'saves/{}/{}/{}/prune_all/global/initial_state_dict_lt.pth.tar'
percentile = 85
percentile_filtration = True

epochs = 5
# prune_percents = np.flip([80,90,95,98,99,99.5,99.7,99.8,99.9,100])
prune_percents = 100-np.geomspace(0.01,20,num=10)

In [4]:
prune_percents

array([99.99      , 99.97673082, 99.94585452, 99.8740079 , 99.70682668,
       99.31780968, 98.41259895, 96.30624765, 91.40494055, 80.        ])

In [5]:
def to_int(x):
    return x*1e8

def max_flow_edgelist(master_g, master_capacity=1./10.):
    nodes = list(master_g.nodes())
    for node in nodes:
        if 'conv1' in node and 'layer' not in node:
            master_g.add_edge('master', node, weight=master_capacity)
    nodes = list(master_g.nodes())
    sps = nx.to_scipy_sparse_matrix(master_g)
    
    sps.data = to_int((1./sps.data)-1).astype('int')
    sps.data = sps.data - sps.data.min() + 1
    
    flow_results = {}
    output_nodes = ['Output_0_{}'.format(o) for o in range(10)]

    mix = nodes.index('master')
    for output_node in output_nodes:
        eix = nodes.index(output_node)
        mf = maximum_flow(sps, mix, eix)
        flow_results[output_node] = mf    
        if mf.flow_value > 0:
            print(output_node, mf.flow_value, mix, eix)
            
    edgelist = []
    for k,v in flow_results.items():
        en = 0
        print(k)
        ng = nx.from_scipy_sparse_matrix(v.residual)
        nns = list(ng.nodes())
        nodemap = {nns[i]:nodes[i] for i in range(len(nns))}
        ng = nx.relabel_nodes(ng, nodemap)
        for edge in ng.edges(data=True):
            if edge[2]['weight'] < 0:
                en += 1
                edgelist.append((edge[0],edge[1]))
        print(en)
    return edgelist

def train(model, device, train_loader, optimizer, epoch, criterion):
    EPS = 1e-6
    model.train()
    for batch_idx, (imgs, targets) in enumerate(train_loader):
        optimizer.zero_grad()
        imgs, targets = imgs.to(device), targets.to(device)
        output = model(imgs)
        train_loss = criterion(output, targets)
        train_loss.backward()

        # Freezing Pruned weights by making their gradients Zero
        for name, p in model.named_parameters():
            if 'weight' in name:
                tensor = p.data.cpu().numpy()
                grad_tensor = p.grad.data.cpu().numpy()
                grad_tensor = np.where(tensor < EPS, 0, grad_tensor)
                p.grad.data = torch.from_numpy(grad_tensor).to(device)
        optimizer.step()
        
#         if batch_idx % 200000 == 0:
#             print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
#                 epoch, batch_idx * len(imgs), len(train_loader.dataset),
#                 100. * batch_idx / len(train_loader), train_loss.item()))
    return train_loss.item()

def test(model, test_loader, criterion):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.data.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct += pred.eq(target.data.view_as(pred)).sum().item()
        test_loss /= len(test_loader.dataset)
        accuracy = 100. * correct / len(test_loader.dataset)
    return accuracy, test_loss

In [6]:
if dataset_name == 'mnist':
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])
    traindataset = datasets.MNIST(data_location, train=True, download=False, transform=transform)
    testdataset = datasets.MNIST(data_location, train=False, transform=transform)
if dataset_name == 'cifar10':
    transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])
    traindataset = datasets.CIFAR10(data_location, train=True, download=False, transform=transform)
    testdataset = datasets.CIFAR10(data_location, train=False, transform=transform)

train_loader = torch.utils.data.DataLoader(traindataset, batch_size=60, shuffle=True, num_workers=0, drop_last=False)
test_loader = torch.utils.data.DataLoader(testdataset, batch_size=60, shuffle=False, num_workers=0, drop_last=True)

criterion = torch.nn.CrossEntropyLoss()

In [None]:
# seed, prune_type, prune_percent, epoch, accuracy, loss
results = []
for seed in seeds:
    
    torch.manual_seed(seed)
    
#     model = torch.load(model_loc0.format(model_name, dataset_name, seed))
    model = torch.load(isd_loc.format(model_name, dataset_name, seed))
    
    NNGD = nn_graph.NNGraph(undirected=False)
    mc = Mc()
    
    if percentile_filtration:
        ps = []
        for name, param in model.named_parameters():
            if 'weight' in name and 'bn' not in name:
                pnum = param.data.cpu().numpy()
                ps.append(pnum.flatten())
        ps = np.concatenate(ps)
        NNGD.parameter_graph(model, mc.param_info, input_size, ignore_zeros=False, threshold=1./(1.+np.percentile(np.abs(ps), percentile)), verbose=True)
    else:
        NNGD.parameter_graph(model, mc.param_info, input_size, ignore_zeros=True, verbose=True)
    
    edgelist = max_flow_edgelist(NNGD.G.copy())
    
    new_graph = NNGD.G.edge_subgraph(edgelist).copy()
    
    model_weights = nn_graph.get_weights(model)
    model_param_info = nn_graph.append_params(mc.param_info, model_weights)
    model_fps = nn_graph.flatten_params(model_param_info).copy()
    
    NNG = nn_graph.NNGraph()
    NNG.G = new_graph.to_undirected()
    NNG.update_indices()




Layer: conv1
Layer: layer1.0.conv1
Layer: layer1.0.conv2
Layer: layer1.0.shortcut
Layer: layer1.1.conv1
Layer: layer1.1.conv2
Layer: layer1.1.shortcut
Layer: layer2.0.conv1
Layer: layer2.0.conv2
Layer: layer2.0.shortcut
Layer: layer2.1.conv1
Layer: layer2.1.conv2
Layer: layer2.1.shortcut
Layer: layer3.0.conv1
Layer: layer3.0.conv2
Layer: layer3.0.shortcut
Layer: layer3.1.conv1
Layer: layer3.1.conv2
Layer: layer3.1.shortcut
Layer: layer4.0.conv1
Layer: layer4.0.conv2
Layer: layer4.0.shortcut
Layer: layer4.1.conv1
Layer: layer4.1.conv2
Layer: layer4.1.shortcut
Layer: MaxPool
Layer: Linear1
Output_0_0 675759621 589322 589314
Output_0_1 721438901 589322 589312
Output_0_2 750518444 589322 589315
Output_0_3 694638474 589322 589316
Output_0_4 770222539 589322 589320
Output_0_5 651799819 589322 589318
Output_0_6 681613968 589322 589313


In [None]:
for prune_percent in prune_percents:
        
    num_edges = model_fps.shape[0]-1

    limit = num_edges - int((float(prune_percent)/100.) * num_edges)

    model_original = torch.load(isd_loc.format(model_name,dataset_name,seed))

    original_weights = nn_graph.get_weights(model_original)
    original_param_info = nn_graph.append_params(mc.param_info, original_weights)

    fps = nn_graph.flatten_params(original_param_info).copy()
    fps2 = fps.copy()
    fps2[:] = 0.
    fps2[NNG.graph_idx_vec[NNG.adj_vec != 0]] = fps[NNG.graph_idx_vec[NNG.adj_vec != 0]]
    num_params = np.sum(fps2 > 0)
    print(num_params, limit, num_edges, prune_percent)
    if num_params <= limit:
        fps2[NNG.graph_idx_vec[NNG.adj_vec != 0]] = fps[NNG.graph_idx_vec[NNG.adj_vec != 0]]
        fps2[np.argsort(-np.abs(fps))[:limit]] = fps[np.argsort(-np.abs(fps))[:limit]]
#         sargs = np.argsort(-np.abs(fps))
#         a = 0
#         while 100*np.sum(fps2 != fps)/num_edges > prune_percent :
#             fps2[sargs[a]] = fps[sargs[a]]
#             a += 1
    else:
        fps2[:] = 0.
        fps2[NNG.graph_idx_vec[NNG.adj_vec != 0][:limit]] = fps[NNG.graph_idx_vec[NNG.adj_vec != 0][:limit]]

    print('new percentage', 100*np.sum(fps2 != fps) / num_edges)

    ps = nn_graph.inverse_flatten_params(fps2, original_param_info)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#         i = 0
#         for seq in model_original.children():
#             for layer in seq.children():
#                 if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear):
#                     layer.data = torch.Tensor(ps[i]).to(device)
#                     i += 1
    i = 0
    for name, param in model_original.named_parameters():
        if 'weight' in name and 'bn' not in name and len(param.shape) > 1:
            param.data = torch.Tensor(ps[i]).to(device)
            i += 1

#         optimizer = torch.optim.Adam(model_original.parameters(), lr=1.2e-3, weight_decay=1e-4)
    optimizer = torch.optim.SGD(model_original.parameters(), lr=1.2e-3, weight_decay=1e-4, momentum=0.9)
    for epoch in range(1,epochs+1):
        train(model_original, device, train_loader, optimizer, epoch, criterion)
        acc, loss = test(model_original, test_loader, criterion)
        print('Seed: {}, Prune Percentage: {}, Epoch: {}, Test Accuracy: {}, Test Loss: {}'.format(seed, prune_percent, epoch, acc, loss))
        results.append([seed,'homology',prune_percent,epoch,acc,loss])

In [None]:
col_names = ['seed', 'prune_type', 'prune_percent', 'epoch', 'accuracy', 'loss']
df = pd.DataFrame(results, columns=col_names)
df.head()

In [None]:
epoch = epochs
for seed in seeds:
    plt.plot(df[(df['seed'] == seed) & (df['epoch'] == epoch)]['prune_percent'], df[(df['seed'] == seed) & (df['epoch'] == epoch)]['accuracy'])
plt.show()

In [None]:
# sms = []
# for seed in seeds:
#     sms.append(df[(df['seed'] == seed) & (df['epoch'] == epoch)]['accuracy'].values)
# sms = np.array(sms)

sms = []
for seed in seeds:
    dfs = df[df['seed'] == seed]
    sms.append(dfs.groupby('prune_percent')['accuracy'].max().values)
sms = np.array(sms)

In [None]:
# seed, prune_type, prune_percent, epoch, accuracy, loss
results_threshold = []
mc = Mc()
for seed in seeds:
    
    torch.manual_seed(seed)
    
#     model = torch.load(model_loc0.format(model_name, dataset_name, seed))
    model = torch.load(isd_loc.format(model_name, dataset_name, seed))
    
    model_weights = nn_graph.get_weights(model)
    model_param_info = nn_graph.append_params(mc.param_info, model_weights)

    model_fps = nn_graph.flatten_params(model_param_info).copy()
    
    for prune_percent in prune_percents:
        
        num_edges = model_fps.shape[0]-1
        
        limit = num_edges - int((float(prune_percent)/100.) * num_edges)
        
        print(limit, num_edges, prune_percent)

        model_original = torch.load(isd_loc.format(model_name,dataset_name,seed))

        original_weights = nn_graph.get_weights(model_original)
        original_param_info = nn_graph.append_params(mc.param_info, original_weights)

        fps = nn_graph.flatten_params(original_param_info).copy()
        fps2 = fps.copy()
        fps2[:] = 0.
        fps2[np.argsort(-np.abs(fps))[:limit]] = fps[np.argsort(-np.abs(fps))[:limit]]
        
        print('new percentage', np.sum(fps2 != fps) / num_edges)

        ps = nn_graph.inverse_flatten_params(fps2, original_param_info)

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
#         i = 0
#         for seq in model_original.children():
#             for layer in seq.children():
#                 if isinstance(layer, torch.nn.Conv2d) or isinstance(layer, torch.nn.Linear):
#                     layer.data = torch.Tensor(ps[i]).to(device)
#                     i += 1
        
        i = 0
        for name, param in model_original.named_parameters():
            if 'weight' in name and 'bn' not in name:
                param.data = torch.Tensor(ps[i]).to(device)
                i += 1

#         optimizer = torch.optim.Adam(model_original.parameters(), lr=1.2e-3, weight_decay=1e-4)
        optimizer = torch.optim.SGD(model_original.parameters(), lr=1.2e-3, weight_decay=1e-4, momentum=0.9)
        for epoch in range(1,epochs+1):
            train(model_original, device, train_loader, optimizer, epoch, criterion)
            acc, loss = test(model_original, test_loader, criterion)
            print('Seed: {}, Prune Percentage: {}, Epoch: {}, Test Accuracy: {}, Test Loss: {}'.format(seed, prune_percent, epoch, acc, loss))
            results_threshold.append([seed,'threshold',prune_percent,epoch,acc,loss])

In [None]:
# for seed in seeds:
#     dat = np.load('/home/schraterlab/gebhart/projects/LTHT/dumps/lt/{}/{}/{}/prune_all/global/lt_20_bestaccuracy.dat'.format(model_name, dataset_name, seed), allow_pickle=True)
#     print(dat)

In [None]:
df_thresh = pd.DataFrame(results_threshold, columns=col_names)
df_thresh.head()

In [None]:
# sms_thresh = []
# for seed in seeds:
#     sms_thresh.append(df_thresh[(df_thresh['seed'] == seed) & (df_thresh['epoch'] == epoch)]['accuracy'].values)
# sms_thresh = np.array(sms_thresh)

sms_thresh = []
for seed in seeds:
    dfs = df_thresh[df_thresh['seed'] == seed]
    sms_thresh.append(dfs.groupby('prune_percent')['accuracy'].max().values)
sms_thresh = np.array(sms_thresh)

In [None]:
xvals = df_thresh[(df_thresh['seed'] == seeds[0]) & (df_thresh['epoch'] == epoch)]['prune_percent'].values
yvals = sms_thresh.mean(axis=0)
ystd = sms_thresh.std(axis=0)
plt.plot(xvals,yvals)
plt.fill_between(xvals, yvals - ystd, yvals + ystd,
                 color='gray', alpha=0.2)

In [None]:
xvals = df[(df['seed'] == seeds[0]) & (df['epoch'] == 1)]['prune_percent'].values
yvals = np.flip(sms.mean(axis=0))
ystd = np.flip(sms.std(axis=0))
plt.plot(xvals,yvals,label='Homology Pruning')
plt.fill_between(xvals, yvals - ystd, yvals + ystd,
                 color='gray', alpha=0.2)

xvals = df_thresh[(df_thresh['seed'] == seeds[0]) & (df_thresh['epoch'] == epoch)]['prune_percent'].values
yvals = np.flip(sms_thresh.mean(axis=0))
ystd = np.flip(sms_thresh.std(axis=0))
plt.plot(xvals,yvals,label='Threshold Pruning')
plt.fill_between(xvals, yvals - ystd, yvals + ystd,
                 color='gray', alpha=0.2)

# plt.xscale('log')
# plt.gca().invert_xaxis()
# plt.plot(1-xvals, yvals)
# plt.gca().set_xticklabels(1-plt.gca().get_xticks())

plt.title(dataset_name.upper())
plt.legend()
plt.xlabel('Prune Percentage')
plt.ylabel('Test Accuracy')
plt.show()

# log log scales?

In [None]:
# fps = nn_graph.flatten_params(original_param_info).copy()
# fps2 = fps.copy()
# fps2[:] = 0.
# fps2[np.argsort(-np.abs(fps))[:345]] = fps[np.argsort(-np.abs(fps))[:345]]
# fps2[NNGD.graph_idx_vec[NNGD.adj_vec != 0]] = fps[NNGD.graph_idx_vec[NNGD.adj_vec != 0]]