In [2]:
%matplotlib notebook
# starter code
import random
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import networkx as nx
from graph_tool.all import load_graph, shortest_distance, GraphView
from networkx.drawing.nx_agraph import graphviz_layout

from cascade import gen_nontrivial_cascade
from utils import get_rank_index

seed = 123456
random.seed(seed)
np.random.seed(seed)

gtype = 'grid'
g = load_graph('data/{}/2-6/graph.gt'.format(gtype))
gnx = nx.read_graphml('data/{}/2-6/graph.graphml'.format(gtype))
gnx = nx.relabel_nodes(gnx, {i: int(i) for i in gnx.nodes_iter()})
N1, N2 = 100, 100
p, q = 0.5, 0.1

pos = graphviz_layout(gnx, root=0)

infection_times, source, obs_nodes = gen_nontrivial_cascade(g, p, q, model='ic', return_tree=False, source_includable=True)

In [3]:
model = 'ic'
method = 'mst'

In [13]:
def get_tree(g, infection_times, source, obs_nodes, method):
    root = min(obs_nodes, key=infection_times.__getitem__)
    if method == 'mst':
        from steiner_tree_mst import steiner_tree_mst
        tree = steiner_tree_mst(g, root, infection_times, source, obs_nodes, debug=False)
    elif method == 'mst-k':
        from steiner_tree_mst import steiner_tree_mst
        k = (int(len(obs_nodes) * 0.5) or 1)
        # print(k)
        tree = steiner_tree_mst(g, root, infection_times, source, obs_nodes, debug=False, k=k)
    elif method == 'tbfs':
        from steiner_tree_order import temporal_bfs
        tree = temporal_bfs(g, root, infection_times[root], infection_times, source, obs_nodes, debug=False)
    return tree

In [20]:
import pandas as pd
def run_k_rounds(g, p, q, model, method, k=100):
    rows = []
    for _ in range(k):
        infection_times, source, obs_nodes = gen_nontrivial_cascade(
            g, p, q, model=model,
            return_tree=False, source_includable=True)        
        tree = get_tree(g, infection_times, source, obs_nodes, method)
        if tree:
            rows.append(evaluate_performance(g, tree, obs_nodes))
    df = pd.DataFrame(rows, columns=['mmc', 'prec', 'rec', 'obj'])
    return df.describe()

In [21]:
scores_stat = run_k_rounds(g, p, q, model, 'mst')
print(scores_stat)


  'precision', 'predicted', average, warn_for)


              mmc        prec         rec        obj
count  100.000000  100.000000  100.000000  100.00000
mean     0.034615    0.558626    0.103930    8.29000
std      0.162908    0.350214    0.082136    5.25183
min     -0.395461    0.000000    0.000000    0.00000
25%     -0.029854    0.326923    0.028571    3.00000
50%      0.030155    0.612500    0.089572    9.00000
75%      0.139599    0.833333    0.172689   13.00000
max      0.394600    1.000000    0.294118   17.00000
             mmc       prec        rec        obj
count  84.000000  84.000000  84.000000  84.000000
mean    0.022795   0.484672   0.100419   7.833333
std     0.175730   0.392663   0.103924   6.530435
min    -0.489219   0.000000   0.000000   1.000000
25%    -0.099278   0.000000   0.000000   2.000000
50%     0.000000   0.535885   0.085714   6.000000
75%     0.175419   0.833333   0.172689  12.000000
max     0.337019   1.000000   0.400000  26.000000


In [22]:
scores_stat = run_k_rounds(g, p, q, model, 'tbfs')
print(scores_stat)

  'precision', 'predicted', average, warn_for)


             mmc       prec        rec        obj
count  90.000000  90.000000  90.000000  90.000000
mean    0.014374   0.522599   0.119198   9.466667
std     0.184398   0.364440   0.113930   7.531507
min    -0.492366   0.000000   0.000000   1.000000
25%    -0.128926   0.175000   0.027778   3.000000
50%     0.000000   0.545455   0.086975   8.000000
75%     0.159279   0.851190   0.181818  14.000000
max     0.404520   1.000000   0.411765  33.000000


In [16]:
from sklearn.metrics import matthews_corrcoef, precision_score, recall_score

def evaluate_performance(g, tree, obs_nodes):
    inferred_labels = np.zeros(g.num_vertices())
    true_labels = (infection_times != -1)
    for e in list(tree.edges()):
        u, v = map(int, e)
        inferred_labels[u] = 1
        inferred_labels[v] = 1
    
    idx = np.sort(list(set(np.arange(g.num_vertices())) - set(obs_nodes)))
    
    true_labels = true_labels[idx]
    inferred_labels = inferred_labels[idx]
    
    mmc = matthews_corrcoef(true_labels, inferred_labels)
    prec = precision_score(true_labels, inferred_labels)
    rec = recall_score(true_labels, inferred_labels)
    obj = tree.num_edges()
    return (mmc, prec, rec, obj)

In [None]:
from plot_utils import plot_snapshot
from utils import extract_edges


fig, ax = plt.subplots(2, 2, figsize=(10, 10))
plot_snapshot(gnx, pos, infection_times, queried_nodes=obs_nodes, source_node=root, with_labels=True,
              ax=ax[0, 0],
              edges=extract_edges(tree))
ax[0, 0].set_title('closure graph')