### Imports

In [1]:
from scripts.utils import *

  from .autonotebook import tqdm as notebook_tqdm


### SDNE

In [7]:
filename = './data/email.pkl'
run_count = 1
hyp_key = 'hyp_pubmed'
outfile = './pubmed_test.pkl'

# ap = argparse.ArgumentParser()
# ap.add_argument("-g", "--graph_path", required = True, help = 'Path to an nx.Graph object stored as a .pkl file')
# ap.add_argument("-r", "--run_count", required = True, help = "Number of iterations for the experiment", default = 1)
# ap.add_argument("-k", "--hyp_key", required = True, help = "Key to index the hyperparameter json file")
# ap.add_argument("-o", "--outfile", required = True, help = "File name to save results into")

# args = vars(ap.parse_args())

# filename = args['graph_path']
# run_count = int(args['run_count'])
# hyp_key = args['hyp_key']
# outfile = args['outfile']

#################################
######### Read In Graph #########
#################################
with open(filename, 'rb') as file: 
    graph_dict = pkl.load(file)
    
try:
    graph = nx.Graph(nx.to_numpy_array(graph_dict['graph']))    
except:
    graph = nx.Graph(nx.to_numpy_array(graph_dict))


#################################
#### Generate Sense Features ####
#################################
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')

uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

#################################
######## Hyperparameters ########
#################################

# Define static ones to override or read in from a file

if hyp_key == '':
    hyp = {'sdne' : {'alpha' : 0.1, 
                     'beta' : 10, 
                     'gamma' : 0, 
                     'delta' : 0, 
                     'epochs' : 200, 
                     'batch_size' : 1024, 
                     'lr' : 1e-3}, 

          'sdne+xm' : {'alpha' : 1, 
                      'beta' : 1, 
                      'gamma' : 10, 
                      'delta' : 10, 
                      'epochs' : 400, 
                      'batch_size' : 1024, 
                      'lr' : 5e-4}}
else: 
    with open('hyp.json', 'r') as file: 
        hyp_file = json.load(file)
        hyp = hyp_file[hyp_key]


#################################
######## Run Experiment #########
#################################

dimensions = [16, 32, 64, 256, 512]
results = {d : {} for d in dimensions}
run_time = []

