In [107]:
import pickle as pkl
import os
import numpy as np

from graph_tool import load_graph
from graph_helpers import extract_edges
from random_steiner_tree.util import from_gt

from sample_pool import TreeSamplePool
from root_sampler import get_root_sampler_by_name
from sklearn.metrics import average_precision_score

from inf_helpers import infer_edge_frequency
from eval_helpers import eval_edge_map

In [78]:
g = load_graph('data/lattice-1024/graph_weighted_0.1.gt')

In [104]:
def one_run_for_edge(g, edge_weights, input_path, output_dir, method='our',
                     **kwargs):
    basename = os.path.basename(input_path)
    output_path = os.path.join(output_dir, basename)

    if os.path.exists(output_path):
        # print(output_path, 'procssed, skip')
        return

    obs, c, _ = pkl.load(open(input_path, 'rb'))

    nlog_edge_weights = g.new_edge_property('float')
    nlog_edge_weights.a = -np.log(edge_weights.a)

    if method == 'our':
        root_sampler_name = kwargs.get('root_sampler_name')
        root_sampler = get_root_sampler_by_name(root_sampler_name, g=g, obs=obs, c=c,
                                                weights=nlog_edge_weights)
        n_samples = kwargs.get('n_sample', 1000)
        edge_freq = infer_edge_frequency(
            g, edge_weights=edge_weights, obs=obs,
            root_sampler=root_sampler,
            n_samples=n_samples,
            log=False)
    elif method == 'min-steiner-tree':
        from minimum_steiner_tree import min_steiner_tree
        edges = min_steiner_tree(g, obs,
                                 p=nlog_edge_weights,
                                 return_type='edges')

        # make it a binary vector
        edge_freq = {e: 1 for e in edges}
    else:
        raise ValueError('unsupported method')

    pkl.dump({'edge_freq': edge_freq},
              open(output_path, 'wb'))

In [106]:
one_run_for_edge(g, edge_weights,
                 'cascade-with-edges/lattice-1024-mic-s0.1-o0.5-omuniform/0.pkl',
                 'output-edges/min-steiner-tree/lattice-1024-mic-s0.1-o0.5-omuniform/',
                  method='min-steiner-tree')

In [111]:
eval_edge_map(g,
              'cascade-with-edges/lattice-1024-mic-s0.1-o0.5-omuniform/',
              'output-edges/min-steiner-tree/lattice-1024-mic-s0.1-o0.5-omuniform/')

100%|██████████| 100/100 [00:01<00:00, 65.82it/s]


[0.18560863142765888]

In [79]:
cascade_path = 'cascade-with-edges/lattice-1024-mic-s0.1-o0.5-omuniform/0.pkl'
obs, c, tree_edges = pkl.load(open(cascade_path, 'rb'))

In [80]:
edge_weights = g.edge_properties['weights']

In [81]:
edge_freq = infer_edge_frequency(g, edge_weights, obs, n_samples=2000)

In [82]:
evaluate_edge_prediction(g, tree_edges, edge_freq, average_precision_score)

0.5395074746964879