In [1]:
import sys
sys.path.insert(0,'..')
import sib
import pandas as pd
import numpy as np

## Order of parameters

### Contacts:

in order: (i, j, t, lambda)

### Observations:
(i,s,t) which is in order:

- node
- state
- time

In [2]:
from pathlib import Path
import data_load

folder_data = Path("data/tree_check/")

params,contacts,observ,epidem = data_load.load_exported_data(folder_data)

n_inst = len(observ)
print("Number of instances: {}".format(len(observ)))
params

Number of instances: 50


{'type_graph': 'TREE',
 't_limit': 8,
 'lambda': 0.9,
 'mu': 0.3,
 'seed': 3,
 'n': 31,
 'p_tested': 0.01,
 'p_asymptomatic': 0.5,
 'p_delay_symptoms': {'1': 0.7, '2': 0.30000000000000004}}

In [3]:
contacts = contacts[["i","j","t","lambda"]]
#contacts.iloc[:30]

In [4]:
obs_all_df = []
for obs in observ:
    obs_df = data_load.convert_obs_to_df(obs)
    obs_df = obs_df[["i","st","t"]]
    obs_all_df.append(obs_df)

### Take one epidemic instance

In [5]:
inst = 8

In [9]:
contacts_sib = list(contacts.to_records(index=False))
sib.set_num_threads(1)
for inst in range(len(epidem)):
    epi_i = epidem[inst]
    obs_dict = observ[inst]
    obs_df = obs_all_df[inst]
    obs_sib = list(obs_df.to_records(index=False))
    mu = params["mu"]
    sib_pars = sib.Params(prob_r=sib.Gamma(mu=mu))
    fg = sib.FactorGraph(sib_pars, contacts_sib, obs_sib)
    sib.iterate(fg,100,3e-6,callback=None)
    print([list(n.bt) for n in fg.nodes])


[[0.01352693525286666, 0.005046349811381039, 0.04967803970741806, 0.23869212198346174, 0.6918120627483799, 0.0011403542860645933, 7.250373274605298e-05, 7.363175140343491e-06, 5.391870260235816e-07, 2.3730115515664794e-05], [0.005254614221858782, 0.07030288194008948, 0.2556382730285242, 0.6665969311064196, 0.0021780973391999103, 1.2471477061761612e-05, 5.4866063822798146e-06, 1.7156353636741048e-06, 2.984476692531352e-07, 9.23019743107716e-06], [0.001068757237763706, 0.01157514844204395, 0.009758602246881642, 0.05316092515483638, 0.21585479457239323, 0.7083733506455403, 0.00020602895272766114, 1.0601570799971824e-07, 1.0961216913499393e-08, 2.275770888144916e-06], [0.07058116229362223, 0.27905641153343635, 0.6485628949576409, 0.0012004818051152112, 0.00038935501103615016, 0.00016543047348836293, 1.4262929833983403e-05, 2.0790226496588585e-06, 4.765390127105349e-07, 2.744543416437417e-05], [0.00099841081864889, 0.0034419040966680746, 0.07091171967566515, 0.23743929327791577, 0.643449979

### Setup sib

In [7]:
def callback(t, err, f):
    print(f"{t:4d}, {err:3.2e}", end="\r")

In [8]:
contacts_sib = list(contacts.to_records(index=False))


   8, 0.00e+00


In [14]:
beliefs = []
for n in fg.nodes:
    res = (np.array(n.bt),np.array(n.bg))
    beliefs.append(np.stack(res))
    #print(beliefs[-1].shape)

In [15]:
names = [str(i) for i in range(len(fg.nodes))]

In [16]:
np.savez(f"beliefs_inst_{inst}",**dict(zip(names,beliefs)))

In [17]:
data = np.load("beliefs_inst_8.npz")

In [18]:
data.files

['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20',
 '21',
 '22',
 '23',
 '24',
 '25',
 '26',
 '27',
 '28',
 '29',
 '30']

## Batch processing

In [63]:

names = [str(i) for i in range(params["n"])]

In [78]:
obs_all_df[0].sort_values(by="t", inplace=True)
obs_all_df[0]

Unnamed: 0,i,st,t
0,6,0,0
0,9,0,2
0,17,1,4
0,12,1,7
1,21,1,7
2,22,1,7
0,10,2,7
0,14,1,8
1,25,1,8


In [10]:

def load_marginals(filename):
    data = np.load(filename)
    all_marg_load = []
    for inst in range(n_inst):
        margist = [data[f"{inst}_{n}"] for n in range(num_nodes)]
        all_marg_load.append(margist)

        return all_marg_load

In [19]:
contacts_sib = list(contacts.to_records(index=False))
all_marginals = []
for i in range(10):
    obs_sib = list(obs.to_records(index=False))
    mu = params["mu"]
    sib_pars = sib.Params(prob_r=sib.Gamma(mu=mu))
    fg = sib.FactorGraph(sib_pars, contacts_sib, obs_sib)

    sib.iterate(fg,100,3e-6,callback=None)
    beliefs = []
    for n in fg.nodes:
        res = (np.array(n.bt),np.array(n.bg))
        beliefs.append(np.stack(res))
    all_marginals.append(tuple(beliefs))
    

In [27]:
[data[f] for f in data.files]

[array([[1.35269353e-02, 5.04634981e-03, 4.96780397e-02, 2.38692122e-01,
         6.91812063e-01, 1.14035429e-03, 7.25037327e-05, 7.36317514e-06,
         5.39187026e-07, 2.37301155e-05],
        [1.34147042e-04, 3.26279580e-03, 4.48459372e-03, 1.58147625e-02,
         6.93900007e-02, 2.34986462e-01, 1.74143111e-01, 1.29010287e-01,
         3.68750110e-01, 2.37301155e-05]]),
 array([[5.25461422e-03, 7.03028819e-02, 2.55638273e-01, 6.66596931e-01,
         2.17809734e-03, 1.24714771e-05, 5.48660638e-06, 1.71563536e-06,
         2.98447669e-07, 9.23019743e-06],
        [2.20595653e-06, 4.94980584e-04, 1.54824441e-02, 6.91584499e-02,
         2.25765236e-01, 1.78597230e-01, 1.32309268e-01, 9.80175390e-02,
         2.80163416e-01, 9.23019743e-06]]),
 array([[1.06875724e-03, 1.15751484e-02, 9.75860225e-03, 5.31609252e-02,
         2.15854795e-01, 7.08373351e-01, 2.06028953e-04, 1.06015708e-07,
         1.09612169e-08, 2.27577089e-06],
        [2.58030321e-06, 2.40348586e-04, 2.62925729e-03,

In [65]:
margin_names = [f"{inst}_{n}"  for inst in range(n_inst) for n in range(params["n"])]

In [66]:
dict_marg = {}
for inst, margist in enumerate(all_marginals):
    for n, marginals in enumerate(margist):
        dict_marg[f"{inst}_{n}"] = marginals

In [67]:
dict_marg["2_3"] == all_marginals[2][3]

array([[ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True],
       [ True,  True,  True,  True,  True,  True,  True,  True,  True,
         True]])

In [68]:
np.savez("beliefs_tree",**dict_marg)

In [36]:
num_nodes = params["n"]

In [39]:
len(all_marg_load)

50

In [40]:
len(all_marg_load[0])

31