for run_idx in tqdm(range(run_count)):

    run_start = time.time()
    
    for d in dimensions: 
    
        # Embed 
        
        # Standard SDNE
        sdne_start = time.time()
        sdne = SDNE_plus(graph, 
                          hidden_size = [32, d], 
                          lr = hyp['sdne']['lr'],
                          sense_features = sense_features.astype(np.float32),
                          alpha = hyp['sdne']['alpha'], 
                          beta = hyp['sdne']['beta'], 
                          gamma = hyp['sdne']['gamma'], 
                          delta = hyp['sdne']['delta'])
        history = sdne.train(epochs = hyp['sdne']['epochs'], batch_size = hyp['sdne']['batch_size'])
        e = sdne.get_embeddings()
        embed_og = np.array([e[node_name] for node_name in graph.nodes()])
        embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
        sdne_time = (time.time() - sdne_start) / hyp['sdne']['epochs']

        # SDNE+XM
        sdne_plus_start = time.time()
        sdne_plus = SDNE_plus(graph, 
                                  hidden_size = [32, d], 
                                  lr = hyp['sdne+xm']['lr'],
                                  sense_features = sense_features.astype(np.float32),
                                  alpha = hyp['sdne+xm']['alpha'], 
                                  beta = hyp['sdne+xm']['beta'], 
                                  gamma = hyp['sdne+xm']['gamma'], 
                                  delta = hyp['sdne+xm']['delta'])

        sdne_plus.model.set_weights(sdne.model.get_weights())
        history = sdne_plus.train(epochs = hyp['sdne+xm']['epochs'], batch_size = hyp['sdne+xm']['batch_size'])
        e = sdne_plus.get_embeddings()
        embed_plus = np.array([e[node_name] for node_name in graph.nodes()])
        embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
        sdne_plus_time = (time.time() - sdne_plus_start) / hyp['sdne+xm']['epochs']

        
        # Generate Graph Explanations and Save
        feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                    embed_name = 'SDNE',
                                                    sense_features = sense_features,
                                                    sense_feat_dict = sense_feat_dict,
                                                    top_k = 8,
                                                    solver = 'nmf')

        explain_og = feature_dict_og['explain_norm']
        explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)
        explain_og_norm = np.linalg.norm(explain_og, ord = 'nuc')
        
        feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                            embed_name = 'SDNE+ Init',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_plus = feature_dict_plus['explain_norm']
        explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)
        explain_plus_norm = np.linalg.norm(explain_plus, ord = 'nuc')

        error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
        error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])


        # Generate Node Explanations
        Y_og = embed_og
        sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
        Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
        D_og = tf.transpose(tf.transpose(sense_mat) / norm)


        Y_plus = embed_plus
        sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
        Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
        D_plus = tf.transpose(tf.transpose(sense_mat) / norm)

        norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in range(len(graph))]
        norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in range(len(graph))]
        
        try:
            results[d]['norm_og'].append(norm_og)
            results[d]['norm_plus'].append(norm_plus)
            results[d]['explain_og_norm'].append(explain_og_norm)
            results[d]['explain_plus_norm'].append(explain_plus_norm)
            results[d]['sdne_time'].append(sdne_time)
            results[d]['sdne+xm_time'].append(sdne_plus_time)
            results[d]['error_og'].append(error_og)
            results[d]['error_plus'].append(error_plus)

            
        except: 
            results[d]['norm_og'] = [norm_og]
            results[d]['norm_plus'] = [norm_plus]
            results[d]['explain_og_norm'] = [explain_og_norm]
            results[d]['explain_plus_norm'] = [explain_plus_norm]
            results[d]['sdne_time'] = [sdne_time]
            results[d]['sdne+xm_time'] = [sdne_plus_time]
            results[d]['error_og'] = [error_og]
            results[d]['error_plus'] = [error_plus]

            
        results[d]['embed_og'] = embed_og
        results[d]['embed_plus'] = embed_plus
    
    with open(outfile, 'wb') as file: 
        pkl.dump(results, file)

    run_time.append(time.time() - run_start)

    with open(outfile + '_progress.txt', 'w') as file: 
        string = 'Current Run : ' + str(run_idx)
        string += '\nLast Iteration Time : ' + str(run_time[-1]) + 's'
        string += '\nAverage Iteration Time : ' + str(np.mean(run_time)) + 's'
        string += '\nEstimated Time Left : ' + str(np.mean(run_time) * (run_count - run_idx)) + 's'
        file.write(string)

Calculating Personalized Page Rank...                     

19717it [10:41, 30.71it/s]


Calculating Katz Centrality...                            

  A = nx.adjacency_matrix(G, nodelist=nodelist, weight=weight).todense().T


Normalizing Features Between 0 And 1...                   Done                                                      

  0%|                                                     | 0/1 [00:00<?, ?it/s]

Metal device set to: Apple M1 Max


2023-01-11 18:37:43.829551: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-01-11 18:37:43.829671: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)




2023-01-11 18:37:44.055194: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz
2023-01-11 18:37:44.084872: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2023-01-11 18:37:45.198547: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-01-11 18:37:46.129235: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-01-11 18:37:50.212241: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 1/100
5s - loss:  214.3038 - 2nd_loss:  214.1941 - 1st_loss:  0.0609 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 2/100
3s - loss:  172.6321 - 2nd_loss:  170.6372 - 1st_loss:  1.9474 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 3/100
3s - loss:  117.8922 - 2nd_loss:  113.0795 - 1st_loss:  4.7465 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 4/100
3s - loss:  88.4219 - 2nd_loss:  84.2189 - 1st_loss:  4.1279 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 5/100
3s - loss:  77.0123 - 2nd_loss:  73.3802 - 1st_loss:  3.5558 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 6/100
3s - loss:  72.5408 - 2nd_loss:  69.4113 - 1st_loss:  3.0554 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 7/100
3s - loss:  70.3458 - 2nd_loss:  67.5856 - 1st_loss:  2.6889 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 8/100
3s - loss:  69.2218 - 2nd_loss:  66.7743 - 1st_loss:  2.3788 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 9/100
3s - loss:  68.2685 - 

