In [1]:
%matplotlib notebook
# starter code
import random
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
import networkx as nx
from graph_tool.all import load_graph, shortest_distance, GraphView
from networkx.drawing.nx_agraph import graphviz_layout

from cascade import gen_nontrivial_cascade
from utils import get_rank_index

seed = 123456
random.seed(seed)
np.random.seed(seed)

gtype = 'grid'
g = load_graph('data/{}/2-6/graph.gt'.format(gtype))
gnx = nx.read_graphml('data/{}/2-6/graph.graphml'.format(gtype))
gnx = nx.relabel_nodes(gnx, {i: int(i) for i in gnx.nodes_iter()})
N1, N2 = 100, 100
p, q = 0.5, 0.2

pos = graphviz_layout(gnx, root=0)

# infection_times, source, obs_nodes, tree = gen_nontrivial_cascade(g, p, q, model='ic', return_tree=True)

In [56]:
import numpy as np
import random
from graph_tool.all import shortest_distance, shortest_path

from utils import edges2graph
from feasibility import is_arborescence

def gen_cascade(g, scale=1.0, source=None, stop_fraction=0.5, return_tree=False):    
    rands = np.random.exponential(scale, g.num_edges())
    delays = g.new_edge_property('float')
    delays.set_2d_array(rands)

    if source is None:
        source = random.choice(np.arange(g.num_vertices()))

    dist, pred = shortest_distance(g, source=source, weights=delays, pred_map=True)

    q = stop_fraction * 100
    percentile = np.percentile(dist.a, q)
    infected_nodes = np.nonzero(dist.a <= percentile)[0]
    uninfected_nodes = np.nonzero(dist.a > percentile)[0]

    infection_times = np.array(dist.a)
    infection_times[uninfected_nodes] = -1
    
    rets = (source, infection_times)
    if return_tree:
        tree_edges = set()
        for n in infected_nodes:
            c = n
            while pred[c] != c:
                edge = ((pred[c], c))
                if edge not in tree_edges:
                    tree_edges.add(edge)
                else:
                    break
        tree = edges2graph(g, tree_edges)
        rets += (tree, )
    return rets

for stop_fraction in np.arange(0.1, 1.0, 0.1):
    for i in range(10):
        source, infection_times, tree = gen_cascade(g, 1.0, source=None,
                                                    stop_fraction=stop_fraction, return_tree=True)
        assert infection_times[source] == 0
        assert is_arborescence(tree)
        np.testing.assert_almost_equal(np.count_nonzero(infection_times != -1) / g.num_vertices(),
                                       stop_fraction, decimal=1)

In [49]:
np.testing.assert_almost_equal?