Can we back into the same adversary prediction abilities using only the sorted edge weights? I.e. if we just look at the top n edges in the filtration, is this good enough? Or do we actually use persistent homology effectively?

In [1]:
import os
import parse
import pickle
import copy
import math

import torch
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import dionysus as dion
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.spatial.distance import hamming, cosine
from sklearn import svm
from sklearn.model_selection import cross_val_score
import sklearn
import networkx as nx
import seaborn as sns

from pt_activation.models.cff_sigmoid import CFF

%load_ext autoreload
%autoreload 2

In [2]:
COLORS = ['#12355b', '#ff6978']
EDGE_COLOR = '#272d2d'
PLT_LABELS = ['Unaltered', 'Adversarial']

In [3]:
def get_adv_info(filename):
    format_string = 'true-{}_adv-{}_sample-{}.npy'
    parsed = parse.parse(format_string, filename)
    return {'true class':int(parsed[0]), 'adv class':int(parsed[1]), 'sample':int(parsed[2])}

def read_adversaries(loc):
    ret = []
    for f in os.listdir(loc):
        if os.path.isfile(os.path.join(loc,f)) and f.find('.npy') != -1:
            adv = np.load(os.path.join(loc, f))
            info = get_adv_info(f)
            info['adversary'] = adv
            ret.append(info)
    return ret
    

In [4]:
adv_directory_loc = '/home/tgebhart/projects/pt_activation/logdir/adversaries/mnist/carliniwagnerl2/cff_sigmoid.pt'
adversaries = read_adversaries(adv_directory_loc)
adversaries = sorted(adversaries,  key=lambda k: k['sample'])

In [5]:
def create_sample_graph(f,tnms,wm):

    subgraphs = {}
    m = dion.homology_persistence(f)
    dgm = dion.init_diagrams(m,f)[0]
    for i,c in enumerate(m):
        if len(c) == 2:
            w = f[i].data
            if (tnms[f[c[0].index][0]],tnms[f[c[1].index][0]]) in wm:
                w = wm[(tnms[f[c[0].index][0]],tnms[f[c[1].index][0]])]
            elif (tnms[f[c[1].index][0]],tnms[f[c[0].index][0]]) in wm:
                w = wm[(tnms[f[c[1].index][0]],tnms[f[c[0].index][0]])]
#                 else:
#                     print((tnms[f[c[0].index][0]],tnms[f[c[1].index][0]]))
#                     raise Exception('NO WM!')
            if False: #tnms[f[c[0].index][0]] in subgraphs:
                subgraphs[tnms[f[c[0].index][0]]].add_edge(tnms[f[c[0].index][0]],tnms[f[c[1].index][0]], weight=w)
            else:
                eaten = False
                for k, v in subgraphs.items():
                    if v.has_node(tnms[f[c[0].index][0]]):
                        if tnms[f[c[1].index][0]] in subgraphs:
                            v.add_node(f[c[1].index][0])
#                                 subgraphs[k] = nx.union(v, subgraphs[tnms[f[c[1].index][0]]])
                        else:
                            v.add_edge(tnms[f[c[0].index][0]], tnms[f[c[1].index][0]], weight=w)
                        eaten = True
                        break
                if not eaten:
                    g = nx.Graph()
                    g.add_edge(tnms[f[c[0].index][0]], tnms[f[c[1].index][0]], weight=w)
                    subgraphs[tnms[f[c[0].index][0]]] = g
    
    return subgraphs, dgm, create_lifetimes(f,subgraphs,dgm,tnms)

def create_lifetimes(f, subgraphs, dgm, ids):
    lifetimes = {}
    for pt in dgm:
        k = ids[f[pt.data][0]] 
        if k in subgraphs.keys():
            if pt.death < float('inf'):
                lifetimes[k] = pt.birth - pt.death
            else:
                lifetimes[k] = pt.birth
    return lifetimes
    