Epoch 69/100
3s - loss:  61.2043 - 2nd_loss:  60.9715 - 1st_loss:  0.1983 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 70/100
3s - loss:  61.1981 - 2nd_loss:  60.9692 - 1st_loss:  0.1947 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 71/100
3s - loss:  61.1699 - 2nd_loss:  60.9449 - 1st_loss:  0.1911 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 72/100
3s - loss:  61.1615 - 2nd_loss:  60.9401 - 1st_loss:  0.1878 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 73/100
3s - loss:  61.1383 - 2nd_loss:  60.9206 - 1st_loss:  0.1843 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 74/100
3s - loss:  61.1295 - 2nd_loss:  60.9153 - 1st_loss:  0.1811 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 75/100
3s - loss:  61.1037 - 2nd_loss:  60.8931 - 1st_loss:  0.1778 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 76/100
3s - loss:  61.0942 - 2nd_loss:  60.8867 - 1st_loss:  0.1749 - ortho_loss :  0.0000 - sparse_loss :  0.0000
Epoch 77/100
3s - loss:  61.0892

2023-01-11 18:43:15.114289: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.




2023-01-11 18:43:16.354676: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-01-11 18:43:16.936964: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
2023-01-11 18:43:20.998382: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


Epoch 1/200
5s - loss:  4926.6935 - 2nd_loss:  60.7219 - 1st_loss:  1.6128 - ortho_loss :  2431.0929 - sparse_loss :  2433.2384
Epoch 2/200
3s - loss:  3061.0865 - 2nd_loss:  62.6511 - 1st_loss:  2.9585 - ortho_loss :  1248.4230 - sparse_loss :  1747.0233
Epoch 3/200
3s - loss:  2257.4667 - 2nd_loss:  65.1503 - 1st_loss:  4.7753 - ortho_loss :  795.9141 - sparse_loss :  1391.5929
Epoch 4/200
3s - loss:  1817.1407 - 2nd_loss:  65.9483 - 1st_loss:  6.8487 - ortho_loss :  569.5246 - sparse_loss :  1174.7821
Epoch 5/200
3s - loss:  1537.7627 - 2nd_loss:  65.7747 - 1st_loss:  9.1536 - ortho_loss :  436.4413 - sparse_loss :  1026.3536
Epoch 6/200
3s - loss:  1332.2102 - 2nd_loss:  65.3804 - 1st_loss:  11.7363 - ortho_loss :  344.7327 - sparse_loss :  910.3188
Epoch 7/200
3s - loss:  1170.5989 - 2nd_loss:  65.0221 - 1st_loss:  14.5035 - ortho_loss :  276.6566 - sparse_loss :  814.3726
Epoch 8/200
3s - loss:  1049.6441 - 2nd_loss:  64.6827 - 1st_loss:  17.3415 - ortho_loss :  228.3695 - sparse

