In [1]:
%matplotlib inline

In [56]:
import matplotlib as mpl
mpl.use('Agg')

import pickle as pkl
import math
from graph_tool.draw import graph_draw
from matplotlib import  pyplot as plt

from graph_helpers import remove_filters, load_graph_by_name, gen_random_spanning_tree
from viz_helpers import lattice_node_pos, QueryIllustrator
from experiment import gen_input, one_round_experiment, remove_filters
from query_selection import RandomQueryGenerator, OurQueryGenerator, PRQueryGenerator


because the backend has already been chosen;
matplotlib.use() must be called *before* pylab, matplotlib.pyplot,
or matplotlib.backends is imported for the first time.



In [23]:
graph_name = 'lattice'
g = load_graph_by_name(graph_name)

In [24]:
if graph_name == 'lattice':
    pos = lattice_node_pos(g, shape=(10, 10))


In [25]:
n_samples = 100
stop_fraction = 0.25
n_queries = 15

In [30]:
obs, c = gen_input(g, stop_fraction=stop_fraction, p=0.5)

In [31]:
if True:
    pkl.dump((obs, c), open('/tmp/cascade_example.pkl', 'wb'))
else:
    (obs, c) = pkl.load(open('/tmp/cascade_example.pkl', 'rb'))

In [32]:
n_plots = n_queries
n_row = 1
n_col = int(math.ceil(n_plots / n_row))

def create_fig_axes():
    plt.clf()
    fig, axes = plt.subplots(n_row, n_col, sharex=True, sharey=True)
    return fig, axes

In [57]:
# for our method
gv = remove_filters(g)
q_gen = OurQueryGenerator(gv, obs, num_spt=n_samples,
                            method='entropy',
                            use_resample=False)
scores, queries, eval_details = one_round_experiment(
    g, obs, c, q_gen, query_method='ours', inference_method='sampling', 
    n_spanning_tree_samples=n_samples,
    n_queries=n_queries,
    return_details=True, log=True)



100%|██████████| 15/15 [00:39<00:00,  2.51s/it]


In [58]:
plt.switch_backend('cairo')
qi = QueryIllustrator(g, obs, c, pos)
fig, axes = create_fig_axes()
for i, ax in enumerate(axes):
    qi.plot_snapshot(queries[i], n_samples=n_samples, ax=ax)

In [59]:
fig.set_size_inches(75, 5, forward=True)
fig.savefig('figs/query_process_our.pdf')

In [60]:
# for pagerank
gv = remove_filters(g)
q_gen = PRQueryGenerator(gv, obs)
scores, queries, eval_details = one_round_experiment(
    g, obs, c, q_gen, query_method='pagerank', inference_method='sampling',
    n_spanning_tree_samples=n_samples,
    n_queries=n_queries, 
    return_details=True, log=True)

100%|██████████| 15/15 [00:18<00:00,  1.20s/it]


In [61]:
qi = QueryIllustrator(g, obs, c, pos)
fig, axes = create_fig_axes()
for i, ax in enumerate(axes):
    qi.plot_snapshot(queries[i], n_samples=n_samples, ax=ax)
fig.set_size_inches(75, 5, forward=True)
fig.savefig('figs/query_process_pagerank.pdf')    

In [62]:
# for random
gv = remove_filters(g)
q_gen = RandomQueryGenerator(gv, obs)
scores, queries, eval_details = one_round_experiment(
    g, obs, c, q_gen, query_method='random', inference_method='sampling', n_queries=n_queries,
    n_spanning_tree_samples=n_samples,
    return_details=True)

In [63]:
qi = QueryIllustrator(g, obs, c, pos)
fig, axes = create_fig_axes()
for i, ax in enumerate(axes):
    qi.plot_snapshot(queries[i], n_samples=n_samples, ax=ax)
fig.set_size_inches(75, 5, forward=True)
fig.savefig('figs/query_process_random.pdf')    