In [45]:
import sys
sys.path.insert(0,'..')
import sib
import pandas as pd
import numpy as np

## Order of parameters

### Contacts:

in order: (i, j, t, lambda)

### Observations:
(i,s,t) which is in order:

- node
- state
- time

In [46]:
from pathlib import Path
import data_load

folder_data = Path("data/tree_check/")

params,contacts,observ,epidem = data_load.load_exported_data(folder_data)

n_inst = len(observ)
print("Number of instances: {}".format(len(observ)))
params

Number of instances: 50


{'type_graph': 'TREE',
 't_limit': 8,
 'lambda': 0.9,
 'mu': 0.3,
 'seed': 3,
 'n': 31,
 'p_tested': 0.01,
 'p_asymptomatic': 0.5,
 'p_delay_symptoms': {'1': 0.7, '2': 0.30000000000000004}}

In [47]:
contacts = contacts[["i","j","t","lambda"]]
#contacts.iloc[:30]

In [48]:
obs_all_df = []
for obs in observ:
    obs_df = data_load.convert_obs_to_df(obs)
    obs_df = obs_df[["i","st","t"]]
    obs_all_df.append(obs_df)

### Take one epidemic instance

In [49]:
inst = 8

In [50]:
epi_i = epidem[inst]
obs_dict = observ[inst]
print(obs_dict)

{'S': {'4': [27]}, 'I': {'2': [9], '3': [4], '4': [3, 19], '6': [2, 6, 7], '7': [5, 6, 16], '8': [12, 13]}, 'R': {}}


In [51]:
obs_df = obs_all_df[inst]

In [52]:
print(epi_i)

[[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 0 0 0 0 1 0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0]
 [0 1 0 0 1 0 0 0 0 1 1 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0]
 [1 1 0 1 1 0 0 0 0 1 2 0 0 0 0 0 0 0 0 1 1 1 1 0 0 0 0 0 0 0 0]
 [2 1 1 1 1 0 0 1 1 1 2 0 0 0 0 0 0 0 0 1 1 2 2 0 0 0 0 0 0 0 0]
 [2 1 1 2 2 1 1 1 2 2 2 0 0 0 0 1 1 1 1 1 1 2 2 0 0 0 0 0 0 0 0]
 [2 1 1 2 2 1 1 1 2 2 2 1 1 1 1 1 1 2 1 1 1 2 2 0 0 0 0 0 0 0 0]
 [2 1 1 2 2 1 2 1 2 2 2 1 1 1 1 2 1 2 2 2 2 2 2 1 0 1 1 1 1 1 0]]


In [53]:
src = np.where(epi_i[0])[0][0]
print("Source is ",src)

Source is  20


### Setup sib

In [54]:
def callback(t, err, f):
    print(f"{t:4d}, {err:3.2e}", end="\r")

In [55]:
contacts_sib = list(contacts.to_records(index=False))
obs_sib = list(obs_df.to_records(index=False))

In [56]:
mu = params["mu"]
sib_pars = sib.Params(prob_r=sib.Gamma(mu=mu))
fg = sib.FactorGraph(sib_pars, contacts_sib, obs_sib)

In [57]:
sib.iterate(fg,100,3e-6,callback=callback)

   0, inf   1, 9.09e-01   2, 4.82e-01   3, 6.33e-01   4, 3.40e-01   5, 1.69e-01   6, 1.23e-01   7, 7.74e-02   8, 0.00e+00

In [58]:
beliefs = []
for n in fg.nodes:
    res = (np.array(n.bt),np.array(n.bg))
    beliefs.append(np.stack(res))
    #print(beliefs[-1].shape)

In [59]:
names = [str(i) for i in range(len(fg.nodes))]

In [60]:
np.savez(f"beliefs_inst_{inst}",**dict(zip(names,beliefs)))

In [61]:
data = np.load("beliefs_inst_8.npz")

In [62]:
data.files

['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30']

## Batch processing