Epoch 66/200
3s - loss:  294.8731 - 2nd_loss:  57.3787 - 1st_loss:  73.8349 - ortho_loss :  9.7997 - sparse_loss :  153.7889
Epoch 67/200
3s - loss:  292.9228 - 2nd_loss:  57.3559 - 1st_loss:  73.9395 - ortho_loss :  9.5731 - sparse_loss :  151.9835
Epoch 68/200
3s - loss:  290.9946 - 2nd_loss:  57.3214 - 1st_loss:  73.9781 - ortho_loss :  9.3658 - sparse_loss :  150.2583
Epoch 69/200
3s - loss:  289.1585 - 2nd_loss:  57.3010 - 1st_loss:  74.0330 - ortho_loss :  9.1627 - sparse_loss :  148.5908
Epoch 70/200
3s - loss:  287.4265 - 2nd_loss:  57.2159 - 1st_loss:  74.0272 - ortho_loss :  8.9900 - sparse_loss :  147.1223
Epoch 71/200
3s - loss:  285.7145 - 2nd_loss:  57.1041 - 1st_loss:  74.0382 - ortho_loss :  8.8185 - sparse_loss :  145.6826
Epoch 72/200
3s - loss:  284.1176 - 2nd_loss:  57.0915 - 1st_loss:  74.0008 - ortho_loss :  8.6600 - sparse_loss :  144.2941
Epoch 73/200
3s - loss:  282.5181 - 2nd_loss:  57.0693 - 1st_loss:  73.9889 - ortho_loss :  8.4951 - sparse_loss :  142.8937


Epoch 132/200
3s - loss:  221.2570 - 2nd_loss:  54.6756 - 1st_loss:  60.5790 - ortho_loss :  4.3643 - sparse_loss :  101.5567
Epoch 133/200
3s - loss:  220.4682 - 2nd_loss:  54.6477 - 1st_loss:  60.2894 - ortho_loss :  4.3271 - sparse_loss :  101.1224
Epoch 134/200
3s - loss:  219.6933 - 2nd_loss:  54.6283 - 1st_loss:  59.9521 - ortho_loss :  4.2970 - sparse_loss :  100.7340
Epoch 135/200
3s - loss:  218.8394 - 2nd_loss:  54.5253 - 1st_loss:  59.6608 - ortho_loss :  4.2626 - sparse_loss :  100.3085
Epoch 136/200
3s - loss:  218.0343 - 2nd_loss:  54.4654 - 1st_loss:  59.3824 - ortho_loss :  4.2282 - sparse_loss :  99.8758
Epoch 137/200
3s - loss:  217.2582 - 2nd_loss:  54.4352 - 1st_loss:  59.1046 - ortho_loss :  4.1939 - sparse_loss :  99.4418
Epoch 138/200
3s - loss:  216.4982 - 2nd_loss:  54.4117 - 1st_loss:  58.8374 - ortho_loss :  4.1597 - sparse_loss :  99.0063
Epoch 139/200
3s - loss:  215.7644 - 2nd_loss:  54.4012 - 1st_loss:  58.5600 - ortho_loss :  4.1274 - sparse_loss :  98.5

Epoch 198/200
3s - loss:  172.5539 - 2nd_loss:  52.2450 - 1st_loss:  43.2988 - ortho_loss :  2.3417 - sparse_loss :  74.5673
Epoch 199/200
3s - loss:  171.4759 - 2nd_loss:  52.1869 - 1st_loss:  42.9632 - ortho_loss :  2.2994 - sparse_loss :  73.9249
Epoch 200/200
3s - loss:  170.5186 - 2nd_loss:  52.1273 - 1st_loss:  42.5260 - ortho_loss :  2.2738 - sparse_loss :  73.4897


  explain_norm_softmax = np.array([np.exp(x) / sum(np.exp(x)) for x in explain_norm])
  explain_norm_softmax = np.array([np.exp(x) / sum(np.exp(x)) for x in explain_norm])
  explain_norm_softmax = np.array([np.exp(x) / sum(np.exp(x)) for x in explain_norm])
  explain_norm_softmax = np.array([np.exp(x) / sum(np.exp(x)) for x in explain_norm])
  0%|                                                     | 0/1 [32:20<?, ?it/s]

KeyboardInterrupt



### LINE

In [7]:
filename = './data/email.pkl'
run_count = 1
hyp_key = 'hyp_email'
outfile = './email_line.pkl'