In [6]:
def create_subgraphs(model, batch_size, up_to):
    device = torch.device("cpu")
    kwargs = {'num_workers': 1, 'pin_memory': True}
    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([
                           transforms.ToTensor(),
#                            transforms.Normalize((0.1307,), (0.3081,))
                       ])), batch_size=batch_size, shuffle=False, **kwargs)

    model.eval()
    test_loss = 0
    correct = 0
    t = 0
    res_df = []
    subgraphs = []
    diagrams = []
    lifetimes = []
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, hiddens = model(data, hiddens=True)
            test_loss = F.nll_loss(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            for s in range(data.shape[0]):
                this_hiddens = [hiddens[0][s], hiddens[1][s], hiddens[2][s]]
                print('Filtration: {}'.format(s+t))
                f, nm, wm = model.compute_dynamic_filtration2(data[s], this_hiddens, percentile=0, return_nm=True, absolute_value=True, input_layer=False)
                tnm = {v: k for k, v in nm.items()}
                sg, dg, lifetime = create_sample_graph(f, tnm, wm)
                row = {'loss':test_loss, 'class':target.cpu().numpy()[s], 'prediction':pred.cpu().numpy()[s][0]}
                res_df.append(row)
                subgraphs.append(sg)
                diagrams.append(dg)
                lifetimes.append(lifetime)

            t += batch_size
            if t >= up_to:
                break

    return pd.DataFrame(res_df), subgraphs, diagrams, lifetimes


def create_adversary_subgraphs(model, batch_size, up_to, adversaries):
    device = torch.device("cpu")
    adv_images = torch.tensor(np.array([a['adversary'] for a in adversaries]))
    adv_labels = torch.tensor(np.array([a['true class'] for a in adversaries]))
    adv_samples = [a['sample'] for a in adversaries]
    
    print(adv_images.shape, adv_labels.shape)
    
    advs = torch.utils.data.TensorDataset(adv_images, adv_labels)
    test_loader = torch.utils.data.DataLoader(advs, batch_size=batch_size, shuffle=False)
    
    model.eval()
    test_loss = 0
    correct = 0
    t = 0
    res_df = []
    subgraphs = []
    diagrams = []
    lifetimes = []
    with torch.no_grad():
        
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output, hiddens = model(data, hiddens=True)
            test_loss = F.nll_loss(output, target).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            for s in range(data.shape[0]):
                this_hiddens = [hiddens[0][s], hiddens[1][s], hiddens[2][s]]
                print('Filtration: {}'.format(s+t))
                f, nm, wm = model.compute_dynamic_filtration2(data[s], this_hiddens, percentile=0, return_nm=True, absolute_value=True, input_layer=False)
                tnm = {v: k for k, v in nm.items()}
                sg, dg, lifetime = create_sample_graph(f, tnm, wm)
                row = {'loss':test_loss, 'class':target.cpu().numpy()[s], 'prediction':pred.cpu().numpy()[s][0]}
                res_df.append(row)
                subgraphs.append(sg)
                diagrams.append(dg)
                lifetimes.append(lifetime)

            t += (batch_size)
            if t >= up_to:
                break

    return pd.DataFrame(res_df), subgraphs, diagrams, lifetimes

In [7]:
model_location = '/home/tgebhart/projects/pt_activation/logdir/models/mnist/cff_sigmoid.pt'
model = CFF()
model.load_state_dict(torch.load(model_location))

In [9]:
res_df, sample_graphs, dgms, lifetimes = create_subgraphs(model, 1, 2000)

Filtration: 0
filtration size 77275
Filtration: 1
filtration size 142673
Filtration: 2
filtration size 17357
Filtration: 3
filtration size 100185
Filtration: 4
filtration size 59808
Filtration: 5
filtration size 12070
Filtration: 6
filtration size 103369
Filtration: 7
filtration size 86325
Filtration: 8
filtration size 38112
Filtration: 9
filtration size 23850
Filtration: 10
filtration size 126050
Filtration: 11
filtration size 74575
Filtration: 12
filtration size 23435
Filtration: 13
filtration size 74521
Filtration: 14
filtration size 77261
Filtration: 15
filtration size 19130
Filtration: 16
filtration size 36701
Filtration: 17
filtration size 119416
Filtration: 18
filtration size 83625
Filtration: 19
filtration size 30624
Filtration: 20
filtration size 63195
Filtration: 21
filtration size 106579
Filtration: 22
filtration size 117739
Filtration: 23
filtration size 18964
Filtration: 24
filtration size 52214
Filtration: 25
filtration size 131342
Filtration: 26
filtration size 94605
Fil

filtration size 69984
Filtration: 218
filtration size 13916
Filtration: 219
filtration size 124514
Filtration: 220
filtration size 112630
Filtration: 221
filtration size 136336
Filtration: 222
filtration size 113955
Filtration: 223
filtration size 60916
Filtration: 224
filtration size 4200
Filtration: 225
filtration size 92600
Filtration: 226
filtration size 18424
Filtration: 227
filtration size 5927
Filtration: 228
filtration size 44342
Filtration: 229
filtration size 94840
Filtration: 230
filtration size 35836
Filtration: 231
filtration size 37904
Filtration: 232
filtration size 128657
Filtration: 233
filtration size 58008
Filtration: 234
filtration size 88358
Filtration: 235
filtration size 15704
Filtration: 236
filtration size 125532
Filtration: 237
filtration size 91150
Filtration: 238
filtration size 22263
Filtration: 239
filtration size 42956
Filtration: 240
filtration size 76464
Filtration: 241
filtration size 46940
Filtration: 242
filtration size 25678
Filtration: 243
filtrati

filtration size 102681
Filtration: 432
filtration size 67537
Filtration: 433
filtration size 31754
Filtration: 434
filtration size 24761
Filtration: 435
filtration size 31279
Filtration: 436
filtration size 5343
Filtration: 437
filtration size 116014
Filtration: 438
filtration size 89937
Filtration: 439
filtration size 59499
Filtration: 440
filtration size 137012
Filtration: 441
filtration size 87727
Filtration: 442
filtration size 77847
Filtration: 443
filtration size 5576
Filtration: 444
filtration size 62343
Filtration: 445
filtration size 118393
Filtration: 446
filtration size 119531
Filtration: 447
filtration size 31401
Filtration: 448
filtration size 9719
Filtration: 449
filtration size 113470
Filtration: 450
filtration size 54631
Filtration: 451
filtration size 83074
Filtration: 452
filtration size 122041
Filtration: 453
filtration size 124544
Filtration: 454
filtration size 12135
Filtration: 455
filtration size 38836
Filtration: 456
filtration size 11057
Filtration: 457
filtrat

Filtration: 646
filtration size 59102
Filtration: 647
filtration size 59628
Filtration: 648
filtration size 109424
Filtration: 649
filtration size 113366
Filtration: 650
filtration size 35433
Filtration: 651
filtration size 128136
Filtration: 652
filtration size 60290
Filtration: 653
filtration size 78638
Filtration: 654
filtration size 105059
Filtration: 655
filtration size 125378
Filtration: 656
filtration size 136753
Filtration: 657
filtration size 104898
Filtration: 658
filtration size 144947
Filtration: 659
filtration size 75796
Filtration: 660
filtration size 11923
Filtration: 661
filtration size 16101
Filtration: 662
filtration size 42018
Filtration: 663
filtration size 73495
Filtration: 664
filtration size 74321
Filtration: 665
filtration size 138970
Filtration: 666
filtration size 78608
Filtration: 667
filtration size 58771
Filtration: 668
filtration size 37385
Filtration: 669
filtration size 92558
Filtration: 670
filtration size 15708
Filtration: 671
filtration size 107860
Fi

Filtration: 860
filtration size 131623
Filtration: 861
filtration size 119334
Filtration: 862
filtration size 40697
Filtration: 863
filtration size 106291
Filtration: 864
filtration size 94554
Filtration: 865
filtration size 95963
Filtration: 866
filtration size 103507
Filtration: 867
filtration size 95581
Filtration: 868
filtration size 62888
Filtration: 869
filtration size 91703
Filtration: 870
filtration size 35621
Filtration: 871
filtration size 126785
Filtration: 872
filtration size 33225
Filtration: 873
filtration size 125379
Filtration: 874
filtration size 20815
Filtration: 875
filtration size 110710
Filtration: 876
filtration size 135553
Filtration: 877
filtration size 107844
Filtration: 878
filtration size 77993
Filtration: 879
filtration size 25309
Filtration: 880
filtration size 34912
Filtration: 881
filtration size 101077
Filtration: 882
filtration size 30150
Filtration: 883
filtration size 27626
Filtration: 884
filtration size 83708
Filtration: 885
filtration size 72191
Fi

Filtration: 1072
filtration size 33377
Filtration: 1073
filtration size 83191
Filtration: 1074
filtration size 69149
Filtration: 1075
filtration size 31818
Filtration: 1076
filtration size 80153
Filtration: 1077
filtration size 106475
Filtration: 1078
filtration size 25745
Filtration: 1079
filtration size 16508
Filtration: 1080
filtration size 12543
Filtration: 1081
filtration size 91275
Filtration: 1082
filtration size 53268
Filtration: 1083
filtration size 52860
Filtration: 1084
filtration size 123918
Filtration: 1085
filtration size 66657
Filtration: 1086
filtration size 36176
Filtration: 1087
filtration size 5615
Filtration: 1088
filtration size 33682
Filtration: 1089
filtration size 49222
Filtration: 1090
filtration size 6115
Filtration: 1091
filtration size 22392
Filtration: 1092
filtration size 68156
Filtration: 1093
filtration size 87952
Filtration: 1094
filtration size 92479
Filtration: 1095
filtration size 41650
Filtration: 1096
filtration size 114183
Filtration: 1097
filtrat

filtration size 136372
Filtration: 1282
filtration size 91276
Filtration: 1283
filtration size 37294
Filtration: 1284
filtration size 96392
Filtration: 1285
filtration size 90890
Filtration: 1286
filtration size 43947
Filtration: 1287
filtration size 35627
Filtration: 1288
filtration size 15975
Filtration: 1289
filtration size 18323
Filtration: 1290
filtration size 10572
Filtration: 1291
filtration size 15926
Filtration: 1292
filtration size 117736
Filtration: 1293
filtration size 104271
Filtration: 1294
filtration size 89657
Filtration: 1295
filtration size 26125
Filtration: 1296
filtration size 119569
Filtration: 1297
filtration size 105925
Filtration: 1298
filtration size 40580
Filtration: 1299
filtration size 135835
Filtration: 1300
filtration size 35166
Filtration: 1301
filtration size 30764
Filtration: 1302
filtration size 70685
Filtration: 1303
filtration size 133290
Filtration: 1304
filtration size 40823
Filtration: 1305
filtration size 49653
Filtration: 1306
filtration size 64

Filtration: 1490
filtration size 133362
Filtration: 1491
filtration size 129709
Filtration: 1492
filtration size 128617
Filtration: 1493
filtration size 27931
Filtration: 1494
filtration size 20288
Filtration: 1495
filtration size 62768
Filtration: 1496
filtration size 130120
Filtration: 1497
filtration size 75527
Filtration: 1498
filtration size 86150
Filtration: 1499
filtration size 76216
Filtration: 1500
filtration size 76047
Filtration: 1501
filtration size 12241
Filtration: 1502
filtration size 82847
Filtration: 1503
filtration size 29437
Filtration: 1504
filtration size 89073
Filtration: 1505
filtration size 103866
Filtration: 1506
filtration size 89723
Filtration: 1507
filtration size 76093
Filtration: 1508
filtration size 24137
Filtration: 1509
filtration size 128194
Filtration: 1510
filtration size 24622
Filtration: 1511
filtration size 91629
Filtration: 1512
filtration size 67315
Filtration: 1513
filtration size 41028
Filtration: 1514
filtration size 104673
Filtration: 1515
f

filtration size 119444
Filtration: 1699
filtration size 110595
Filtration: 1700
filtration size 115495
Filtration: 1701
filtration size 103568
Filtration: 1702
filtration size 77734
Filtration: 1703
filtration size 55758
Filtration: 1704
filtration size 18766
Filtration: 1705
filtration size 100768
Filtration: 1706
filtration size 105243
Filtration: 1707
filtration size 71692
Filtration: 1708
filtration size 101114
Filtration: 1709
filtration size 113537
Filtration: 1710
filtration size 98263
Filtration: 1711
filtration size 106126
Filtration: 1712
filtration size 137355
Filtration: 1713
filtration size 65713
Filtration: 1714
filtration size 90337
Filtration: 1715
filtration size 58620
Filtration: 1716
filtration size 8938
Filtration: 1717
filtration size 119149
Filtration: 1718
filtration size 136452
Filtration: 1719
filtration size 21913
Filtration: 1720
filtration size 87545
Filtration: 1721
filtration size 92630
Filtration: 1722
filtration size 42503
Filtration: 1723
filtration siz

filtration size 76470
Filtration: 1908
filtration size 119773
Filtration: 1909
filtration size 58844
Filtration: 1910
filtration size 68946
Filtration: 1911
filtration size 24599
Filtration: 1912
filtration size 23802
Filtration: 1913
filtration size 40197
Filtration: 1914
filtration size 47054
Filtration: 1915
filtration size 97487
Filtration: 1916
filtration size 110018
Filtration: 1917
filtration size 50024
Filtration: 1918
filtration size 46492
Filtration: 1919
filtration size 108155
Filtration: 1920
filtration size 12175
Filtration: 1921
filtration size 103582
Filtration: 1922
filtration size 45935
Filtration: 1923
filtration size 90225
Filtration: 1924
filtration size 115916
Filtration: 1925
filtration size 104173
Filtration: 1926
filtration size 49756
Filtration: 1927
filtration size 93473
Filtration: 1928
filtration size 73837
Filtration: 1929
filtration size 12168
Filtration: 1930
filtration size 98108
Filtration: 1931
filtration size 80832
Filtration: 1932
filtration size 110

In [10]:
adv_df, adv_sample_graphs, adv_dgms, adv_lifetimes = create_adversary_subgraphs(model, 1, 2000, adversaries)

torch.Size([4116, 1, 28, 28]) torch.Size([4116])
Filtration: 0
filtration size 9572
Filtration: 1
filtration size 77275
Filtration: 2
filtration size 144309
Filtration: 3
filtration size 7828
Filtration: 4
filtration size 91744
Filtration: 5
filtration size 106766
Filtration: 6
filtration size 59808
Filtration: 7
filtration size 67517
Filtration: 8
filtration size 14605
Filtration: 9
filtration size 12070
Filtration: 10
filtration size 96467
Filtration: 11
filtration size 103369
Filtration: 12
filtration size 86325
Filtration: 13
filtration size 90960
Filtration: 14
filtration size 38112
Filtration: 15
filtration size 38112
Filtration: 16
filtration size 23850
Filtration: 17
filtration size 24735
Filtration: 18
filtration size 123275
Filtration: 19
filtration size 74575
Filtration: 20
filtration size 93178
Filtration: 21
filtration size 23435
Filtration: 22
filtration size 17658
Filtration: 23
filtration size 34051
Filtration: 24
filtration size 16904
Filtration: 25
filtration size 398

filtration size 75182
Filtration: 217
filtration size 107790
Filtration: 218
filtration size 90805
Filtration: 219
filtration size 18659
Filtration: 220
filtration size 14130
Filtration: 221
filtration size 26154
Filtration: 222
filtration size 102148
Filtration: 223
filtration size 95256
Filtration: 224
filtration size 91536
Filtration: 225
filtration size 51031
Filtration: 226
filtration size 119688
Filtration: 227
filtration size 27293
Filtration: 228
filtration size 28614
Filtration: 229
filtration size 19037
Filtration: 230
filtration size 41813
Filtration: 231
filtration size 111424
Filtration: 232
filtration size 135032
Filtration: 233
filtration size 8808
Filtration: 234
filtration size 17420
Filtration: 235
filtration size 14909
Filtration: 236
filtration size 18033
Filtration: 237
filtration size 40934
Filtration: 238
filtration size 22776
Filtration: 239
filtration size 30700
Filtration: 240
filtration size 137890
Filtration: 241
filtration size 110439
Filtration: 242
filtra

filtration size 66170
Filtration: 432
filtration size 30929
Filtration: 433
filtration size 40094
Filtration: 434
filtration size 138588
Filtration: 435
filtration size 111512
Filtration: 436
filtration size 87319
Filtration: 437
filtration size 34608
Filtration: 438
filtration size 75584
Filtration: 439
filtration size 41545
Filtration: 440
filtration size 6571
Filtration: 441
filtration size 16875
Filtration: 442
filtration size 132455
Filtration: 443
filtration size 79261
Filtration: 444
filtration size 24152
Filtration: 445
filtration size 20321
Filtration: 446
filtration size 20982
Filtration: 447
filtration size 58463
Filtration: 448
filtration size 19873
Filtration: 449
filtration size 72123
Filtration: 450
filtration size 118675
Filtration: 451
filtration size 45121
Filtration: 452
filtration size 62105
Filtration: 453
filtration size 49074
Filtration: 454
filtration size 80433
Filtration: 455
filtration size 122782
Filtration: 456
filtration size 31356
Filtration: 457
filtrati

filtration size 68768
Filtration: 647
filtration size 108353
Filtration: 648
filtration size 17688
Filtration: 649
filtration size 89026
Filtration: 650
filtration size 132368
Filtration: 651
filtration size 48655
Filtration: 652
filtration size 96250
Filtration: 653
filtration size 31896
Filtration: 654
filtration size 14610
Filtration: 655
filtration size 114839
Filtration: 656
filtration size 56060
Filtration: 657
filtration size 79514
Filtration: 658
filtration size 92117
Filtration: 659
filtration size 73569
Filtration: 660
filtration size 118370
Filtration: 661
filtration size 58086
Filtration: 662
filtration size 81482
Filtration: 663
filtration size 110018
Filtration: 664
filtration size 54364
Filtration: 665
filtration size 144884
Filtration: 666
filtration size 34819
Filtration: 667
filtration size 67021
Filtration: 668
filtration size 13085
Filtration: 669
filtration size 113525
Filtration: 670
filtration size 62008
Filtration: 671
filtration size 6817
Filtration: 672
filtra

filtration size 127450
Filtration: 861
filtration size 15684
Filtration: 862
filtration size 120142
Filtration: 863
filtration size 102950
Filtration: 864
filtration size 20045
Filtration: 865
filtration size 7756
Filtration: 866
filtration size 87331
Filtration: 867
filtration size 83778
Filtration: 868
filtration size 29138
Filtration: 869
filtration size 113730
Filtration: 870
filtration size 8820
Filtration: 871
filtration size 21962
Filtration: 872
filtration size 100255
Filtration: 873
filtration size 42615
Filtration: 874
filtration size 37631
Filtration: 875
filtration size 77859
Filtration: 876
filtration size 20480
Filtration: 877
filtration size 7307
Filtration: 878
filtration size 141809
Filtration: 879
filtration size 130884
Filtration: 880
filtration size 136864
Filtration: 881
filtration size 49334
Filtration: 882
filtration size 108268
Filtration: 883
filtration size 7337
Filtration: 884
filtration size 12561
Filtration: 885
filtration size 70756
Filtration: 886
filtrat

RuntimeError: can't alloc

In [None]:
with open(os.path.join(adv_directory_loc, 'adv_samples.pkl'), 'wb') as f:
    pickle.dump(adv_sample_graphs, f)
with open(os.path.join(adv_directory_loc, 'samples.pkl'), 'wb') as f:
    pickle.dump(sample_graphs, f)

# sample_graphs = pickle.load( open(os.path.join(adv_directory_loc, 'samples.pkl'), "rb") )
# adv_sample_graphs = pickle.load( open(os.path.join(adv_directory_loc, 'adv_samples.pkl'), "rb") )

In [None]:
# for dgm in dgms:
#     dion.plot.plot_diagram(dgm, show=True)

In [None]:
# for adv_dgm in adv_dgms:
#     dion.plot.plot_diagram(adv_dgm, show=True)

In [None]:
# thru = 3
# all_gois = []
# for i in range(len(sample_graphs)):
#     print(i)
#     ks = list(sample_graphs[i].keys())
#     a = [sample_graphs[i][k] for k in ks[:thru]]
#     all_gois.append(nx.compose_all(a))
    
# adv_all_gois = []
# for i in range(len(adv_sample_graphs)):
#     print(i)
#     ks = list(adv_sample_graphs[i].keys())
#     a = [adv_sample_graphs[i][k] for k in ks[:thru]]
#     adv_all_gois.append(nx.compose_all(a))

In [None]:
def hamming_distance(g1, g2, ret_labels=False):
    nodeset = set(list(g1.nodes) + list(g2.nodes))
    g1_vec = np.zeros((len(nodeset)))
    g2_vec = np.zeros((len(nodeset)))
    nodesetlist = list(nodeset)
    for i in range(len(nodesetlist)):
        node = nodesetlist[i]
        if node in g1.nodes:
            g1_vec[i] = 1.0
        if node in g2.nodes:
            g2_vec[i] = 1.0
    if ret_labels:
        return hamming(g1_vec, g2_vec), nodesetlist
    else:
        return hamming(g1_vec, g2_vec)
    
def edge_hamming_distance(g1, g2, ret_labels=False):
    edgeset = set(list(g1.edges) + list(g2.edges))
    g1_vec = np.zeros((len(edgeset)))
    g2_vec = np.zeros((len(edgeset)))
    edgesetlist = list(edgeset)
    for i in range(len(edgesetlist)):
        edge = edgesetlist[i]
        if edge in g1.edges:
            g1_vec[i] = 1.0
        if edge in g2.edges:
            g2_vec[i] = 1.0
    if ret_labels:
        return hamming(g1_vec, g2_vec), edgesetlist
    else:
        return hamming(g1_vec, g2_vec)
    
def lifetime_weighted_edge_distance(subgraphs1,subgraphs2,lifetimes1,lifetimes2,ret_labels=False):
    edges1 = {}
    edges2 = {}
    sg1keys = list(subgraphs1.keys())
    sg2keys = list(subgraphs2.keys())
    lifetimes1 = list(lifetimes1.values())
    lifetimes2 = list(lifetimes2.values())
    ml1 = max(lifetimes1)
    ml2 = max(lifetimes2)
    for i in range(len(sg1keys)):
        g = subgraphs1[sg1keys[i]]
        for e in g.edges:
            edges1[e] = lifetimes1[i]/ml1
    for i in range(len(sg2keys)):
        g = subgraphs2[sg2keys[i]]
        for e in g.edges:
            edges2[e] = lifetimes2[i]/ml2
    edgeset = set(list(edges1.keys()) + list(edges2.keys()))
    g1_vec = np.zeros((len(edgeset)))
    g2_vec = np.zeros((len(edgeset)))
    edgesetlist = list(edgeset)
    for i in range(len(edgesetlist)):
        edge = edgesetlist[i]
        if edge in edges1:
            g1_vec[i] += edges1[edge]
        if edge in edges2:
            g2_vec[i] += edges2[edge]
    if ret_labels:
        return cosine(g1_vec, g2_vec), edgesetlist
    else:
        return cosine(g1_vec, g2_vec)
    
def weighted_edge_distance(g1, g2, ret_labels=False):
    edgeset = set(list(g1.edges) + list(g2.edges))
    g1_vec = np.zeros((len(edgeset)))
    g2_vec = np.zeros((len(edgeset)))
    edgesetlist = list(edgeset)
    for i in range(len(edgesetlist)):
        edge = edgesetlist[i]
        if edge in g1.edges:
            g1_vec[i] = g1[edge[0]][edge[1]]['weight']
        if edge in g2.edges:
            g2_vec[i] = g2[edge[0]][edge[1]]['weight']
    if ret_labels:
        return cosine(g1_vec, g2_vec), edgesetlist
    else:
        return cosine(g1_vec, g2_vec)
    
def compute_kernel_matrix(gs, ls, take=3):
    ret = np.zeros((len(gs),len(gs)))
    for i in range(len(gs)):
        print('row: ', i)
        g1 = gs[i]
        lt1 = ls[i]
        g1keys = list(g1.keys())
        g1 = {k: g1[k] for k in g1keys[:take]}
        for j in range(len(gs)):
            g2 = gs[j]
            lt2 = ls[j]
            g2keys = list(g2.keys())
            g2 = {k: g2[k] for k in g2keys[:take]}
            ret[i,j] = 1 - lifetime_weighted_edge_distance(g1,g2,lt1,lt2)
    return ret

In [None]:
gs = sample_graphs
ls = lifetimes

In [None]:
X = compute_kernel_matrix(gs, ls, take=-1)
y = res_df['class'].values

In [None]:
clf = svm.SVC(decision_function_shape='ovo', kernel='precomputed')
cross_val_score(clf, X, y, cv=10)

In [None]:
np.average(cross_val_score(clf, X, y, cv=10))

In [None]:
print('Natural performance: ', res_df[res_df['class'] == res_df['prediction']].shape[0]/res_df.shape[0])

In [None]:
take = 5
edges = set()
for i in range(len(sample_graphs)):
    for k in list(sample_graphs[i].keys())[:take]:
        for x in sample_graphs[i][k].edges(data=True):
            edge_name = str(x[0])+'-'+str(x[1])
            edges.add(edge_name)
            
edf = pd.DataFrame(np.zeros((len(sample_graphs),len(edges))), columns=list(edges))
for i in range(len(sample_graphs)):
    print('Sample: {}/{}'.format(i,len(sample_graphs)))
    lst = list(sample_graphs[i].keys())
    for k in lst[:take]:
        for x in sample_graphs[i][k].edges(data=True):
            edge_name = str(x[0])+'-'+str(x[1])
#             edf.iloc[i][edge_name] += x[2]['weight']
            edf.iloc[i][edge_name] += 1

In [None]:
edf.head()

In [None]:
kernel='linear'
X = edf.values
y = res_df['class'].values
clf = svm.SVC(decision_function_shape='ovo', kernel=kernel)
cvs = cross_val_score(clf, X, y, cv=10)

In [None]:
cvs.mean()

In [None]:
t_fit = svm.SVC(decision_function_shape='ovo', kernel=kernel).fit(X,y)

In [None]:
# take = 3
# adv_edges = set()
# for i in range(len(adv_sample_graphs)):
#     for k in list(adv_sample_graphs[i].keys())[:take]:
#         for x in adv_sample_graphs[i][k].edges(data=True):
#             edge_name = str(x[0])+'-'+str(x[1])
#             adv_edges.add(edge_name)

In [None]:
adv_edf = pd.DataFrame(np.zeros((len(sample_graphs),len(edges))), columns=list(edges))
for i in range(len(adv_sample_graphs)):
    print('Sample: {}/{}'.format(i,len(adv_sample_graphs)))
    lst = list(adv_sample_graphs[i].keys())
    for k in lst[:take]:
        for x in adv_sample_graphs[i][k].edges(data=True):
            edge_name = str(x[0])+'-'+str(x[1])
            if edge_name in adv_edf.columns:
                adv_edf.iloc[i][edge_name] += 10

In [None]:
adv_preds = t_fit.predict(adv_edf.values)

In [None]:
adv_preds[:5]

In [None]:
adv_df.head()

In [None]:
adv_df.shape

In [None]:
print('Recovery Accuracy: {}'.format(adv_df[adv_df['class'] == adv_preds[:adv_df.shape[0]]].shape[0]/adv_df.shape[0]))

In [None]:
print('Adversary Class Percentage: {}'.format(adv_df[adv_df['prediction'] == adv_preds].shape[0]/adv_df.shape[0]))

In [None]:
plt.imshow(adversaries[4]['adversary'].reshape(28,28))

In [None]:
adv_df['prediction'].value_counts()/adv_df.shape[0]

In [None]:
adv_df[adv_df['class'] == adv_preds]['prediction'].value_counts()/adv_df[adv_df['class'] == adv_preds]['prediction'].shape[0]

In [None]:
colors = ['black', 'blue', 'red', 'green', 'yellow', 'orange', 'purple', 'pink', 'silver', 'cyan']
labels = list(range(10))

In [None]:
# from sklearn.decomposition import PCA
# from mpl_toolkits.mplot3d import Axes3D
# import PyQt5

# fig = plt.figure(1, figsize=(4, 3))
# plt.clf()
# ax = Axes3D(fig)

# plt.cla()
# pca = PCA(n_components=3)
# pca.fit(X)
# X_pca = pca.transform(X)

# for i in range(len(X_pca)):
#     ax.scatter(X_pca[i,0], X_pca[i,1], X_pca[i,2], color=colors[res_df['prediction'].iloc[i]])

In [None]:
# from sklearn import manifold
# from collections import OrderedDict

# fig = plt.figure(1, figsize=(4, 3))
# plt.clf()
# ax = Axes3D(fig)
# plt.cla()

# # X_dimmed = manifold.TSNE(n_components=3, init='pca', random_state=0).fit_transform(X)
# X_dimmed = manifold.Isomap(10, 3).fit_transform(X)

# for i in range(len(X_dimmed)):
#     ax.scatter(X_dimmed[i,0], X_dimmed[i,1], X_dimmed[i,2], color=colors[res_df['class'].iloc[i]], label=labels[res_df['class'].iloc[i]])

# handles, labs = plt.gca().get_legend_handles_labels()
# by_label = OrderedDict(zip(labs, handles))
# ax.legend(by_label.values(), by_label.keys())

In [None]:
from sklearn.decomposition import PCA
from sklearn import manifold
from collections import OrderedDict
plot_take = 1000

In [None]:
fig, ax = plt.subplots()
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X)

