In [1]:
%matplotlib inline

from graph_generator import grid_2d, add_p_and_delta
from utils import infeciton_time2weight
from plot_utils import plot_snapshot, add_colorbar
from ic import sample_graph_from_infection, make_partial_cascade
from joblib import Parallel, delayed
from collections import defaultdict
from tqdm import tqdm
from scipy.stats import hmean

In [None]:
g = grid_2d(10)
g = add_p_and_delta(g, p=0.7, d=1)

In [None]:
pos = {n: np.array(n) for n in g.nodes()}

In [None]:
def infection_time_estimation(g, n_rounds, mean_method='harmonic'):
    """
    estimate the harmonic mean of infection times given each node as source

    Returns:
    dict source to nodes' infection time:
    for each node as source, return the estimated infection times of all nodes
    """
    sampled_graphs = [sample_graph_from_infection(g)
                      for i in range(n_rounds)]
    s2t_len_list = Parallel(n_jobs=-1)(
        delayed(nx.shortest_path_length)(g, weight='d')
        for g in sampled_graphs)
    # 3D array
    s2n_times = defaultdict(lambda: defaultdict(list))

    for g, s2t_len in tqdm(zip(sampled_graphs, s2t_len_list)):
        for s in s2t_len:
            for n in g.nodes_iter():
                s2n_times[s][n].append(s2t_len[s].get(n, float('inf')))

    if mean_method == 'harmonic':
        def mean_func(times):
            times = np.array(times)
            times = times[np.nonzero(times)]
            if times.shape[0] >	0:
                return hmean(times)
            else:  # all zeros
                return 0
    elif mean_method == 'arithmetic':
        all_times = np.array([times
                              for n2times in s2n_times.values()
                              for times in n2times.values()])
        all_times = np.ravel(all_times)
        all_times = all_times[np.invert(np.isinf(all_times))]
        inf_value = all_times.max() + 1
        print(inf_value)

        def mean_func(times):
            times = np.array(times)
            times[np.isinf(times)] = inf_value
            return times.mean()

    else:
        raise ValueError('{"harmoic", "arithmetic"} accepted')

    est = defaultdict(dict)
    for s, n2times in tqdm(s2n_times.items()):
        for n, times in n2times.items():
            est[s][n] = mean_func(times)
    return est, s2n_times

In [None]:
est, s2n_times = infection_time_estimation(g, 100, mean_method='arithmetic')
print(est[(0, 0)][(0, 1)])
weights = infeciton_time2weight(est[(0, 0)])
plot_snapshot(g, pos, weights, source_node=(0, 0))
add_colorbar(np.array(list(weights.values())))

In [None]:
print(est[(0, 0)][(0, 1)], est[(0, 0)][(0, 2)], est[(0, 0)][(1, 1)])

In [None]:
est, s2n_times = infection_time_estimation(g, 100, mean_method='harmonic')
weights = infeciton_time2weight(est[(0, 0)])
plot_snapshot(g, pos, weights, source_node=(0, 0))
add_colorbar(np.array(list(weights.values())))

In [None]:
print(est[(0, 0)][(0, 1)], est[(0, 0)][(0, 2)], est[(0, 0)][(1, 1)])

In [None]:
source, obs_nodes, infection_times, _ = make_partial_cascade(g, 0.05, 'late_nodes')

In [None]:
plot_snapshot(g, pos, infeciton_time2weight(infection_times), queried_nodes=obs_nodes)

In [None]:
hmean([float('inf'), 100])

In [6]:
# faster version
import itertools
from scipy.sparse import csr_matrix

In [47]:
n_rounds = 100
g = grid_2d(10)
g = add_p_and_delta(g, p=0.7, d=1)
node2id = {n: i for i, n in enumerate(g.nodes_iter())}
def run_one_round(sampled_g, node2id):
    s2t_len = nx.shortest_path_length(sampled_g)
    return np.array([[node2id[s], node2id[n], t]
                     for s in s2t_len
                     for n, t in s2t_len[s].items()],
                    dtype=np.int32)

In [59]:
snt_list_list = Parallel(n_jobs=-1)(delayed(run_one_round)(sample_graph_from_infection(g), node2id)
                                    for i in range(n_rounds))

df = pd.DataFrame(list(itertools.chain(*snt_list_list)),
                  columns=['source', 'node', 'time'],
                  dtype=np.uint16)

In [60]:
df.info(memory_usage='deep')

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 923118 entries, 0 to 923117
Data columns (total 3 columns):
source    923118 non-null uint16
node      923118 non-null uint16
time      923118 non-null uint16
dtypes: uint16(3)
memory usage: 5.3 MB