In [63]:
contacts_sib = list(contacts.to_records(index=False))
names = [str(i) for i in range(params["n"])]

In [78]:
obs_all_df[0].sort_values(by="t", inplace=True)
obs_all_df[0]

Unnamed: 0,i,st,t
0,6,0,0
0,9,0,2
0,17,1,4
0,12,1,7
1,21,1,7
2,22,1,7
0,10,2,7
0,14,1,8
1,25,1,8


In [72]:
all_marginals = []
for obs in obs_all_df:
    obs.sort_values(by="t")
    obs_sib = list(obs.to_records(index=False))
    mu = params["mu"]
    sib_pars = sib.Params(prob_r=sib.Gamma(mu=mu))
    fg = sib.FactorGraph(sib_pars, contacts_sib, obs_sib)

    sib.iterate(fg,100,3e-6,callback=callback)
    print("\nFinished")

    beliefs = []
    for n in fg.nodes:
        res = (np.array(n.bt),np.array(n.bg))
        beliefs.append(np.stack(res))
    all_marginals.append(tuple(beliefs))
    #print(beliefs[-1].shape)

0 inf
1 0.9
2 0.4908200468638564
3 0.3188670991376257
4 0.3410215246556428
5 0.2577644552270978
6 0.11417624608391663
7 0.10218456952920552
8 0.09011872960944872
9 0.0

Finished
0 inf
1 0.9090909090909091
2 0.7199650936978632
3 0.4485398335348392
4 0.2886688258672616
5 0.2886688258672616
6 0.0864980057658798
7 0.027218772673373443
8 0.009844251252036008
9 0.0

Finished
0 inf
1 0.9090909090909091
2 0.7721416869068866
3 0.3246052542382018
4 0.3114715821600669
5 0.2801806490772911
6 0.11069673431012772
7 0.04517753142833904
8 0.003439557341772148
9 0.0

Finished
0 inf
1 0.9090909090909091
2 0.5437278206387481
3 0.4045531453623853
4 0.2826881640554885
5 0.2618116398039921
6 0.2526821639833079
7 0.0

Finished
0 inf
1 0.9090909090909091
2 0.8835680488351797
3 0.8491950362701252
4 0.34478835417616305
5 0.1807550216726972
6 0.08507677154165882
7 0.037617991092635084
8 0.0

Finished
0 inf
1 0.9090909090909091
2 0.4977550830737674
3 0.31471443075710215
4 0.25312902263023795
5 0.24899407474041813

4 0.4851254760862802
5 0.4027789153570699
6 0.15977854614467701
7 0.09345270984994036
8 0.0

Finished
0 inf
1 0.9090909090909091
2 0.6011223252183613
3 0.579758681924888
4 0.5608622300809127
5 0.2197625011786762
6 0.022171415668488903
7 0.018175149368044208
8 0.015294848803784178
9 0.0

Finished
0 inf
1 0.9
2 0.6503641874465766
3 0.3734291923655084
4 0.3412362289820796
5 0.3648428076972042
6 0.33193515288703723
7 0.14260389500126858
8 0.0

Finished


In [65]:
margin_names = [f"{inst}_{n}"  for inst in range(n_inst) for n in range(params["n"])]

In [66]:
dict_marg = {}
for inst, margist in enumerate(all_marginals):
    for n, marginals in enumerate(margist):
        dict_marg[f"{inst}_{n}"] = marginals

In [67]:
dict_marg["2_3"] == all_marginals[2][3]

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])

In [68]:
np.savez("beliefs_tree",**dict_marg)

In [36]:
num_nodes = params["n"]

In [37]:

def load_marginals(filename):
    data = np.load(filename)
    all_marg_load = []
    for inst in range(n_inst):
        margist = [data[f"{inst}_{n}"] for n in range(num_nodes)]
        all_marg_load.append(margist)
    return all_marg_load

In [38]:
all_marg_load = load_marginals("beliefs_tree.npz")

In [39]:
len(all_marg_load)

50

In [40]:
len(all_marg_load[0])

31