for i in range(len(X_pca)):
    ax.scatter(X_pca[i,0], X_pca[i,1], color=colors[res_df['prediction'].iloc[i]], label=labels[res_df['class'].iloc[i]])
handles, labs = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labs, handles))
ax.legend(by_label.values(), by_label.keys())

In [None]:
from sklearn import manifold
fig, ax = plt.subplots()
X_dimmed = manifold.Isomap(n_neighbors=15, n_components=2).fit_transform(X)
# X_dimmed = manifold.TSNE(n_components=2, init='pca', random_state=5).fit_transform(X)
# X_dimmed = manifold.SpectralEmbedding(n_neighbors=100, n_components=2).fit_transform(X)
# X_dimmed = manifold.MDS(2, max_iter=200, n_init=10).fit_transform(X)
# X_dimmed = manifold.LocallyLinearEmbedding(10, 2, eigen_solver='auto', method='standard').fit_transform(X)


for i in range(len(X_dimmed)):
    ax.scatter(X_dimmed[i,0], X_dimmed[i,1], color=colors[res_df['prediction'].iloc[i]], label=labels[res_df['prediction'].iloc[i]])
handles, labs = plt.gca().get_legend_handles_labels()
by_label = OrderedDict(zip(labs, handles))
ax.legend(by_label.values(), by_label.keys())