### Imports

In [1]:
import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import pickle as pkl
import networkx as nx

from tqdm import tqdm

from tensorflow.keras.optimizers import Adam, Nadam
from scripts.utils import *

from GMI_.models import GMI, LogReg
from GMI_.utils import process


  from .autonotebook import tqdm as notebook_tqdm


### Class Definitions

In [2]:
class BaseEmbedder:
    def __init__(self, graph, embed_shape = (128,)):
        self.embed(graph)
        self.E = list(graph.edges())
        self.graph = graph
        self.embed_shape = embed_shape
    
    def embed(self, graph):
        raise NotImplementedError
    
    def get_embedding(self):
        raise NotImplementedError
        

In [3]:
class GMIEmbedding(BaseEmbedder):
    def __init__(self, embed_dim = 64, graph = None, feature_matrix = None, use_xm = False, debug = False, batch_size = 1, nb_epochs = 500, patience = 20, ortho_ = 0.1, sparse_ = 0.1, lr = 1e-3, l2_coef = 0.0, drop_prob = 0.0, sparse = True, nonlinearity = 'prelu', alpha = 0.8, beta = 1.0, gamma = 1.0, negative_num = 5, epoch_flag = 20, model_name = 'test'):

        self.embed_dim = embed_dim
        self.debug = debug
        
        # Training Params
        self.graph = graph
        self.batch_size = batch_size
        self.nb_epochs = nb_epochs
        self.patience = patience
        self.lr = lr
        self.l2_coef = l2_coef
        self.feature_matrix = feature_matrix
        self.drop_prob = drop_prob
        self.hid_units = embed_dim
        self.sparse = sparse
        self.nonlinearity = nonlinearity
        self.use_xm = use_xm
        self.ortho_ = ortho_
        self.sparse_ = sparse_
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.negative_num = negative_num
        self.epoch_flag = epoch_flag
        self.model_name = model_name
        
        self.time_per_epoch = None
        
        if graph is not None:
            self.embed()
        else:
            self.graph = None
    
    def embed(self):

        ####
        if self.feature_matrix is None:
            feature_matrix = np.identity(len(self.graph))
        else: 
            feature_matrix = self.feature_matrix

        adj_ori = nx.to_scipy_sparse_array(graph)
        features = sp.lil_matrix(feature_matrix)
        features, _ = process.preprocess_features(features)

        nb_nodes = features.shape[0]
        ft_size = features.shape[1]
        
        adj = process.normalize_adj(adj_ori + sp.eye(adj_ori.shape[0]))

        if self.sparse:
            sp_adj = process.sparse_mx_to_torch_sparse_tensor(adj)
        else:
            adj = (adj + sp.eye(adj.shape[0])).todense()

        features = torch.FloatTensor(features[np.newaxis])
        if not self.sparse:
            adj = torch.FloatTensor(adj[np.newaxis])

        if self.feature_matrix is not None: 
            sense_features = torch.FloatTensor(self.feature_matrix)

        model = GMI(ft_size, self.hid_units, self.nonlinearity)
        optimiser = torch.optim.Adam(model.parameters(), lr = self.lr, weight_decay = self.l2_coef)
        
#         if self.use_xm:
#              model.load_state_dict(torch.load(self.model_name + '.pkl'))
        
        if torch.cuda.is_available():
            model.cuda()
            features = features.cuda()
            sp_adj = sp_adj.cuda()
            
        b_xent = nn.BCEWithLogitsLoss()
        xent = nn.CrossEntropyLoss()
        cnt_wait = 0
        best = 1e9
        best_t = 0
        
        
        adj_dense = adj_ori.toarray()
        adj_target = adj_dense + np.eye(adj_dense.shape[0])
        adj_row_avg = 1.0 / np.sum(adj_dense, axis = 1)
        adj_row_avg[np.isnan(adj_row_avg)] = 0.0
        adj_row_avg[np.isinf(adj_row_avg)] = 0.0
        adj_dense = adj_dense * 1.0
        
        for i in range(adj_ori.shape[0]):
            adj_dense[i] = adj_dense[i] * adj_row_avg[i]
        adj_ori = sp.csr_matrix(adj_dense, dtype = np.float32)
        
        start_time = time.time()
        for epoch in tqdm(range(self.nb_epochs)):
            model.train()
            optimiser.zero_grad()
            
            res = model(features, adj_ori, self.negative_num, sp_adj, None, None) 
            
            if self.use_xm == True and feature_matrix is not None:
                
                start_idx = 0
                loop = True
                
                ortho_loss = 0
                sparse_loss = 0
                xm_batch_size = 128
                
                sf = sense_features
                embeds = model.embed(features, sp_adj)
                                
                while loop:
                    end_idx = start_idx + xm_batch_size
                    if end_idx > len(self.graph):
                        loop = False
                        end_idx = len(self.graph)
                                            
                    sf = sense_features[start_idx : end_idx]
                    embeds_ = torch.squeeze(embeds)[start_idx : end_idx]
                    
                    
                    sense_mat = torch.einsum('ij, ik -> ijk', embeds_, sf)
                    E = sense_mat
                    y_norm = torch.diagonal(torch.matmul(embeds_, torch.transpose(embeds_, 0, 1)))
                    sense_norm = torch.diagonal(torch.matmul(sf, torch.transpose(sf, 0, 1)))
                    norm = torch.multiply(y_norm, sense_norm)
                    E = torch.transpose(torch.transpose(E, 0, 2) / norm, 0, 2)
                    E = (E - torch.amin(E, dim = [-1, -2], keepdim = True)) / (torch.amax(E, dim = [-1, -2], keepdim = True) - torch.amin(E, dim = [-1, -2], keepdim = True))
                    
                    E_t = torch.transpose(E, 1, 2)
                    
                    E_o = torch.einsum('aij, ajh -> aih', E, E_t)
                    E_o = torch.sum(E_o)
                    batch_ortho_loss = (self.ortho_ * E_o) / self.batch_size

                    batch_sparse_loss = (self.sparse_ * torch.sum(torch.linalg.norm(E, ord = 1, axis = 0))) / self.batch_size
                        
                    ortho_loss += batch_ortho_loss
                    sparse_loss += batch_sparse_loss
                    
                    start_idx = end_idx
                    
                loss = self.alpha * process.mi_loss_jsd(res[0], res[1]) +\
                       self.beta * process.mi_loss_jsd(res[2], res[3]) +\
                       self.gamma * process.reconstruct_loss(res[4], adj_target) +\
                       ortho_loss +\
                       sparse_loss
            else:
                loss = self.alpha * process.mi_loss_jsd(res[0], res[1]) +\
                       self.beta * process.mi_loss_jsd(res[2], res[3]) +\
                       self.gamma * process.reconstruct_loss(res[4], adj_target)


            if self.debug:
                print('Loss:', loss)

            if loss < best:
                best = loss
                best_t = epoch
                cnt_wait = 0
                torch.save(model.state_dict(), self.model_name + '.pkl')
            else:
                cnt_wait += 1

            if cnt_wait == self.epoch_flag:
                print('Early stopping!')
                break

            loss.backward()
            optimiser.step()
            
        self.time_per_epoch = (time.time() - start_time) / epoch

        if self.debug: 
            print('Loading {}th epoch'.format(best_t))
            
        model.load_state_dict(torch.load(self.model_name + '.pkl'))

        embeds = model.embed(features, sp_adj)
        self.embeddings = embeds
    
    def get_embedding(self):
        return np.squeeze(self.embeddings.numpy())
    

### Large Experiments - Remove 

In [9]:
with open('./data/pubmed.pkl', 'rb') as file: 
    graph_dict = pkl.load(file)
    
graph = nx.Graph(nx.to_numpy_array(graph_dict))    
graph = nx.Graph(nx.to_numpy_array(graph))


sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')


Calculating Personalized Page Rank...                     

19717it [12:35, 26.10it/s]


Calculating Katz Centrality...                            

  A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight).todense().T


Normalizing Features Between 0 And 1...                   Done                                                      

In [10]:
uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

In [11]:
results = {}