# ap = argparse.ArgumentParser()
# ap.add_argument("-g", "--graph_path", required = True, help = 'Path to an nx.Graph object stored as a .pkl file')
# ap.add_argument("-r", "--run_count", required = True, help = "Number of iterations for the experiment", default = 1)
# ap.add_argument("-k", "--hyp_key", required = True, help = "Key to index the hyperparameter json file")
# ap.add_argument("-o", "--outfile", required = True, help = "File name to save results into")

# args = vars(ap.parse_args())

# filename = args['graph_path']
# run_count = int(args['run_count'])
# hyp_key = args['hyp_key']
# outfile = args['outfile']

#################################
######### Read In Graph #########
#################################
with open(filename, 'rb') as file: 
    graph_dict = pkl.load(file)
    
try:
    graph = nx.Graph(nx.to_numpy_array(graph_dict['graph']))    
except:
    graph = nx.Graph(nx.to_numpy_array(graph_dict))

#################################
#### Generate Sense Features ####
#################################
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')

uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

#################################
######## Hyperparameters ########
#################################

# Define static ones to override or read in from a file

if hyp_key == '':
    hyp = {'line' : {'alpha' : 0.1, 
                     'ortho' : 0, 
                     'sparse' : 0, 
                     'epochs' : 15, 
                     'batch_size' : 1024, 
                     'lr' : 1e-3}, 

          'line+xm' : {'alpha' : 100, 
                      'ortho' : 10, 
                      'sparse' : 10, 
                      'epochs' : 50, 
                      'batch_size' : 1024, 
                      'lr' : 5e-4}}
else: 
    with open('hyp.json', 'r') as file: 
        hyp_file = json.load(file)
        hyp = hyp_file[hyp_key]


#################################
######## Run Experiment #########
#################################

dimensions = [16, 32, 64, 256, 512]
results = {d : {} for d in dimensions}
run_time = []

for run_idx in tqdm(range(run_count)):
    
    run_start = time.time()

    for d in dimensions: 
    
        # Embed 
        
        # Standard LINE
        line_start = time.time()
        line = LINE(graph, 
                embedding_size = d,
                sense_features = sense_features,
                alpha = hyp['line']['alpha'], 
                ortho = hyp['line']['ortho'], 
                sparse = hyp['line']['sparse'],
                learning_rate =  hyp['line']['lr'],
                order = 'second', 
                batch_size = hyp['line']['batch_size'])

        history = line.train(epochs = hyp['line']['epochs'])

        e = line.get_embeddings()
        embed_og = np.array([e[node_name] for node_name in graph.nodes()])
        embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
        line_time = (time.time() - line_start) / hyp['line']['epochs']


        feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                            embed_name = 'LINE',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_og = feature_dict_og['explain_norm']
        explain_og_norm = np.linalg.norm(explain_og, ord = 'nuc')
        error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
        explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)
        
        # LINE+XM
        line_plus_start = time.time()
        line_plus = LINE(graph, 
                        embedding_size = d,
                        sense_features = sense_features,
                        alpha = hyp['line+xm']['alpha'], 
                        ortho = hyp['line+xm']['ortho'], 
                        sparse = hyp['line+xm']['sparse'],
                        learning_rate =  hyp['line+xm']['lr'],
                        order = 'second', 
                        batch_size = hyp['line+xm']['batch_size'])

        history = line_plus.train(epochs = hyp['line+xm']['epochs'])

        e = line_plus.get_embeddings()
        embed_plus = np.array([e[node_name] for node_name in graph.nodes()])
        embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
        line_plus_time = (time.time() - line_plus_start) / hyp['line+xm']['epochs']

        feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                            embed_name = 'LINE+XM',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_plus = feature_dict_plus['explain_norm']
        explain_plus_norm = np.linalg.norm(explain_plus, ord = 'nuc')
        error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])
        explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)

        # Generate Node Explanations
        Y_og = embed_og
        sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
        Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
        D_og = tf.transpose(tf.transpose(sense_mat) / norm)


        Y_plus = embed_plus
        sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
        Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
        D_plus = tf.transpose(tf.transpose(sense_mat) / norm)

        norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in range(len(graph))]
        norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in range(len(graph))]
        
        try:
            results[d]['norm_og'].append(norm_og)
            results[d]['norm_plus'].append(norm_plus)
            results[d]['explain_og_norm'].append(explain_og_norm)
            results[d]['explain_plus_norm'].append(explain_plus_norm)
            results[d]['line_time'].append(line_time)
            results[d]['line+xm_time'].append(line_plus_time)
            results[d]['error_og'].append(error_og)
            results[d]['error_plus'].append(error_plus)
            
        except: 
            results[d]['norm_og'] = [norm_og]
            results[d]['norm_plus'] = [norm_plus]
            results[d]['explain_og_norm'] = [explain_og_norm]
            results[d]['explain_plus_norm'] = [explain_plus_norm]
            results[d]['line_time'] = [line_time]
            results[d]['line+xm_time'] = [line_plus_time]
            results[d]['error_og'] = [error_og]
            results[d]['error_plus'] = [error_plus]
            
        results[d]['embed_og'] = embed_og
        results[d]['embed_plus'] = embed_plus
    
    with open(outfile, 'wb') as file: 
        pkl.dump(results, file)
        
    with open(outfile, 'wb') as file: 
        pkl.dump(results, file)

    run_time.append(time.time() - run_start)

    with open(outfile + '_progress.txt', 'w') as file: 
        string = 'Current Run : ' + str(run_idx)
        string += '\nLast Iteration Time : ' + str(run_time[-1]) + 's'
        string += '\nAverage Iteration Time : ' + str(np.mean(run_time)) + 's'
        string += '\nEstimated Time Left : ' + str(np.mean(run_time) * (run_count - run_idx)) + 's'
        file.write(string)

