In [3]:
%matplotlib notebook

In [45]:
import numpy as np
import random
import itertools
from graph_tool import Graph, GraphView
from graph_tool.draw import graph_draw
from tqdm import tqdm

from matplotlib import pyplot as plt

from viz_helpers import lattice_node_pos
from minimum_steiner_tree import min_steiner_tree
from cascade_generator import si, observe_cascade
from eval_helpers import infection_precision_recall
from graph_helpers import remove_filters, load_graph_by_name, gen_random_spanning_tree

from inference import infer_infected_nodes
from experiment import gen_input

In [39]:
graph_name = 'karate'
g = load_graph_by_name(graph_name)

In [40]:
n_rounds = 100
n_samples = 100
stop_fraction = 0.25
obs_fraction = 0.5

In [48]:
sampling_scores = []
sp_trees = [gen_random_spanning_tree(g) for _ in range(n_samples)]
for i in tqdm(range(n_rounds)):
    obs, c = gen_input(g, stop_fraction=stop_fraction, q=obs_fraction)
    preds = infer_infected_nodes(g, obs, method='sampling', sp_trees=sp_trees)
    prec, rec = infection_precision_recall(set(preds), c, obs)
    sampling_scores.append((prec, rec))

100%|██████████| 100/100 [00:36<00:00,  2.77it/s]


In [42]:
mst_scores = []
for i in tqdm(range(n_rounds)):
    obs, c = gen_input(g, stop_fraction=stop_fraction, q=obs_fraction)
    preds = infer_infected_nodes(g, obs, method='min_steiner_tree')
    prec, rec = infection_precision_recall(set(preds), c, obs)
    mst_scores.append((prec, rec))

100%|██████████| 100/100 [00:02<00:00, 38.02it/s]


In [43]:
sampling_scores = np.asarray(sampling_scores)
sampling_scores.mean(axis=0)

array([ 0.19033802,  0.26607143])

In [44]:
mst_scores = np.asarray(mst_scores)
mst_scores.mean(axis=0)

array([ 0.09299603,  0.10905952])