In [12]:
import seaborn as sns
import pandas as pd
import numpy as np
import pickle as pkl
import matplotlib as mpl
from graph_tool.generation import lattice
from graph_tool.draw import graph_draw
from graph_tool import openmp_set_num_threads
from itertools import chain
from collections import Counter
from tqdm import tqdm
from copy import copy
from scipy.spatial.distance import cosine, euclidean


from graph_helpers import load_graph_by_name, extract_edges, extract_nodes
from viz_helpers import lattice_node_pos
from experiment import gen_input
from helpers import infected_nodes
from core import sample_steiner_trees
from tree_stat import TreeBasedStatistics
from random_steiner_tree.util import from_gt

from root_sampler import build_root_sampler_by_pagerank_score
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score


In [2]:
openmp_set_num_threads(1)

In [17]:
graph_name = 'grqc'
suffix = 's0.03'
q = '0.1'
g = load_graph_by_name(graph_name, weighted=True, suffix='_'+suffix)

load graph from data/grqc/graph_weighted_s0.03.gt


In [19]:
stuff = pkl.load(open('outputs/inf_probas/{}-{}-q{}.pkl'.format(graph_name, suffix, q), 'rb'))

In [20]:
def accumulate_score(stuff, eval_func):
    scores_by_root_sampling_method = {}
    for root_sampling_method, data in stuff.items():
        scores_by_root_sampling_method[root_sampling_method] = []
        for row in tqdm(data):
            c, obs = row['c'], row['obs']
            inf_nodes = infected_nodes(c)
            y_true = np.zeros((len(c), ))
            y_true[inf_nodes] = 1
            mask = np.array([(i not in obs) for i in range(len(c))])

            score = {}
            names = ['random', 'st_naive', 'st_inc']
            random_inf_p = np.random.random(g.num_vertices())
            for name, inf_probas in zip(names, [random_inf_p, row['st_naive_probas'], row['st_tree_inc_probas']]):
                score[name] = eval_func(y_true[mask], inf_probas[mask])
            scores_by_root_sampling_method[root_sampling_method].append(score)
    return scores_by_root_sampling_method

In [21]:
def threshold_half_accuracy(y_true, y_pred):
    pred_labels = np.asarray((y_pred >= 0.5), dtype=np.bool)
    return accuracy_score(y_true, pred_labels)

In [22]:
ap_scores_by_root_sampling_method = accumulate_score(stuff, average_precision_score)
roc_scores_by_root_sampling_method = accumulate_score(stuff, roc_auc_score)

100%|██████████| 768/768 [00:12<00:00, 62.69it/s]
100%|██████████| 768/768 [00:15<00:00, 50.23it/s]
100%|██████████| 768/768 [00:12<00:00, 60.34it/s]
100%|██████████| 768/768 [00:15<00:00, 46.92it/s]
100%|██████████| 768/768 [00:14<00:00, 52.98it/s]
100%|██████████| 768/768 [00:14<00:00, 51.74it/s]
100%|██████████| 768/768 [00:16<00:00, 66.34it/s]
100%|██████████| 768/768 [00:13<00:00, 57.47it/s]
100%|██████████| 768/768 [00:11<00:00, 65.93it/s]
100%|██████████| 768/768 [00:16<00:00, 47.52it/s]


In [23]:
accuracy_by_root_sampling_method = accumulate_score(stuff, threshold_half_accuracy)

100%|██████████| 768/768 [00:13<00:00, 56.19it/s]
100%|██████████| 768/768 [00:11<00:00, 66.15it/s]
100%|██████████| 768/768 [00:09<00:00, 84.49it/s]
100%|██████████| 768/768 [00:11<00:00, 69.50it/s]
100%|██████████| 768/768 [00:09<00:00, 83.72it/s]


In [24]:
def describe_stuff(scores_by_root_sampling_method):
    df_by_root_sampling_method = {}
    for name, recs in scores_by_root_sampling_method.items():
        df_by_root_sampling_method[name] = pd.DataFrame.from_records(recs).describe()
    return df_by_root_sampling_method
    

In [25]:
ap_df_by_root_sampling_method = describe_stuff(ap_scores_by_root_sampling_method)
roc_df_by_root_sampling_method = describe_stuff(roc_scores_by_root_sampling_method)
accuracy_df_by_root_sampling_method = describe_stuff(accuracy_by_root_sampling_method)

In [26]:
def print_result(df_by_keys):
    keys = ['pagerank-eps0.0', 'pagerank-eps0.5', 'pagerank-eps1.0', 'random_root', 'true root']
    for k  in keys:
        df = df_by_keys[k]
        print(k)
        print('-' * 10)
        print(df)    

In [27]:
print_result(ap_df_by_root_sampling_method)

pagerank-eps0.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.028739    0.519638    0.492277
std      0.025484    0.294506    0.236692
min      0.001526    0.007117    0.009651
25%      0.005489    0.292504    0.384354
50%      0.013812    0.590268    0.492354
75%      0.052583    0.694487    0.586402
max      0.103090    1.000000    1.000000
pagerank-eps0.5
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.028348    0.423003    0.413134
std      0.025292    0.281057    0.184759
min      0.001288    0.005211    0.011940
25%      0.005620    0.105163    0.291376
50%      0.014365    0.537865    0.460842
75%      0.051643    0.630235    0.522219
max      0.107173    0.960233    0.871361
pagerank-eps1.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.028389    0.422234    0.408661
std      0.025063    0.279375    0.18

In [28]:
print_result(roc_df_by_root_sampling_method)

pagerank-eps0.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.501072    0.929740    0.899488
std      0.056799    0.052914    0.065862
min      0.173045    0.704641    0.595108
25%      0.476061    0.906527    0.862507
50%      0.501188    0.928863    0.888946
75%      0.523646    0.972005    0.957618
max      0.756600    1.000000    1.000000
pagerank-eps0.5
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.501437    0.925527    0.899648
std      0.054115    0.051726    0.067759
min      0.275326    0.674889    0.457908
25%      0.476606    0.901602    0.861109
50%      0.501240    0.926297    0.886834
75%      0.522488    0.963392    0.959234
max      0.755846    0.999823    0.999498
pagerank-eps1.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.502069    0.922745    0.895775
std      0.051316    0.058173    0.07

In [29]:
print_result(accuracy_df_by_root_sampling_method)

pagerank-eps0.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.499934    0.962198    0.975322
std      0.007463    0.022937    0.022501
min      0.476491    0.916848    0.906592
25%      0.495299    0.943895    0.954040
50%      0.499578    0.956538    0.988921
75%      0.504843    0.980330    0.995668
max      0.523256    1.000000    0.998797
pagerank-eps0.5
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.499830    0.952981    0.974924
std      0.007846    0.013649    0.022685
min      0.471126    0.915230    0.915938
25%      0.494704    0.942639    0.953806
50%      0.499516    0.953323    0.988198
75%      0.505543    0.964679    0.994946
max      0.522640    0.999519    0.998556
pagerank-eps1.0
----------
           random      st_inc    st_naive
count  768.000000  768.000000  768.000000
mean     0.500010    0.952808    0.975122
std      0.007149    0.013118    0.02