### DGI

In [None]:
filename = './data/email.pkl'
run_count = 1
hyp_key = 'hyp_email'
outfile = './email_test.pkl'

# ap = argparse.ArgumentParser()
# ap.add_argument("-g", "--graph_path", required = True, help = 'Path to an nx.Graph object stored as a .pkl file')
# ap.add_argument("-r", "--run_count", required = True, help = "Number of iterations for the experiment", default = 1)
# ap.add_argument("-k", "--hyp_key", required = True, help = "Key to index the hyperparameter json file")
# ap.add_argument("-o", "--outfile", required = True, help = "File name to save results into")

# args = vars(ap.parse_args())

# filename = args['graph_path']
# run_count = int(args['run_count'])
# hyp_key = args['hyp_key']
# outfile = args['outfile']

#################################
######### Read In Graph #########
#################################
with open(filename, 'rb') as file: 
    graph_dict = pkl.load(file)
    
try:
    graph = nx.Graph(nx.to_numpy_array(graph_dict['graph']))    
except:
    graph = nx.Graph(nx.to_numpy_array(graph_dict))

#################################
#### Generate Sense Features ####
#################################
sense_feat_dict, sense_features = get_sense_features(graph, ppr_flag = 'std')

uncorrelated_feats = ['Degree',
                    'Clustering Coefficient',
                    'Personalized Page Rank - Standard Deviation',
                    'Average Neighbor Degree',
                    'Average Neighbor Clustering',
                    'Eccentricity',
                    'Katz Centrality']
sense_features = sense_features[:, [list(sense_feat_dict).index(feat) for feat in uncorrelated_feats]]
sense_feat_dict = {feat : idx for idx, feat in enumerate(uncorrelated_feats)}

#################################
######## Hyperparameters ########
#################################

# Define static ones to override or read in from a file

if hyp_key == '':
    hyp = {'dgi' : {'use_xm' : False, 
                     'ortho' : 0, 
                     'sparse' : 0, 
                     'lr' : 1e-3}, 

            'dgi+xm': {'use_xm' : True, 
                     'ortho' : 0.1, 
                     'sparse' : 0.1, 
                     'lr' : 1e-3}
          }