for run in tqdm(range(1)):
    for d in [16, 32, 64, 128, 256]:

        results[d] = {}

        gmi_og = GMIEmbedding(graph = graph, 
                   embed_dim = d, 
                   feature_matrix = sense_features, 
                   use_xm = False, 
                   ortho_ = 0, 
                   sparse_ = 0, 
                   batch_size = 1)
        embed_og = gmi_og.get_embedding()
        embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
        feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                            embed_name = 'GMI-SF',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_og = feature_dict_og['explain_norm']
        error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
        explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)


        gmi_plus = GMIEmbedding(graph = graph, 
                   embed_dim = d, 
                   feature_matrix = sense_features, 
                   use_xm = True, 
                   alpha = 0.04, 
                   beta = 0.5, 
                   gamma = 0.5,
                   ortho_ = 1e-2, 
                   sparse_ = 1e-1, 
                   batch_size = 1, 
                   epoch_flag = 500)
        embed_plus = gmi_plus.get_embedding()
        embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
        feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                            embed_name = 'GMI+XM',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_plus = feature_dict_plus['explain_norm']
        error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])
        explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)


        Y_plus = embed_plus
        sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
        Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
        D_plus = tf.transpose(tf.transpose(sense_mat) / norm)
        D_plus = (D_plus - tf.reshape(tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_plus, axis = [-1, -2]) - tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))


        Y_og = embed_og
        sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
        Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
        D_og = tf.transpose(tf.transpose(sense_mat) / norm)
        D_og = (D_og - tf.reshape(tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_og, axis = [-1, -2]) - tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))

        norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]
        norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]
        
        orth_og = np.squeeze(tf.reduce_sum(D_og @ tf.transpose(D_og, perm = (0, 2, 1)), axis = [1, 2]))
        orth_plus = np.squeeze(tf.reduce_sum(D_plus @ tf.transpose(D_plus, perm = (0, 2, 1)), axis = [1, 2]))

        explain_og_norm = np.linalg.norm(explain_og, ord = 'nuc')
        explain_plus_norm = np.linalg.norm(explain_plus, ord = 'nuc')

        try:
            results[d]['norm_og'].append(norm_og)
            results[d]['norm_plus'].append(norm_plus)
            results[d]['orth_og'].append(orth_og)
            results[d]['orth_plus'].append(orth_plus)
            results[d]['explain_og_norm'].append(explain_og_norm)
            results[d]['explain_plus_norm'].append(explain_plus_norm)
            results[d]['gmi_og_time'].append(gmi_og.time_per_epoch)
            results[d]['gmi+xm_time'].append(gmi_plus.time_per_epoch)
            results[d]['error_og'].append(error_og)
            results[d]['error_plus'].append(error_plus)

            results[d]['embed_og'].append(embed_og)
            results[d]['embed_plus'].append(embed_plus)


        except: 
            results[d]['norm_og'] = [norm_og]
            results[d]['norm_plus'] = [norm_plus]
            results[d]['orth_og'] = [orth_og]
            results[d]['orth_plus'] = [orth_plus]
            results[d]['explain_og_norm'] = [explain_og_norm]
            results[d]['explain_plus_norm'] = [explain_plus_norm]
            results[d]['gmi_og_time'] = [gmi_og.time_per_epoch]
            results[d]['gmi+xm_time'] = [gmi_plus.time_per_epoch]
            results[d]['error_og'] = [error_og]
            results[d]['error_plus'] = [error_plus]

            results[d]['embed_og'] = [embed_og]
            results[d]['embed_plus'] = [embed_plus]

        with open('./results/pubmed_gmi.pkl', 'wb') as file: 
            pkl.dump(results, file)


  0%|                                                     | 0/1 [00:00<?, ?it/s]
  0%|                                                   | 0/500 [00:00<?, ?it/s][A
  0%|                                           | 1/500 [00:03<27:56,  3.36s/it][A
  0%|▏                                          | 2/500 [00:06<28:04,  3.38s/it][A
  1%|▎                                          | 3/500 [00:10<27:56,  3.37s/it][A
  1%|▎                                          | 4/500 [00:13<27:51,  3.37s/it][A
  1%|▍                                          | 5/500 [00:16<27:51,  3.38s/it][A
  1%|▌                                          | 6/500 [00:20<27:54,  3.39s/it][A
  1%|▌                                          | 7/500 [00:23<27:40,  3.37s/it][A
  2%|▋                                          | 8/500 [00:26<27:27,  3.35s/it][A
  2%|▊                                          | 9/500 [00:30<27:25,  3.35s/it][A
  2%|▊                                         | 10/500 [00:33<27:19,  3.35s/it

 19%|████████                                  | 96/500 [05:21<22:35,  3.36s/it][A
 19%|████████▏                                 | 97/500 [05:25<22:34,  3.36s/it][A
 20%|████████▏                                 | 98/500 [05:28<22:26,  3.35s/it][A
 20%|████████▎                                 | 99/500 [05:31<22:23,  3.35s/it][A
 20%|████████▏                                | 100/500 [05:35<22:15,  3.34s/it][A
 20%|████████▎                                | 101/500 [05:38<22:22,  3.37s/it][A
 20%|████████▎                                | 102/500 [05:41<22:17,  3.36s/it][A
 21%|████████▍                                | 103/500 [05:45<22:11,  3.35s/it][A
 21%|████████▌                                | 104/500 [05:48<22:22,  3.39s/it][A
 21%|████████▌                                | 105/500 [05:52<22:13,  3.37s/it][A
 21%|████████▋                                | 106/500 [05:55<22:05,  3.36s/it][A
 21%|████████▊                                | 107/500 [05:58<21:58,  3.35s

 39%|███████████████▊                         | 193/500 [10:47<17:13,  3.37s/it][A
 39%|███████████████▉                         | 194/500 [10:50<17:09,  3.36s/it][A
 39%|███████████████▉                         | 195/500 [10:53<17:09,  3.37s/it][A
 39%|████████████████                         | 196/500 [10:57<17:09,  3.39s/it][A
 39%|████████████████▏                        | 197/500 [11:00<17:02,  3.38s/it][A
 40%|████████████████▏                        | 198/500 [11:04<16:59,  3.38s/it][A
 40%|████████████████▎                        | 199/500 [11:07<16:51,  3.36s/it][A
 40%|████████████████▍                        | 200/500 [11:10<16:45,  3.35s/it][A
 40%|████████████████▍                        | 201/500 [11:14<16:46,  3.37s/it][A
 40%|████████████████▌                        | 202/500 [11:17<16:41,  3.36s/it][A
 41%|████████████████▋                        | 203/500 [11:20<16:40,  3.37s/it][A
 41%|████████████████▋                        | 204/500 [11:24<16:34,  3.36s

 58%|███████████████████████▊                 | 290/500 [16:13<11:45,  3.36s/it][A
 58%|███████████████████████▊                 | 291/500 [16:16<11:45,  3.37s/it][A
 58%|███████████████████████▉                 | 292/500 [16:20<11:42,  3.38s/it][A
 59%|████████████████████████                 | 293/500 [16:23<11:35,  3.36s/it][A
 59%|████████████████████████                 | 294/500 [16:26<11:38,  3.39s/it][A
 59%|████████████████████████▏                | 295/500 [16:30<11:33,  3.38s/it][A
 59%|████████████████████████▎                | 296/500 [16:33<11:28,  3.37s/it][A
 59%|████████████████████████▎                | 297/500 [16:36<11:25,  3.38s/it][A
 60%|████████████████████████▍                | 298/500 [16:40<11:20,  3.37s/it][A
 60%|████████████████████████▌                | 299/500 [16:43<11:17,  3.37s/it][A
 60%|████████████████████████▌                | 300/500 [16:46<11:11,  3.36s/it][A
 60%|████████████████████████▋                | 301/500 [16:50<11:07,  3.36s

 77%|███████████████████████████████▋         | 387/500 [21:41<06:24,  3.40s/it][A
 78%|███████████████████████████████▊         | 388/500 [21:45<06:21,  3.40s/it][A
 78%|███████████████████████████████▉         | 389/500 [21:48<06:16,  3.39s/it][A
 78%|███████████████████████████████▉         | 390/500 [21:51<06:14,  3.41s/it][A
 78%|████████████████████████████████         | 391/500 [21:55<06:12,  3.42s/it][A
 78%|████████████████████████████████▏        | 392/500 [21:58<06:07,  3.40s/it][A
 79%|████████████████████████████████▏        | 393/500 [22:02<06:01,  3.38s/it][A
 79%|████████████████████████████████▎        | 394/500 [22:05<06:00,  3.40s/it][A
 79%|████████████████████████████████▍        | 395/500 [22:08<05:57,  3.41s/it][A
 79%|████████████████████████████████▍        | 396/500 [22:12<05:55,  3.42s/it][A
 79%|████████████████████████████████▌        | 397/500 [22:15<05:50,  3.40s/it][A
 80%|████████████████████████████████▋        | 398/500 [22:19<05:47,  3.40s

 97%|███████████████████████████████████████▋ | 484/500 [27:11<00:54,  3.42s/it][A
 97%|███████████████████████████████████████▊ | 485/500 [27:14<00:51,  3.43s/it][A
 97%|███████████████████████████████████████▊ | 486/500 [27:18<00:47,  3.41s/it][A
 97%|███████████████████████████████████████▉ | 487/500 [27:21<00:44,  3.42s/it][A
 98%|████████████████████████████████████████ | 488/500 [27:24<00:40,  3.41s/it][A
 98%|████████████████████████████████████████ | 489/500 [27:28<00:37,  3.41s/it][A
 98%|████████████████████████████████████████▏| 490/500 [27:31<00:34,  3.41s/it][A
 98%|████████████████████████████████████████▎| 491/500 [27:35<00:30,  3.42s/it][A
 98%|████████████████████████████████████████▎| 492/500 [27:38<00:27,  3.41s/it][A
 99%|████████████████████████████████████████▍| 493/500 [27:41<00:23,  3.41s/it][A
 99%|████████████████████████████████████████▌| 494/500 [27:45<00:20,  3.40s/it][A
 99%|████████████████████████████████████████▌| 495/500 [27:48<00:17,  3.40s

 16%|██████▋                                   | 80/500 [04:40<24:37,  3.52s/it][A
 16%|██████▊                                   | 81/500 [04:43<24:32,  3.51s/it][A
 16%|██████▉                                   | 82/500 [04:47<24:33,  3.53s/it][A
 17%|██████▉                                   | 83/500 [04:50<24:24,  3.51s/it][A
 17%|███████                                   | 84/500 [04:54<24:24,  3.52s/it][A
 17%|███████▏                                  | 85/500 [04:57<24:17,  3.51s/it][A
 17%|███████▏                                  | 86/500 [05:01<24:15,  3.52s/it][A
 17%|███████▎                                  | 87/500 [05:04<24:06,  3.50s/it][A
 18%|███████▍                                  | 88/500 [05:08<24:10,  3.52s/it][A
 18%|███████▍                                  | 89/500 [05:11<24:00,  3.50s/it][A
 18%|███████▌                                  | 90/500 [05:15<23:58,  3.51s/it][A
 18%|███████▋                                  | 91/500 [05:18<23:53,  3.51s

 35%|██████████████▌                          | 177/500 [10:19<18:47,  3.49s/it][A
 36%|██████████████▌                          | 178/500 [10:23<18:47,  3.50s/it][A
 36%|██████████████▋                          | 179/500 [10:26<18:46,  3.51s/it][A
 36%|██████████████▊                          | 180/500 [10:30<18:41,  3.50s/it][A
 36%|██████████████▊                          | 181/500 [10:33<18:39,  3.51s/it][A
 36%|██████████████▉                          | 182/500 [10:37<18:38,  3.52s/it][A
 37%|███████████████                          | 183/500 [10:40<18:32,  3.51s/it][A
 37%|███████████████                          | 184/500 [10:44<18:27,  3.51s/it][A
 37%|███████████████▏                         | 185/500 [10:47<18:23,  3.50s/it][A
 37%|███████████████▎                         | 186/500 [10:51<18:19,  3.50s/it][A
 37%|███████████████▎                         | 187/500 [10:54<18:15,  3.50s/it][A
 38%|███████████████▍                         | 188/500 [10:58<18:12,  3.50s

 55%|██████████████████████▍                  | 274/500 [16:00<13:14,  3.52s/it][A
 55%|██████████████████████▌                  | 275/500 [16:03<13:11,  3.52s/it][A
 55%|██████████████████████▋                  | 276/500 [16:07<13:07,  3.52s/it][A
 55%|██████████████████████▋                  | 277/500 [16:10<13:05,  3.52s/it][A
 56%|██████████████████████▊                  | 278/500 [16:14<13:02,  3.53s/it][A
 56%|██████████████████████▉                  | 279/500 [16:17<12:58,  3.52s/it][A
 56%|██████████████████████▉                  | 280/500 [16:21<12:55,  3.52s/it][A
 56%|███████████████████████                  | 281/500 [16:24<12:53,  3.53s/it][A
 56%|███████████████████████                  | 282/500 [16:28<12:48,  3.52s/it][A
 57%|███████████████████████▏                 | 283/500 [16:31<12:44,  3.52s/it][A
 57%|███████████████████████▎                 | 284/500 [16:35<12:40,  3.52s/it][A
 57%|███████████████████████▎                 | 285/500 [16:38<12:41,  3.54s

 74%|██████████████████████████████▍          | 371/500 [21:42<07:39,  3.56s/it][A
 74%|██████████████████████████████▌          | 372/500 [21:46<07:36,  3.56s/it][A
 75%|██████████████████████████████▌          | 373/500 [21:49<07:31,  3.56s/it][A
 75%|██████████████████████████████▋          | 374/500 [21:53<07:26,  3.54s/it][A
 75%|██████████████████████████████▊          | 375/500 [21:57<07:24,  3.55s/it][A
 75%|██████████████████████████████▊          | 376/500 [22:00<07:20,  3.55s/it][A
 75%|██████████████████████████████▉          | 377/500 [22:04<07:17,  3.56s/it][A
 76%|██████████████████████████████▉          | 378/500 [22:07<07:14,  3.56s/it][A
 76%|███████████████████████████████          | 379/500 [22:11<07:09,  3.55s/it][A
 76%|███████████████████████████████▏         | 380/500 [22:14<07:06,  3.55s/it][A
 76%|███████████████████████████████▏         | 381/500 [22:18<07:02,  3.55s/it][A
 76%|███████████████████████████████▎         | 382/500 [22:21<06:59,  3.55s

 94%|██████████████████████████████████████▍  | 468/500 [27:31<01:56,  3.65s/it][A
 94%|██████████████████████████████████████▍  | 469/500 [27:35<01:52,  3.64s/it][A
 94%|██████████████████████████████████████▌  | 470/500 [27:39<01:48,  3.63s/it][A
 94%|██████████████████████████████████████▌  | 471/500 [27:42<01:45,  3.63s/it][A
 94%|██████████████████████████████████████▋  | 472/500 [27:46<01:41,  3.64s/it][A
 95%|██████████████████████████████████████▊  | 473/500 [27:49<01:38,  3.64s/it][A
 95%|██████████████████████████████████████▊  | 474/500 [27:53<01:34,  3.64s/it][A
 95%|██████████████████████████████████████▉  | 475/500 [27:57<01:31,  3.66s/it][A
 95%|███████████████████████████████████████  | 476/500 [28:00<01:27,  3.65s/it][A
 95%|███████████████████████████████████████  | 477/500 [28:04<01:24,  3.65s/it][A
 96%|███████████████████████████████████████▏ | 478/500 [28:08<01:20,  3.66s/it][A
 96%|███████████████████████████████████████▎ | 479/500 [28:11<01:16,  3.66s

  0%|                                      | 64/19717 [00:16<1:24:47,  3.86it/s][A
  0%|▏                                     | 65/19717 [00:17<1:24:40,  3.87it/s][A
  0%|▏                                     | 66/19717 [00:17<1:24:33,  3.87it/s][A
  0%|▏                                     | 67/19717 [00:17<1:24:32,  3.87it/s][A
  0%|▏                                     | 68/19717 [00:17<1:24:24,  3.88it/s][A
  0%|▏                                     | 69/19717 [00:18<1:24:16,  3.89it/s][A
  0%|▏                                     | 70/19717 [00:18<1:24:11,  3.89it/s][A
  0%|▏                                     | 71/19717 [00:18<1:24:12,  3.89it/s][A
  0%|▏                                     | 72/19717 [00:18<1:23:49,  3.91it/s][A
  0%|▏                                     | 73/19717 [00:19<1:27:32,  3.74it/s][A
  0%|▏                                     | 74/19717 [00:19<1:26:38,  3.78it/s][A
  0%|▏                                     | 75/19717 [00:19<1:26:06,  3.80i

  1%|▎                                    | 161/19717 [00:42<1:23:04,  3.92it/s][A
  1%|▎                                    | 162/19717 [00:42<1:23:29,  3.90it/s][A
  1%|▎                                    | 163/19717 [00:42<1:27:07,  3.74it/s][A
  1%|▎                                    | 164/19717 [00:43<1:30:18,  3.61it/s][A
  1%|▎                                    | 165/19717 [00:43<1:29:16,  3.65it/s][A
  1%|▎                                    | 166/19717 [00:43<1:30:44,  3.59it/s][A
  1%|▎                                    | 167/19717 [00:43<1:27:48,  3.71it/s][A
  1%|▎                                    | 168/19717 [00:44<1:26:31,  3.77it/s][A
  1%|▎                                    | 169/19717 [00:44<1:28:57,  3.66it/s][A
  1%|▎                                    | 170/19717 [00:44<1:26:38,  3.76it/s][A
  1%|▎                                    | 171/19717 [00:44<1:25:08,  3.83it/s][A
  1%|▎                                    | 172/19717 [00:45<1:24:43,  3.84i

  1%|▍                                    | 258/19717 [01:07<1:23:41,  3.88it/s][A
  1%|▍                                    | 259/19717 [01:07<1:24:27,  3.84it/s][A
  1%|▍                                    | 260/19717 [01:07<1:23:59,  3.86it/s][A
  1%|▍                                    | 261/19717 [01:08<1:23:45,  3.87it/s][A
  1%|▍                                    | 262/19717 [01:08<1:22:36,  3.93it/s][A
  1%|▍                                    | 263/19717 [01:08<1:22:14,  3.94it/s][A
  1%|▍                                    | 264/19717 [01:08<1:24:35,  3.83it/s][A
  1%|▍                                    | 265/19717 [01:09<1:24:04,  3.86it/s][A
  1%|▍                                    | 266/19717 [01:09<1:24:07,  3.85it/s][A
  1%|▌                                    | 267/19717 [01:09<1:23:43,  3.87it/s][A
  1%|▌                                    | 268/19717 [01:09<1:23:38,  3.88it/s][A
  1%|▌                                    | 269/19717 [01:10<1:23:29,  3.88i

  2%|▋                                    | 355/19717 [01:32<1:23:41,  3.86it/s][A
  2%|▋                                    | 356/19717 [01:32<1:23:08,  3.88it/s][A
  2%|▋                                    | 357/19717 [01:33<1:24:08,  3.83it/s][A
  2%|▋                                    | 358/19717 [01:33<1:23:39,  3.86it/s][A
  2%|▋                                    | 359/19717 [01:33<1:23:15,  3.87it/s][A
  2%|▋                                    | 360/19717 [01:33<1:23:02,  3.89it/s][A
  2%|▋                                    | 361/19717 [01:34<1:22:49,  3.89it/s][A
  2%|▋                                    | 362/19717 [01:34<1:22:47,  3.90it/s][A
  2%|▋                                    | 363/19717 [01:34<1:24:43,  3.81it/s][A
  2%|▋                                    | 364/19717 [01:34<1:24:00,  3.84it/s][A
  2%|▋                                    | 365/19717 [01:35<1:24:52,  3.80it/s][A
  2%|▋                                    | 366/19717 [01:35<1:24:30,  3.82i

  2%|▊                                    | 452/19717 [01:57<1:23:25,  3.85it/s][A
  2%|▊                                    | 453/19717 [01:57<1:22:50,  3.88it/s][A
  2%|▊                                    | 454/19717 [01:58<1:22:06,  3.91it/s][A
  2%|▊                                    | 455/19717 [01:58<1:25:13,  3.77it/s][A
  2%|▊                                    | 456/19717 [01:58<1:24:12,  3.81it/s][A
  2%|▊                                    | 457/19717 [01:58<1:23:31,  3.84it/s][A
  2%|▊                                    | 458/19717 [01:59<1:22:31,  3.89it/s][A
  2%|▊                                    | 459/19717 [01:59<1:22:18,  3.90it/s][A
  2%|▊                                    | 460/19717 [01:59<1:21:37,  3.93it/s][A
  2%|▊                                    | 461/19717 [01:59<1:22:10,  3.91it/s][A
  2%|▊                                    | 462/19717 [02:00<1:22:06,  3.91it/s][A
  2%|▊                                    | 463/19717 [02:00<1:22:19,  3.90i

  3%|█                                    | 549/19717 [02:22<1:22:35,  3.87it/s][A
  3%|█                                    | 550/19717 [02:22<1:22:18,  3.88it/s][A
  3%|█                                    | 551/19717 [02:23<1:22:34,  3.87it/s][A
  3%|█                                    | 552/19717 [02:23<1:22:14,  3.88it/s][A
  3%|█                                    | 553/19717 [02:23<1:21:23,  3.92it/s][A
  3%|█                                    | 554/19717 [02:23<1:20:33,  3.96it/s][A
  3%|█                                    | 555/19717 [02:24<1:21:12,  3.93it/s][A
  3%|█                                    | 556/19717 [02:24<1:21:43,  3.91it/s][A
  3%|█                                    | 557/19717 [02:24<1:21:33,  3.92it/s][A
  3%|█                                    | 558/19717 [02:25<1:21:29,  3.92it/s][A
  3%|█                                    | 559/19717 [02:25<1:21:30,  3.92it/s][A
  3%|█                                    | 560/19717 [02:25<1:21:12,  3.93i

  3%|█▏                                   | 646/19717 [02:47<1:26:19,  3.68it/s][A
  3%|█▏                                   | 647/19717 [02:48<1:24:48,  3.75it/s][A
  3%|█▏                                   | 648/19717 [02:48<1:27:21,  3.64it/s][A
  3%|█▏                                   | 649/19717 [02:48<1:25:43,  3.71it/s][A
  3%|█▏                                   | 650/19717 [02:48<1:24:22,  3.77it/s][A
  3%|█▏                                   | 651/19717 [02:49<1:23:22,  3.81it/s][A
  3%|█▏                                   | 652/19717 [02:49<1:22:12,  3.87it/s][A
  3%|█▏                                   | 653/19717 [02:49<1:21:47,  3.88it/s][A
  3%|█▏                                   | 654/19717 [02:49<1:25:16,  3.73it/s][A
  3%|█▏                                   | 655/19717 [02:50<1:22:58,  3.83it/s][A
  3%|█▏                                   | 656/19717 [02:50<1:21:49,  3.88it/s][A
  3%|█▏                                   | 657/19717 [02:50<1:20:52,  3.93i

  4%|█▍                                   | 743/19717 [03:12<1:21:12,  3.89it/s][A
  4%|█▍                                   | 744/19717 [03:13<1:21:06,  3.90it/s][A
  4%|█▍                                   | 745/19717 [03:13<1:21:55,  3.86it/s][A
  4%|█▍                                   | 746/19717 [03:13<1:24:11,  3.76it/s][A
  4%|█▍                                   | 747/19717 [03:13<1:23:44,  3.78it/s][A
  4%|█▍                                   | 748/19717 [03:14<1:24:49,  3.73it/s][A
  4%|█▍                                   | 749/19717 [03:14<1:23:39,  3.78it/s][A
  4%|█▍                                   | 750/19717 [03:14<1:22:02,  3.85it/s][A
  4%|█▍                                   | 751/19717 [03:14<1:21:03,  3.90it/s][A
  4%|█▍                                   | 752/19717 [03:15<1:20:57,  3.90it/s][A
  4%|█▍                                   | 753/19717 [03:15<1:20:54,  3.91it/s][A
  4%|█▍                                   | 754/19717 [03:15<1:20:51,  3.91i

  4%|█▌                                   | 840/19717 [03:37<1:20:42,  3.90it/s][A
  4%|█▌                                   | 841/19717 [03:38<1:20:31,  3.91it/s][A
  4%|█▌                                   | 842/19717 [03:38<1:20:28,  3.91it/s][A
  4%|█▌                                   | 843/19717 [03:38<1:19:37,  3.95it/s][A
  4%|█▌                                   | 844/19717 [03:38<1:19:20,  3.96it/s][A
  4%|█▌                                   | 845/19717 [03:39<1:18:48,  3.99it/s][A
  4%|█▌                                   | 846/19717 [03:39<1:19:05,  3.98it/s][A
  4%|█▌                                   | 847/19717 [03:39<1:19:27,  3.96it/s][A
  4%|█▌                                   | 848/19717 [03:40<1:20:52,  3.89it/s][A
  4%|█▌                                   | 849/19717 [03:40<1:20:46,  3.89it/s][A
  4%|█▌                                   | 850/19717 [03:40<1:22:42,  3.80it/s][A
  4%|█▌                                   | 851/19717 [03:40<1:21:39,  3.85i

  5%|█▊                                   | 937/19717 [04:03<1:21:37,  3.83it/s][A
  5%|█▊                                   | 938/19717 [04:03<1:21:18,  3.85it/s][A
  5%|█▊                                   | 939/19717 [04:03<1:22:53,  3.78it/s][A
  5%|█▊                                   | 940/19717 [04:03<1:22:22,  3.80it/s][A
  5%|█▊                                   | 941/19717 [04:04<1:21:52,  3.82it/s][A
  5%|█▊                                   | 942/19717 [04:04<1:22:03,  3.81it/s][A
  5%|█▊                                   | 943/19717 [04:04<1:25:18,  3.67it/s][A
  5%|█▊                                   | 944/19717 [04:04<1:23:43,  3.74it/s][A
  5%|█▊                                   | 945/19717 [04:05<1:22:37,  3.79it/s][A
  5%|█▊                                   | 946/19717 [04:05<1:23:00,  3.77it/s][A
  5%|█▊                                   | 947/19717 [04:05<1:21:39,  3.83it/s][A
  5%|█▊                                   | 948/19717 [04:05<1:21:52,  3.82i

  5%|█▉                                  | 1034/19717 [04:28<1:23:03,  3.75it/s][A
  5%|█▉                                  | 1035/19717 [04:28<1:22:08,  3.79it/s][A
  5%|█▉                                  | 1036/19717 [04:28<1:24:55,  3.67it/s][A
  5%|█▉                                  | 1037/19717 [04:28<1:23:27,  3.73it/s][A
  5%|█▉                                  | 1038/19717 [04:29<1:21:19,  3.83it/s][A
  5%|█▉                                  | 1039/19717 [04:29<1:19:53,  3.90it/s][A
  5%|█▉                                  | 1040/19717 [04:29<1:19:35,  3.91it/s][A
  5%|█▉                                  | 1041/19717 [04:29<1:18:49,  3.95it/s][A
  5%|█▉                                  | 1042/19717 [04:30<1:19:08,  3.93it/s][A
  5%|█▉                                  | 1043/19717 [04:30<1:19:28,  3.92it/s][A
  5%|█▉                                  | 1044/19717 [04:30<1:19:43,  3.90it/s][A
  5%|█▉                                  | 1045/19717 [04:31<1:21:58,  3.80i

  6%|██                                  | 1131/19717 [04:53<1:19:42,  3.89it/s][A
  6%|██                                  | 1132/19717 [04:53<1:19:49,  3.88it/s][A
  6%|██                                  | 1133/19717 [04:53<1:19:46,  3.88it/s][A
  6%|██                                  | 1134/19717 [04:54<1:19:38,  3.89it/s][A
  6%|██                                  | 1135/19717 [04:54<1:20:32,  3.84it/s][A
  6%|██                                  | 1136/19717 [04:54<1:20:27,  3.85it/s][A
  6%|██                                  | 1137/19717 [04:54<1:20:21,  3.85it/s][A
  6%|██                                  | 1138/19717 [04:55<1:23:32,  3.71it/s][A
  6%|██                                  | 1139/19717 [04:55<1:22:25,  3.76it/s][A
  6%|██                                  | 1140/19717 [04:55<1:20:44,  3.83it/s][A
  6%|██                                  | 1141/19717 [04:55<1:20:23,  3.85it/s][A
  6%|██                                  | 1142/19717 [04:56<1:19:02,  3.92i

  6%|██▏                                 | 1228/19717 [05:18<1:18:32,  3.92it/s][A
  6%|██▏                                 | 1229/19717 [05:18<1:17:40,  3.97it/s][A
  6%|██▏                                 | 1230/19717 [05:18<1:18:05,  3.95it/s][A
  6%|██▏                                 | 1231/19717 [05:19<1:18:21,  3.93it/s][A
  6%|██▏                                 | 1232/19717 [05:19<1:18:22,  3.93it/s][A
  6%|██▎                                 | 1233/19717 [05:19<1:18:24,  3.93it/s][A
  6%|██▎                                 | 1234/19717 [05:19<1:17:37,  3.97it/s][A
  6%|██▎                                 | 1235/19717 [05:20<1:18:06,  3.94it/s][A
  6%|██▎                                 | 1236/19717 [05:20<1:17:32,  3.97it/s][A
  6%|██▎                                 | 1237/19717 [05:20<1:17:49,  3.96it/s][A
  6%|██▎                                 | 1238/19717 [05:21<1:18:08,  3.94it/s][A
  6%|██▎                                 | 1239/19717 [05:21<1:21:57,  3.76i

  7%|██▍                                 | 1325/19717 [05:43<1:19:55,  3.84it/s][A
  7%|██▍                                 | 1326/19717 [05:43<1:19:22,  3.86it/s][A
  7%|██▍                                 | 1327/19717 [05:44<1:19:07,  3.87it/s][A
  7%|██▍                                 | 1328/19717 [05:44<1:18:59,  3.88it/s][A
  7%|██▍                                 | 1329/19717 [05:44<1:18:52,  3.89it/s][A
  7%|██▍                                 | 1330/19717 [05:44<1:18:18,  3.91it/s][A
  7%|██▍                                 | 1331/19717 [05:45<1:18:21,  3.91it/s][A
  7%|██▍                                 | 1332/19717 [05:45<1:18:22,  3.91it/s][A
  7%|██▍                                 | 1333/19717 [05:45<1:22:00,  3.74it/s][A
  7%|██▍                                 | 1334/19717 [05:45<1:21:01,  3.78it/s][A
  7%|██▍                                 | 1335/19717 [05:46<1:21:50,  3.74it/s][A
  7%|██▍                                 | 1336/19717 [05:46<1:20:58,  3.78i

  7%|██▌                                 | 1422/19717 [06:08<1:17:50,  3.92it/s][A
  7%|██▌                                 | 1423/19717 [06:08<1:17:50,  3.92it/s][A
  7%|██▌                                 | 1424/19717 [06:09<1:17:01,  3.96it/s][A
  7%|██▌                                 | 1425/19717 [06:09<1:17:24,  3.94it/s][A
  7%|██▌                                 | 1426/19717 [06:09<1:17:40,  3.92it/s][A
  7%|██▌                                 | 1427/19717 [06:09<1:16:57,  3.96it/s][A
  7%|██▌                                 | 1428/19717 [06:10<1:17:22,  3.94it/s][A
  7%|██▌                                 | 1429/19717 [06:10<1:16:45,  3.97it/s][A
  7%|██▌                                 | 1430/19717 [06:10<1:17:13,  3.95it/s][A
  7%|██▌                                 | 1431/19717 [06:10<1:17:32,  3.93it/s][A
  7%|██▌                                 | 1432/19717 [06:11<1:21:48,  3.73it/s][A
  7%|██▌                                 | 1433/19717 [06:11<1:21:22,  3.74i

  8%|██▊                                 | 1519/19717 [06:33<1:16:49,  3.95it/s][A
  8%|██▊                                 | 1520/19717 [06:34<1:17:12,  3.93it/s][A
  8%|██▊                                 | 1521/19717 [06:34<1:17:27,  3.92it/s][A
  8%|██▊                                 | 1522/19717 [06:34<1:16:36,  3.96it/s][A
  8%|██▊                                 | 1523/19717 [06:34<1:17:03,  3.93it/s][A
  8%|██▊                                 | 1524/19717 [06:35<1:18:29,  3.86it/s][A
  8%|██▊                                 | 1525/19717 [06:35<1:18:23,  3.87it/s][A
  8%|██▊                                 | 1526/19717 [06:35<1:17:38,  3.90it/s][A
  8%|██▊                                 | 1527/19717 [06:35<1:17:38,  3.90it/s][A
  8%|██▊                                 | 1528/19717 [06:36<1:17:45,  3.90it/s][A
  8%|██▊                                 | 1529/19717 [06:36<1:17:58,  3.89it/s][A
  8%|██▊                                 | 1530/19717 [06:36<1:17:58,  3.89i

  8%|██▉                                 | 1616/19717 [06:59<1:15:56,  3.97it/s][A
  8%|██▉                                 | 1617/19717 [06:59<1:16:26,  3.95it/s][A
  8%|██▉                                 | 1618/19717 [06:59<1:18:02,  3.87it/s][A
  8%|██▉                                 | 1619/19717 [06:59<1:17:38,  3.88it/s][A
  8%|██▉                                 | 1620/19717 [07:00<1:17:36,  3.89it/s][A
  8%|██▉                                 | 1621/19717 [07:00<1:17:39,  3.88it/s][A
  8%|██▉                                 | 1622/19717 [07:00<1:16:43,  3.93it/s][A
  8%|██▉                                 | 1623/19717 [07:00<1:16:34,  3.94it/s][A
  8%|██▉                                 | 1624/19717 [07:01<1:17:42,  3.88it/s][A
  8%|██▉                                 | 1625/19717 [07:01<1:17:40,  3.88it/s][A
  8%|██▉                                 | 1626/19717 [07:01<1:21:04,  3.72it/s][A
  8%|██▉                                 | 1627/19717 [07:01<1:19:27,  3.79i

  9%|███▏                                | 1713/19717 [07:24<1:16:22,  3.93it/s][A
  9%|███▏                                | 1714/19717 [07:24<1:16:53,  3.90it/s][A
  9%|███▏                                | 1715/19717 [07:24<1:17:03,  3.89it/s][A
  9%|███▏                                | 1716/19717 [07:24<1:17:11,  3.89it/s][A
  9%|███▏                                | 1717/19717 [07:25<1:17:32,  3.87it/s][A
  9%|███▏                                | 1718/19717 [07:25<1:16:30,  3.92it/s][A
  9%|███▏                                | 1719/19717 [07:25<1:16:13,  3.93it/s][A
  9%|███▏                                | 1720/19717 [07:26<1:15:29,  3.97it/s][A
  9%|███▏                                | 1721/19717 [07:26<1:16:13,  3.93it/s][A
  9%|███▏                                | 1722/19717 [07:26<1:15:51,  3.95it/s][A
  9%|███▏                                | 1723/19717 [07:26<1:15:52,  3.95it/s][A
  9%|███▏                                | 1724/19717 [07:27<1:16:17,  3.93i

  9%|███▎                                | 1810/19717 [07:49<1:17:21,  3.86it/s][A
  9%|███▎                                | 1811/19717 [07:49<1:17:01,  3.87it/s][A
  9%|███▎                                | 1812/19717 [07:49<1:16:59,  3.88it/s][A
  9%|███▎                                | 1813/19717 [07:50<1:17:03,  3.87it/s][A
  9%|███▎                                | 1814/19717 [07:50<1:17:04,  3.87it/s][A
  9%|███▎                                | 1815/19717 [07:50<1:17:07,  3.87it/s][A
  9%|███▎                                | 1816/19717 [07:50<1:16:09,  3.92it/s][A
  9%|███▎                                | 1817/19717 [07:51<1:16:24,  3.90it/s][A
  9%|███▎                                | 1818/19717 [07:51<1:16:19,  3.91it/s][A
  9%|███▎                                | 1819/19717 [07:51<1:15:37,  3.94it/s][A
  9%|███▎                                | 1820/19717 [07:51<1:16:00,  3.92it/s][A
  9%|███▎                                | 1821/19717 [07:52<1:16:17,  3.91i

 10%|███▍                                | 1907/19717 [08:14<1:16:42,  3.87it/s][A
 10%|███▍                                | 1908/19717 [08:14<1:16:42,  3.87it/s][A
 10%|███▍                                | 1909/19717 [08:15<1:16:42,  3.87it/s][A
 10%|███▍                                | 1910/19717 [08:15<1:16:39,  3.87it/s][A
 10%|███▍                                | 1911/19717 [08:15<1:16:41,  3.87it/s][A
 10%|███▍                                | 1912/19717 [08:15<1:16:42,  3.87it/s][A
 10%|███▍                                | 1913/19717 [08:16<1:16:39,  3.87it/s][A
 10%|███▍                                | 1914/19717 [08:16<1:16:40,  3.87it/s][A
 10%|███▍                                | 1915/19717 [08:16<1:18:24,  3.78it/s][A
 10%|███▍                                | 1916/19717 [08:16<1:17:49,  3.81it/s][A
 10%|███▌                                | 1917/19717 [08:17<1:17:35,  3.82it/s][A
 10%|███▌                                | 1918/19717 [08:17<1:17:15,  3.84i

 10%|███▋                                | 2004/19717 [08:39<1:15:28,  3.91it/s][A
 10%|███▋                                | 2005/19717 [08:40<1:15:50,  3.89it/s][A
 10%|███▋                                | 2006/19717 [08:40<1:15:55,  3.89it/s][A
 10%|███▋                                | 2007/19717 [08:40<1:16:06,  3.88it/s][A
 10%|███▋                                | 2008/19717 [08:40<1:15:33,  3.91it/s][A
 10%|███▋                                | 2009/19717 [08:41<1:17:37,  3.80it/s][A
 10%|███▋                                | 2010/19717 [08:41<1:17:15,  3.82it/s][A
 10%|███▋                                | 2011/19717 [08:41<1:17:05,  3.83it/s][A
 10%|███▋                                | 2012/19717 [08:41<1:16:50,  3.84it/s][A
 10%|███▋                                | 2013/19717 [08:42<1:16:37,  3.85it/s][A
 10%|███▋                                | 2014/19717 [08:42<1:20:14,  3.68it/s][A
 10%|███▋                                | 2015/19717 [08:42<1:19:19,  3.72i

 11%|███▊                                | 2101/19717 [09:05<1:14:08,  3.96it/s][A
 11%|███▊                                | 2102/19717 [09:05<1:14:37,  3.93it/s][A
 11%|███▊                                | 2103/19717 [09:05<1:14:56,  3.92it/s][A
 11%|███▊                                | 2104/19717 [09:05<1:15:11,  3.90it/s][A
 11%|███▊                                | 2105/19717 [09:06<1:15:08,  3.91it/s][A
 11%|███▊                                | 2106/19717 [09:06<1:14:20,  3.95it/s][A
 11%|███▊                                | 2107/19717 [09:06<1:14:48,  3.92it/s][A
 11%|███▊                                | 2108/19717 [09:06<1:15:05,  3.91it/s][A
 11%|███▊                                | 2109/19717 [09:07<1:14:17,  3.95it/s][A
 11%|███▊                                | 2110/19717 [09:07<1:13:46,  3.98it/s][A
 11%|███▊                                | 2111/19717 [09:07<1:13:28,  3.99it/s][A
 11%|███▊                                | 2112/19717 [09:07<1:14:01,  3.96i

 11%|████                                | 2198/19717 [09:30<1:13:11,  3.99it/s][A
 11%|████                                | 2199/19717 [09:30<1:13:21,  3.98it/s][A
 11%|████                                | 2200/19717 [09:30<1:14:03,  3.94it/s][A
 11%|████                                | 2201/19717 [09:30<1:13:50,  3.95it/s][A
 11%|████                                | 2202/19717 [09:31<1:14:06,  3.94it/s][A
 11%|████                                | 2203/19717 [09:31<1:14:27,  3.92it/s][A
 11%|████                                | 2204/19717 [09:31<1:17:57,  3.74it/s][A
 11%|████                                | 2205/19717 [09:31<1:17:11,  3.78it/s][A
 11%|████                                | 2206/19717 [09:32<1:16:46,  3.80it/s][A
 11%|████                                | 2207/19717 [09:32<1:18:10,  3.73it/s][A
 11%|████                                | 2208/19717 [09:32<1:16:20,  3.82it/s][A
 11%|████                                | 2209/19717 [09:33<1:16:04,  3.84i

 12%|████▏                               | 2295/19717 [09:55<1:16:50,  3.78it/s][A
 12%|████▏                               | 2296/19717 [09:55<1:16:14,  3.81it/s][A
 12%|████▏                               | 2297/19717 [09:55<1:14:53,  3.88it/s][A
 12%|████▏                               | 2298/19717 [09:56<1:14:55,  3.87it/s][A
 12%|████▏                               | 2299/19717 [09:56<1:15:01,  3.87it/s][A
 12%|████▏                               | 2300/19717 [09:56<1:14:07,  3.92it/s][A
 12%|████▏                               | 2301/19717 [09:56<1:14:38,  3.89it/s][A
 12%|████▏                               | 2302/19717 [09:57<1:13:46,  3.93it/s][A
 12%|████▏                               | 2303/19717 [09:57<1:13:12,  3.96it/s][A
 12%|████▏                               | 2304/19717 [09:57<1:13:03,  3.97it/s][A
 12%|████▏                               | 2305/19717 [09:57<1:12:42,  3.99it/s][A
 12%|████▏                               | 2306/19717 [09:58<1:13:25,  3.95i

 12%|████▎                               | 2392/19717 [10:20<1:17:27,  3.73it/s][A
 12%|████▎                               | 2393/19717 [10:20<1:15:51,  3.81it/s][A
 12%|████▎                               | 2394/19717 [10:20<1:14:53,  3.86it/s][A
 12%|████▎                               | 2395/19717 [10:21<1:13:48,  3.91it/s][A
 12%|████▎                               | 2396/19717 [10:21<1:14:02,  3.90it/s][A
 12%|████▍                               | 2397/19717 [10:21<1:14:15,  3.89it/s][A
 12%|████▍                               | 2398/19717 [10:21<1:13:25,  3.93it/s][A
 12%|████▍                               | 2399/19717 [10:22<1:12:50,  3.96it/s][A
 12%|████▍                               | 2400/19717 [10:22<1:13:41,  3.92it/s][A
 12%|████▍                               | 2401/19717 [10:22<1:14:13,  3.89it/s][A
 12%|████▍                               | 2402/19717 [10:23<1:13:24,  3.93it/s][A
 12%|████▍                               | 2403/19717 [10:23<1:13:28,  3.93i

 13%|████▌                               | 2489/19717 [10:45<1:14:25,  3.86it/s][A
 13%|████▌                               | 2490/19717 [10:45<1:16:08,  3.77it/s][A
 13%|████▌                               | 2491/19717 [10:46<1:15:34,  3.80it/s][A
 13%|████▌                               | 2492/19717 [10:46<1:15:15,  3.81it/s][A
 13%|████▌                               | 2493/19717 [10:46<1:17:42,  3.69it/s][A
 13%|████▌                               | 2494/19717 [10:47<1:16:37,  3.75it/s][A
 13%|████▌                               | 2495/19717 [10:47<1:15:52,  3.78it/s][A
 13%|████▌                               | 2496/19717 [10:47<1:16:53,  3.73it/s][A
 13%|████▌                               | 2497/19717 [10:47<1:16:56,  3.73it/s][A
 13%|████▌                               | 2498/19717 [10:48<1:16:22,  3.76it/s][A
 13%|████▌                               | 2499/19717 [10:48<1:15:21,  3.81it/s][A
 13%|████▌                               | 2500/19717 [10:48<1:15:25,  3.80i

 13%|████▋                               | 2586/19717 [11:11<1:17:54,  3.66it/s][A
 13%|████▋                               | 2587/19717 [11:11<1:17:18,  3.69it/s][A
 13%|████▋                               | 2588/19717 [11:11<1:16:30,  3.73it/s][A
 13%|████▋                               | 2589/19717 [11:11<1:15:44,  3.77it/s][A
 13%|████▋                               | 2590/19717 [11:12<1:16:10,  3.75it/s][A
 13%|████▋                               | 2591/19717 [11:12<1:15:13,  3.79it/s][A
 13%|████▋                               | 2592/19717 [11:12<1:15:19,  3.79it/s][A
 13%|████▋                               | 2593/19717 [11:12<1:13:58,  3.86it/s][A
 13%|████▋                               | 2594/19717 [11:13<1:17:04,  3.70it/s][A
 13%|████▋                               | 2595/19717 [11:13<1:16:03,  3.75it/s][A
 13%|████▋                               | 2596/19717 [11:13<1:19:23,  3.59it/s][A
 13%|████▋                               | 2597/19717 [11:13<1:17:43,  3.67i

 14%|████▉                               | 2683/19717 [11:36<1:13:42,  3.85it/s][A
 14%|████▉                               | 2684/19717 [11:36<1:12:45,  3.90it/s][A
 14%|████▉                               | 2685/19717 [11:36<1:12:57,  3.89it/s][A
 14%|████▉                               | 2686/19717 [11:37<1:12:15,  3.93it/s][A
 14%|████▉                               | 2687/19717 [11:37<1:15:40,  3.75it/s][A
 14%|████▉                               | 2688/19717 [11:37<1:15:29,  3.76it/s][A
 14%|████▉                               | 2689/19717 [11:37<1:18:22,  3.62it/s][A
 14%|████▉                               | 2690/19717 [11:38<1:16:55,  3.69it/s][A
 14%|████▉                               | 2691/19717 [11:38<1:15:23,  3.76it/s][A
 14%|████▉                               | 2692/19717 [11:38<1:14:04,  3.83it/s][A
 14%|████▉                               | 2693/19717 [11:38<1:13:08,  3.88it/s][A
 14%|████▉                               | 2694/19717 [11:39<1:13:19,  3.87i

 14%|█████                               | 2780/19717 [12:01<1:17:21,  3.65it/s][A
 14%|█████                               | 2781/19717 [12:01<1:15:41,  3.73it/s][A
 14%|█████                               | 2782/19717 [12:02<1:18:59,  3.57it/s][A
 14%|█████                               | 2783/19717 [12:02<1:17:18,  3.65it/s][A
 14%|█████                               | 2784/19717 [12:02<1:15:25,  3.74it/s][A
 14%|█████                               | 2785/19717 [12:02<1:15:51,  3.72it/s][A
 14%|█████                               | 2786/19717 [12:03<1:15:07,  3.76it/s][A
 14%|█████                               | 2787/19717 [12:03<1:17:46,  3.63it/s][A
 14%|█████                               | 2788/19717 [12:03<1:15:32,  3.73it/s][A
 14%|█████                               | 2789/19717 [12:04<1:17:22,  3.65it/s][A
 14%|█████                               | 2790/19717 [12:04<1:16:47,  3.67it/s][A
 14%|█████                               | 2791/19717 [12:04<1:14:47,  3.77i

 15%|█████▎                              | 2877/19717 [12:27<1:13:21,  3.83it/s][A
 15%|█████▎                              | 2878/19717 [12:27<1:13:17,  3.83it/s][A
 15%|█████▎                              | 2879/19717 [12:27<1:13:15,  3.83it/s][A
 15%|█████▎                              | 2880/19717 [12:27<1:13:11,  3.83it/s][A
 15%|█████▎                              | 2881/19717 [12:28<1:13:30,  3.82it/s][A
 15%|█████▎                              | 2882/19717 [12:28<1:13:21,  3.83it/s][A
 15%|█████▎                              | 2883/19717 [12:28<1:14:19,  3.77it/s][A
 15%|█████▎                              | 2884/19717 [12:28<1:13:56,  3.79it/s][A
 15%|█████▎                              | 2885/19717 [12:29<1:13:40,  3.81it/s][A
 15%|█████▎                              | 2886/19717 [12:29<1:12:33,  3.87it/s][A
 15%|█████▎                              | 2887/19717 [12:29<1:14:24,  3.77it/s][A
 15%|█████▎                              | 2888/19717 [12:29<1:13:16,  3.83i

 15%|█████▍                              | 2974/19717 [12:52<1:14:57,  3.72it/s][A
 15%|█████▍                              | 2975/19717 [12:52<1:13:40,  3.79it/s][A
 15%|█████▍                              | 2976/19717 [12:53<1:13:23,  3.80it/s][A
 15%|█████▍                              | 2977/19717 [12:53<1:13:28,  3.80it/s][A
 15%|█████▍                              | 2978/19717 [12:53<1:16:31,  3.65it/s][A
 15%|█████▍                              | 2979/19717 [12:53<1:15:18,  3.70it/s][A
 15%|█████▍                              | 2980/19717 [12:54<1:14:30,  3.74it/s][A
 15%|█████▍                              | 2981/19717 [12:54<1:13:58,  3.77it/s][A
 15%|█████▍                              | 2982/19717 [12:54<1:13:50,  3.78it/s][A
 15%|█████▍                              | 2983/19717 [12:54<1:13:17,  3.81it/s][A
 15%|█████▍                              | 2984/19717 [12:55<1:13:54,  3.77it/s][A
 15%|█████▍                              | 2985/19717 [12:55<1:13:31,  3.79i

 16%|█████▌                              | 3071/19717 [13:18<1:15:45,  3.66it/s][A
 16%|█████▌                              | 3072/19717 [13:18<1:14:41,  3.71it/s][A
 16%|█████▌                              | 3073/19717 [13:18<1:12:58,  3.80it/s][A
 16%|█████▌                              | 3074/19717 [13:18<1:11:52,  3.86it/s][A
 16%|█████▌                              | 3075/19717 [13:19<1:14:45,  3.71it/s][A
 16%|█████▌                              | 3076/19717 [13:19<1:14:03,  3.75it/s][A
 16%|█████▌                              | 3077/19717 [13:19<1:13:16,  3.78it/s][A
 16%|█████▌                              | 3078/19717 [13:19<1:12:57,  3.80it/s][A
 16%|█████▌                              | 3079/19717 [13:20<1:12:42,  3.81it/s][A
 16%|█████▌                              | 3080/19717 [13:20<1:11:42,  3.87it/s][A
 16%|█████▋                              | 3081/19717 [13:20<1:11:08,  3.90it/s][A
 16%|█████▋                              | 3082/19717 [13:20<1:10:33,  3.93i

 16%|█████▊                              | 3168/19717 [13:43<1:11:37,  3.85it/s][A
 16%|█████▊                              | 3169/19717 [13:43<1:11:46,  3.84it/s][A
 16%|█████▊                              | 3170/19717 [13:44<1:11:49,  3.84it/s][A
 16%|█████▊                              | 3171/19717 [13:44<1:13:41,  3.74it/s][A
 16%|█████▊                              | 3172/19717 [13:44<1:13:11,  3.77it/s][A
 16%|█████▊                              | 3173/19717 [13:44<1:11:51,  3.84it/s][A
 16%|█████▊                              | 3174/19717 [13:45<1:11:53,  3.84it/s][A
 16%|█████▊                              | 3175/19717 [13:45<1:12:13,  3.82it/s][A
 16%|█████▊                              | 3176/19717 [13:45<1:11:09,  3.87it/s][A
 16%|█████▊                              | 3177/19717 [13:45<1:11:21,  3.86it/s][A
 16%|█████▊                              | 3178/19717 [13:46<1:11:31,  3.85it/s][A
 16%|█████▊                              | 3179/19717 [13:46<1:11:38,  3.85i

 17%|█████▋                            | 3265/19717 [33:08<434:34:42, 95.09s/it][A
 17%|█████▋                            | 3266/19717 [34:45<436:41:29, 95.56s/it][A
 17%|█████▋                            | 3267/19717 [36:21<437:52:57, 95.83s/it][A
 17%|█████▋                            | 3268/19717 [37:58<438:37:14, 96.00s/it][A
 17%|█████▋                            | 3269/19717 [39:34<439:26:39, 96.18s/it][A
 17%|█████▋                            | 3270/19717 [41:11<439:57:37, 96.30s/it][A
 17%|█████▋                            | 3271/19717 [42:47<440:08:48, 96.35s/it][A
 17%|█████▋                            | 3272/19717 [44:24<440:09:42, 96.36s/it][A
 17%|█████▋                            | 3273/19717 [46:00<440:08:01, 96.36s/it][A
 17%|█████▋                            | 3274/19717 [47:36<440:13:55, 96.38s/it][A
 17%|█████▋                            | 3275/19717 [49:13<440:31:11, 96.45s/it][A
 17%|█████▋                            | 3276/19717 [50:49<440:23:59, 96.43s

 17%|█████▍                          | 3362/19717 [3:09:27<439:18:13, 96.70s/it][A
 17%|█████▍                          | 3363/19717 [3:11:04<439:22:06, 96.72s/it][A
 17%|█████▍                          | 3364/19717 [3:12:40<439:07:32, 96.67s/it][A
 17%|█████▍                          | 3365/19717 [3:14:17<439:18:41, 96.72s/it][A
 17%|█████▍                          | 3366/19717 [3:15:53<438:52:46, 96.63s/it][A
 17%|█████▍                          | 3367/19717 [3:17:30<438:35:48, 96.57s/it][A
 17%|█████▍                          | 3368/19717 [3:19:06<438:33:11, 96.57s/it][A
 17%|█████▍                          | 3369/19717 [3:20:43<438:21:57, 96.53s/it][A
 17%|█████▍                          | 3370/19717 [3:22:19<438:12:29, 96.50s/it][A
 17%|█████▍                          | 3371/19717 [3:23:56<438:15:04, 96.52s/it][A
 17%|█████▍                          | 3372/19717 [3:25:32<438:17:05, 96.53s/it][A
 17%|█████▍                          | 3373/19717 [3:27:09<438:08:10, 96.51s

 18%|█████▌                          | 3459/19717 [5:45:30<436:04:19, 96.56s/it][A
 18%|█████▌                          | 3460/19717 [5:47:06<435:53:54, 96.53s/it][A
 18%|█████▌                          | 3461/19717 [5:48:43<435:54:57, 96.54s/it][A
 18%|█████▌                          | 3462/19717 [5:50:19<435:42:42, 96.50s/it][A
 18%|█████▌                          | 3463/19717 [5:51:55<435:28:05, 96.45s/it][A
 18%|█████▌                          | 3464/19717 [5:53:32<435:16:09, 96.41s/it][A
 18%|█████▌                          | 3465/19717 [5:55:08<435:26:56, 96.46s/it][A
 18%|█████▋                          | 3466/19717 [5:56:45<435:52:24, 96.56s/it][A
 18%|█████▋                          | 3467/19717 [5:58:22<435:50:53, 96.56s/it][A
 18%|█████▋                          | 3468/19717 [5:59:58<435:49:43, 96.56s/it][A
 18%|█████▋                          | 3469/19717 [6:01:35<435:49:02, 96.56s/it][A
 18%|█████▋                          | 3470/19717 [6:03:11<435:28:03, 96.49s

 18%|█████▊                          | 3556/19717 [8:21:47<433:32:18, 96.57s/it][A
 18%|█████▊                          | 3557/19717 [8:23:24<433:29:05, 96.57s/it][A
 18%|█████▊                          | 3558/19717 [8:25:00<433:09:15, 96.50s/it][A
 18%|█████▊                          | 3559/19717 [8:26:36<433:12:14, 96.52s/it][A
 18%|█████▊                          | 3560/19717 [8:28:13<433:05:20, 96.50s/it][A
 18%|█████▊                          | 3561/19717 [8:29:49<433:11:58, 96.53s/it][A
 18%|█████▊                          | 3562/19717 [8:31:26<432:54:15, 96.47s/it][A
 18%|█████▊                          | 3563/19717 [8:33:02<432:44:10, 96.44s/it][A
 18%|█████▊                          | 3564/19717 [8:34:39<432:51:52, 96.47s/it][A
 18%|█████▊                          | 3565/19717 [8:36:15<433:15:33, 96.57s/it][A
 18%|█████▊                          | 3566/19717 [8:37:52<433:05:44, 96.54s/it][A
 18%|█████▊                          | 3567/19717 [8:39:29<433:13:55, 96.57s

KeyboardInterrupt: 

In [63]:
gmi_og = GMIEmbedding(graph = graph, 
           embed_dim = 128, 
           feature_matrix = sense_features, 
           use_xm = False, 
           ortho_ = 0, 
           sparse_ = 0, 
           batch_size = 1)
embed_og = gmi_og.get_embedding()
embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                    embed_name = 'GMI-SF',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

explain_og = feature_dict_og['explain_norm']
error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)


gmi_plus = GMIEmbedding(graph = graph, 
           embed_dim = 128, 
           feature_matrix = sense_features, 
           use_xm = True, 
           alpha = 0.04, 
           beta = 0.5, 
           gamma = 0.5,
           ortho_ = 1e-2, 
           sparse_ = 1e-1, 
           batch_size = 1, 
           epoch_flag = 500)
embed_plus = gmi_plus.get_embedding()
embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                    embed_name = 'GMI+XM',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

explain_plus = feature_dict_plus['explain_norm']
error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])
explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)

100%|█████████████████████████████████████████| 500/500 [00:48<00:00, 10.22it/s]
100%|█████████████████████████████████████████| 500/500 [01:22<00:00,  6.04it/s]


In [64]:
fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_og,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI : ' + str(np.linalg.norm(explain_og, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()

fig = go.Figure()
fig.add_trace(go.Heatmap(z = explain_plus,
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = 'DGI+ : ' + str(np.linalg.norm(explain_plus, ord = 'nuc')), 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions')

fig.show()

In [65]:
Y_plus = embed_plus
sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
D_plus = tf.transpose(tf.transpose(sense_mat) / norm)
D_plus = (D_plus - tf.reshape(tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_plus, axis = [-1, -2]) - tf.reduce_min(D_plus, axis = [-1, -2]), (-1, 1, 1))


Y_og = embed_og

sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
D_og = tf.transpose(tf.transpose(sense_mat) / norm)
D_og = (D_og - tf.reshape(tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_og, axis = [-1, -2]) - tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))


In [66]:
norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]
norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in tqdm(range(len(graph)))]



100%|███████████████████████████████████████| 1186/1186 [00:14<00:00, 82.68it/s]
100%|██████████████████████████████████████| 1186/1186 [00:09<00:00, 128.44it/s]


In [67]:
diff = np.array(norm_og) - np.array(norm_plus)
fig = go.Figure()
fig.add_trace(go.Histogram(x = diff))
fig.update_layout(title_text = 'Distribution of Difference In Nuclear Norms - GMI vs GMI+XM', 
                  xaxis_title_text = 'Difference In Nuclear Norms - GMI vs GMI+XM', 
                  yaxis_title_text = 'Frequency', 
                  plot_bgcolor = 'white', 
                  paper_bgcolor = 'white', 
                  font = dict(size = 30))
fig.show()

In [68]:
fig = go.Figure()

fig.add_trace(go.Histogram(x = norm_og, 
                           name = 'GMI'))
fig.add_trace(go.Histogram(x = norm_plus, 
                           name = 'GMI+XM'))


fig.update_layout(title_text = 'Distribution of Nuclear Norm of Node Explain Matrix', 
                  xaxis_title_text = 'Nuclear Norm', 
                  yaxis_title_text = 'Frequency', 
                  paper_bgcolor = 'white', 
                  plot_bgcolor = 'white', 
                  font = dict(size = 20))
fig.show()

### Visualisation

In [2]:
with open('./results/email_dgi.pkl', 'rb') as file:
    results = pkl.load(file)
    
dimensions = list(results.keys())
runs = len(results[64]['norm_og'])

print("Runs : ", runs)

Runs :  15


In [8]:
results[16].keys()

dict_keys(['norm_og', 'norm_plus', 'explain_og_norm', 'explain_plus_norm', 'dgi_og_time', 'dgi+xm_time', 'error_og', 'error_plus', 'embed_og', 'embed_plus', 'norm_id', 'explain_id_norm', 'dgi_id_time', 'error_id', 'embed_id'])

In [9]:
a = results[16]['norm_og'] 
results[16]['norm_og'] = results[16]['norm_plus']
results[16]['norm_plus'] = a

In [10]:
norm_mean = []
norm_std = []

norm_plus_mean = []
norm_plus_std = []

norm_id_mean = []
norm_id_std = []

for d in dimensions:
    
    norm_mean.append(np.nanmean(np.array(results[d]['norm_og'])))
    norm_std.append(np.nanstd(np.array(results[d]['norm_og'])) / np.sqrt(runs))
    
    norm_plus_mean.append(np.nanmean(np.array(results[d]['norm_plus'])))
    norm_plus_std.append(np.nanstd(np.array(results[d]['norm_plus'])) / np.sqrt(runs))
    
#     norm_id_mean.append(np.mean(np.array(results[d]['norm_id'])))
#     norm_id_std.append(np.std(np.array(results[d]['norm_id'])) / np.sqrt(runs))

In [30]:
scipy.stats.ttest_ind_from_stats(norm_mean, norm_std, 15, norm_plus_mean, norm_plus_std, 15)

Ttest_indResult(statistic=array([ 9.12288613, 10.99860777, 12.83327174, 24.99099928, 13.65560912]), pvalue=array([7.01413283e-10, 1.13046395e-11, 3.00776272e-13, 1.09390510e-20,
       6.66244058e-14]))

In [126]:
fig = go.Figure()
marker_size = 10
text_size = 20
method = 'DGI'

fig.add_trace(go.Scatter(x = dimensions, 
                         y = norm_mean, 
                         error_y = dict(type = 'data', 
                                        array = norm_std), 
                         mode = 'markers', 
                         name = method, 
                         marker = dict(size = marker_size)))

fig.add_trace(go.Scatter(x = dimensions, 
                         y = norm_plus_mean, 
                         error_y = dict(type = 'data', 
                                        array = norm_plus_std), 
                         mode = 'markers', 
                         name = method + '+XM', 
                         marker = dict(size = marker_size)))

# fig.add_trace(go.Scatter(x = dimensions, 
#                          y = norm_id_mean, 
#                          error_y = dict(type = 'data', 
#                                         array = norm_id_std), 
#                          mode = 'markers', 
#                          name = method + ' ID', 
#                          marker = dict(size = marker_size)))


fig.update_layout(xaxis_title_text = 'Dimensions (log scale)',
                  yaxis_title_text = 'Nuclear Norm (log scale)', 
                  plot_bgcolor = 'white',
                  paper_bgcolor = 'white', 
                  font = dict(size = text_size))

fig.update_xaxes(type = 'log')
fig.update_yaxes(type = 'log')
fig.show()

In [16]:
d = 128
idx = 0
norm_og = results[d]['norm_og'][idx]
norm_plus = results[d]['norm_plus'][idx]


fig = go.Figure()

fig.add_trace(go.Histogram(x = norm_og, 
                           name = 'DGI'))
fig.add_trace(go.Histogram(x = norm_plus, 
                           name = 'DGI+XM'))


fig.update_layout(title_text = 'Distribution of Nuclear Norm of Node Explain Matrix', 
                  xaxis_title_text = 'Nuclear Norm', 
                  yaxis_title_text = 'Frequency', 
                  paper_bgcolor = 'white', 
                  plot_bgcolor = 'white', 
                  font = dict(size = text_size))
fig.show()

In [24]:
d = 128
idx = 0
norm_og = results[d]['norm_og'][idx]
norm_plus = results[d]['norm_plus'][idx]
diff = np.array(norm_og) - np.array(norm_plus)

fig = go.Figure()

fig.add_trace(go.Histogram(x = diff, 
                           name = 'DGI'))



fig.update_layout(title_text = 'Distribution of Difference In Nuclear Norm of Node Explain Matrix', 
                  xaxis_title_text = 'Difference In Nuclear Norms', 
                  yaxis_title_text = 'Frequency', 
                  paper_bgcolor = 'white', 
                  plot_bgcolor = 'white', 
                  font = dict(size = text_size))
fig.show()

In [127]:
with open('./results/email_lp.pkl', 'rb') as file: 
    email_results = pkl.load(file)

# Acc
id_0 = 0
id_1 = 3

# AUC 
# id_0 = 1
# id_1 = 4

# AUP
# id_0 = 2
# id_1 = 5
#########################
####### Aggregate #######
#########################
sdne_xm = np.mean(email_results['sdne'], axis = 0)[[id_0, id_1]][0]
sdne = np.mean(email_results['sdne'], axis = 0)[[id_0, id_1]][1]
sdne_xm_std = np.std(email_results['sdne'], axis = 0)[[id_0, id_1]][0]
sdne_std = np.std(email_results['sdne'], axis = 0)[[id_0, id_1]][1]

line_xm = np.mean(email_results['line'], axis = 0)[[id_0, id_1]][0]
line = np.mean(email_results['line'], axis = 0)[[id_0, id_1]][1]
line_xm_std = np.std(email_results['line'], axis = 0)[[id_0, id_1]][0]
line_std = np.std(email_results['line'], axis = 0)[[id_0, id_1]][1]

dgi_xm = np.mean(email_results['dgi'], axis = 0)[[id_0, id_1]][0]
dgi = np.mean(email_results['dgi'], axis = 0)[[id_0, id_1]][1]
dgi_xm_std = np.std(email_results['dgi'], axis = 0)[[id_0, id_1]][0]
dgi_std = np.std(email_results['dgi'], axis = 0)[[id_0, id_1]][1]

gmi_xm = np.mean(email_results['gmi'], axis = 0)[[id_0, id_1]][0]
gmi = np.mean(email_results['gmi'], axis = 0)[[id_0, id_1]][1]
gmi_xm_std = np.std(email_results['gmi'], axis = 0)[[id_0, id_1]][0]
gmi_std = np.std(email_results['gmi'], axis = 0)[[id_0, id_1]][1]


In [113]:
dim = 128

with open('./results/email_dgi.pkl', 'rb') as file:
    results = pkl.load(file)
    
dgi_norm = np.nanmean(results[dim]['norm_og'])
dgi_norm_std = np.nanstd(results[dim]['norm_og'])

dgi_xm_norm = np.nanmean(results[dim]['norm_plus'])
dgi_xm_norm_std = np.nanstd(results[dim]['norm_plus'])

with open('./results/email_gmi.pkl', 'rb') as file:
    results = pkl.load(file)
    
gmi_norm = np.nanmean(results[dim]['norm_og'])
gmi_norm_std = np.nanstd(results[dim]['norm_og'])

gmi_xm_norm = np.nanmean(results[dim]['norm_plus'])
gmi_xm_norm_std = np.nanstd(results[dim]['norm_plus'])

with open('./results/email_sdne.pkl', 'rb') as file:
    results = pkl.load(file)

sdne_norm = np.nanmean(results[dim]['norm_og'])
sdne_norm_std = np.nanstd(results[dim]['norm_og'])

sdne_xm_norm = np.nanmean(results[dim]['norm_plus'])
sdne_xm_norm_std = np.nanstd(results[dim]['norm_plus'])


with open('./results/email_line.pkl', 'rb') as file:
    results = pkl.load(file)
    
line_norm = np.nanmean(results[dim]['norm_og'])
line_norm_std = np.nanstd(results[dim]['norm_og'])

line_xm_norm = np.nanmean(results[dim]['norm_plus'])
line_xm_norm_std = np.nanstd(results[dim]['norm_plus'])

In [131]:
fig = go.Figure()
marker_size = 15
fig.add_trace(go.Scatter(x = [sdne], 
                        y = [sdne_norm], 
                        error_x = dict(type = 'data', 
                                       array = [sdne_std]), 
                        error_y = dict(type = 'data', 
                                       array = [sdne_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'SDNE', 
                        legendgroup = 'sdne'))

fig.add_trace(go.Scatter(x = [sdne_xm + 0.02], 
                        y = [sdne_xm_norm], 
                        error_x = dict(type = 'data', 
                                       array = [sdne_xm_std]), 
                        error_y = dict(type = 'data', 
                                       array = [sdne_xm_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'SDNE+XM', 
                        legendgroup = 'sdne'))

fig.add_trace(go.Scatter(x = [line], 
                        y = [line_norm], 
                        error_x = dict(type = 'data', 
                                       array = [line_std]), 
                        error_y = dict(type = 'data', 
                                       array = [line_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'LINE', 
                        legendgroup = 'line'))

fig.add_trace(go.Scatter(x = [line_xm + 0.05], 
                        y = [line_xm_norm], 
                        error_x = dict(type = 'data', 
                                       array = [line_xm_std]), 
                        error_y = dict(type = 'data', 
                                       array = [line_xm_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'LINE+XM', 
                        legendgroup = 'line'))

fig.add_trace(go.Scatter(x = [dgi], 
                        y = [dgi_norm], 
                        error_x = dict(type = 'data', 
                                       array = [dgi_std]), 
                        error_y = dict(type = 'data', 
                                       array = [dgi_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'DGI', 
                        legendgroup = 'dgi'))

fig.add_trace(go.Scatter(x = [dgi_xm], 
                        y = [dgi_xm_norm], 
                        error_x = dict(type = 'data', 
                                       array = [dgi_xm_std]), 
                        error_y = dict(type = 'data', 
                                       array = [dgi_xm_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'DGI+XM', 
                        legendgroup = 'dgi'))

fig.add_trace(go.Scatter(x = [gmi], 
                        y = [gmi_norm], 
                        error_x = dict(type = 'data', 
                                       array = [gmi_std]), 
                        error_y = dict(type = 'data', 
                                       array = [gmi_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'GMI', 
                        legendgroup = 'gmi'))

fig.add_trace(go.Scatter(x = [gmi_xm], 
                        y = [gmi_xm_norm], 
                        error_x = dict(type = 'data', 
                                       array = [gmi_xm_std]), 
                        error_y = dict(type = 'data', 
                                       array = [gmi_xm_norm_std]), 
                        mode = 'markers', 
                        marker = dict(size = marker_size), 
                        name = 'GMI+XM', 
                        legendgroup = 'gmi'))

fig.update_layout(title_text = 'AUC vs Nuclear Norm', 
                  xaxis_title_text = 'AUC', 
                  yaxis_title_text = 'Nuclear Norm (log scale)', 
                  plot_bgcolor = 'white', 
                  paper_bgcolor = 'white', 
                  font = dict(size = 20))
fig.update_xaxes(range = [0.75, 1])
fig.update_yaxes(type = 'log')

fig.show()

### Small Tests

In [67]:
clique_one = np.zeros((10, 10))
clique_one[9, :] = 1
clique_one[:, 9] = 1
clique_one[9, 9] = 0

In [96]:
clique_one = np.ones((4, 4)) - np.identity(4)
clique_two = np.ones((4, 4)) - np.identity(4)

barbell = np.block([[clique_one, np.zeros((4, 4))], 
                     [np.zeros((4, 4)), clique_two]])
barbell[2, 7] = 1
barbell[7, 2] = 1

graph = nx.Graph(barbell)

In [97]:
# graph = nx.Graph(clique_one)
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')
uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}
sense_features[np.isnan(sense_features)] = 0

Calculating Degrees...                                   Calculating Average Neighbor Degree...                    Calculating Clustering Coefficient...                     Calculating Average Neighbor Clustering Coefficients...   Calculating Eccentricity...                               Calculating Page Rank...                                  Calculating Personalized Page Rank...                     

8it [00:00, 35246.25it/s]

Calculating Node Betweenness...                           Calculating Number Of Edges In Ego Nets...                Calculating Structural Hole Constraint Scores...         Calculating Degree Centrality...                         Calculating Eigen Centrality...                          Calculating Katz Centrality...                           Normalizing Features Between 0 And 1...                   Done                                                      



adjacency_matrix will return a scipy.sparse array instead of a matrix in Networkx 3.0.



In [105]:
dgi_og = DGIEmbedding(graph = graph, 
           embed_dim = 16, 
           feature_matrix = np.identity(len(graph)), 
           use_xm = False, 
           ortho_ = 0, 
           sparse_ = 0, 
           batch_size = 1, 
           model_name = '-')
embed_og = dgi_og.get_embedding()
embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)

Y_og = embed_og
sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
D_og = tf.transpose(tf.transpose(sense_mat) / norm)
D_og = (D_og - tf.reshape(tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))) / tf.reshape(tf.reduce_max(D_og, axis = [-1, -2]) - tf.reduce_min(D_og, axis = [-1, -2]), (-1, 1, 1))






 12%|████▊                                   | 302/2500 [00:00<00:02, 1068.61it/s]


In [110]:
node = 1
fig = go.Figure()
fig.add_trace(go.Heatmap(z = D_og[node, :, :],
                         x = list(sense_feat_dict), 
                         ))
fig.update_layout(title_text = '', 
                  xaxis_title_text = 'Sense Features', 
                  yaxis_title_text = 'Dimensions', 
                  font = dict(size =  25), 
                  width = 1000, 
                  height = 1000)

fig.show()