else: 
    with open('hyp.json', 'r') as file: 
        hyp_file = json.load(file)
        hyp = hyp_file[hyp_key]
        
        # Convert To Bool
        hyp['dgi+xm']['use_xm'] = hyp['dgi+xm']['use_xm'] == 'True'
        hyp['dgi']['use_xm'] = hyp['dgi']['use_xm'] == 'True'


#################################
######## Run Experiment #########
#################################

dimensions = [16, 32, 64, 256, 512]
results = {d : {} for d in dimensions}
run_time = []

for run_idx in tqdm(range(run_count)):
    
    run_start = time.time()

    for d in dimensions: 
    
        # Embed 
        
        # DGI-ID
        dgi_id = DGIEmbedding(graph = graph, 
                   embed_dim = d, 
                   feature_matrix = np.identity(len(graph)), 
                   use_xm = hyp['dgi']['use_xm'], 
                   ortho_ = hyp['dgi']['ortho'], 
                   sparse_ = hyp['dgi']['sparse'], 
                   batch_size = 1)
        embed_id = dgi_id.get_embedding()
        embed_id = (embed_id - np.min(embed_id)) / np.ptp(embed_id)
        feature_dict_id = find_feature_membership(input_embed = embed_id,
                                                            embed_name = 'DGI',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_id = feature_dict_id['explain_norm']
        error_id = sense_features * np.log((sense_features + 1e-10) / ((embed_id @ feature_dict_id['explain_norm']) + 1e-10)) - sense_features + (embed_id @ feature_dict_id['explain_norm'])
        explain_id = (explain_id - np.min(explain_id)) / np.ptp(explain_id)
        
        # DGI-SF
        dgi_og = DGIEmbedding(graph = graph, 
                   embed_dim = d, 
                   feature_matrix = sense_features, 
                   use_xm = hyp['dgi']['use_xm'], 
                   ortho_ = hyp['dgi']['ortho'], 
                   sparse_ = hyp['dgi']['sparse'], 
                   batch_size = 1)
        embed_og = dgi_og.get_embedding()
        embed_og = (embed_og - np.min(embed_og)) / np.ptp(embed_og)
        feature_dict_og = find_feature_membership(input_embed = embed_og,
                                                            embed_name = 'DGI-SF',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_og = feature_dict_og['explain_norm']
        error_og = sense_features * np.log((sense_features + 1e-10) / ((embed_og @ feature_dict_og['explain_norm']) + 1e-10)) - sense_features + (embed_og @ feature_dict_og['explain_norm'])
        explain_og = (explain_og - np.min(explain_og)) / np.ptp(explain_og)
        
        # DGI+XM
        dgi_plus = DGIEmbedding(graph = graph, 
                   embed_dim = d, 
                   feature_matrix = sense_features, 
                   use_xm = hyp['dgi+xm']['use_xm'], 
                   ortho_ = hyp['dgi+xm']['ortho'], 
                   sparse_ = hyp['dgi+xm']['sparse'], 
                   batch_size = 1)
        embed_plus = dgi_plus.get_embedding()
        embed_plus = (embed_plus - np.min(embed_plus)) / np.ptp(embed_plus)
        feature_dict_plus = find_feature_membership(input_embed = embed_plus,
                                                            embed_name = 'DGI+XM',
                                                            sense_features = sense_features,
                                                            sense_feat_dict = sense_feat_dict,
                                                            top_k = 8,
                                                            solver = 'nmf')

        explain_plus = feature_dict_plus['explain_norm']
        error_plus = sense_features * np.log((sense_features + 1e-10) / ((embed_plus @ feature_dict_plus['explain_norm']) + 1e-10)) - sense_features + (embed_plus @ feature_dict_plus['explain_norm'])
        explain_plus = (explain_plus - np.min(explain_plus)) / np.ptp(explain_plus)

        # Generate Node Explanations
        Y_og = embed_og
        sense_mat = tf.einsum('ij, ik -> ijk', Y_og, sense_features)
        Y_og_norm = tf.linalg.diag_part(tf.matmul(Y_og, Y_og, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_og_norm * tf.cast(sense_norm, tf.float32)
        D_og = tf.transpose(tf.transpose(sense_mat) / norm)


        Y_plus = embed_plus
        sense_mat = tf.einsum('ij, ik -> ijk', Y_plus, sense_features)
        Y_plus_norm = tf.linalg.diag_part(tf.matmul(Y_plus, Y_plus, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_plus_norm * tf.cast(sense_norm, tf.float32)
        D_plus = tf.transpose(tf.transpose(sense_mat) / norm)
        
        Y_id = embed_id
        sense_mat = tf.einsum('ij, ik -> ijk', Y_id, sense_features)
        Y_id_norm = tf.linalg.diag_part(tf.matmul(Y_id, Y_id, transpose_b = True), k = 0)
        sense_norm = tf.linalg.diag_part(tf.matmul(sense_features, sense_features, transpose_b = True), k = 0)
        norm = Y_id_norm * tf.cast(sense_norm, tf.float32)
        D_id = tf.transpose(tf.transpose(sense_mat) / norm)

        norm_og = [np.linalg.norm(D_og[node, :, :], ord = 'nuc') for node in range(len(graph))]
        norm_plus = [np.linalg.norm(D_plus[node, :, :], ord = 'nuc') for node in range(len(graph))]
        norm_id = [np.linalg.norm(D_id[node, :, :], ord = 'nuc') for node in range(len(graph))]

        try:
            results[d]['norm_og'].append(norm_og)
            results[d]['norm_plus'].append(norm_plus)
            results[d]['norm_id'].append(norm_id)
            results[d]['explain_og_norm'].append(explain_og_norm)
            results[d]['explain_plus_norm'].append(explain_plus_norm)
            results[d]['explain_id_norm'].append(explain_id_norm)
            results[d]['dgi_id_time'].append(dgi_id.time_per_epoch)
            results[d]['dgi_og_time'].append(dgi_og.time_per_epoch)
            results[d]['dgi+xm_time'].append(dgi_plus.time_per_epoch)
            results[d]['error_og'].append(error_og)
            results[d]['error_plus'].append(error_plus)
            results[d]['error_id'].append(error_id)
            
            
        except: 
            results[d]['norm_og'] = [norm_og]
            results[d]['norm_plus'] = [norm_plus]
            results[d]['norm_id'] = [norm_id]
            results[d]['explain_og_norm'] = [explain_og_norm]
            results[d]['explain_plus_norm'] = [explain_plus_norm]
            results[d]['explain_id_norm'] = [explain_id_norm]
            results[d]['dgi_id_time'] = [dgi_id.time_per_epoch]
            results[d]['dgi_og_time'] = [dgi_og.time_per_epoch]
            results[d]['dgi+xm_time'] = [dgi_plus.time_per_epoch]
            results[d]['error_og'] = [error_og]
            results[d]['error_plus'] = [error_plus]
            results[d]['error_id'] = [error_id]
            
        results[d]['embed_og'] = embed_og
        results[d]['embed_plus'] = embed_plus
        results[d]['embed_id'] = embed_id
        
    
    with open(outfile, 'wb') as file: 
        pkl.dump(results, file)
        
    with open(outfile, 'wb') as file: 
        pkl.dump(results, file)

    run_time.append(time.time() - run_start)

    with open(outfile + '_progress.txt', 'w') as file: 
        string = 'Current Run : ' + str(run_idx)
        string += '\nLast Iteration Time : ' + str(run_time[-1]) + 's'
        string += '\nAverage Iteration Time : ' + str(np.mean(run_time)) + 's'
        string += '\nEstimated Time Left : ' + str(np.mean(run_time) * (run_count - run_idx)) + 's'
        file.write(string)