In [45]:
import pandas as pd
import sib

# Load Data

In [46]:
#contacts are directed, add the two directions for bidirectional contacts
df_contacts = pd.read_csv("./data/dummy/transmissions.csv")[["i", "j", "t", "lambda"]]
df_contacts

Unnamed: 0,i,j,t,lambda
0,2,3,0,0.02
1,3,2,0,0.02
2,1,2,0,0.02
3,2,1,0,0.02
4,0,2,0,0.02
5,2,0,0,0.02
6,2,1,1,0.02
7,1,2,1,0.02
8,0,1,3,0.02
9,1,0,3,0.02


In [47]:
#Create contact list for sib
contacts = list(df_contacts.to_records(index=False))
N = df_contacts[["i", "j"]].max().max() + 1
print(f"Number of nodes: {N}")

Number of nodes: 4


In [65]:
#loading observations
df_observ = pd.read_csv("./data/dummy/observations.csv")
df_observ

Unnamed: 0,i,s,t_test
0,0,0,0
1,1,1,1
2,2,1,2


In [66]:
#create tests fro sib from observations
tests = []
for obs in df_observ.iloc:
    test_temp = (obs["i"], sib.Test(obs["s"]==0,obs["s"]==1,obs["s"]==2), obs["t_test"])
    tests.append(test_temp)

## Generate the factor Graph

Parameters of the SIR model used for inference

In [67]:
mu = 1/50 #rate of recovery
pseed = 1e-2 #probability to have sources
params=sib.Params(prob_r=sib.Exponential(mu=mu), pseed=pseed)

### Dummy observations for computing marginals
We want to computer the marginal probabilities of each node to be S, I or R at time t=0 and t=3.

We add dummy observations to the observations list.

We force each node the have these times in the factor graph.

In [68]:
t1 = 0
t2 = 3
tests += [(i,sib.Test(1,1,1),t1) for i in range(N)]
tests += [(i,sib.Test(1,1,1),t2) for i in range(N)]
tests = list(sorted(tests, key=lambda x: x[2]))

#### factor graph

In [69]:
f = sib.FactorGraph(params=params,
                    contacts=contacts, 
                    tests=tests
                    )

# Iterate the BP equations

In [70]:
sib.iterate(f, maxit=10)




# Print the marginals

In [71]:
#marginal at time 0
for n in f.nodes:
    m = sib.marginal_t(n, t1)
    print(n.index, f"node:{n.index} -- Marginals (S,I,R): ({m[0]:.3f}, {m[1]:.3f}, {m[2]:.3f})")

0 node:0 -- Marginals (S,I,R): (1.000, 0.000, 0.000)
1 node:1 -- Marginals (S,I,R): (0.249, 0.751, 0.000)
2 node:2 -- Marginals (S,I,R): (0.502, 0.498, 0.000)
3 node:3 -- Marginals (S,I,R): (0.976, 0.024, 0.000)


In [72]:
#marginal at time 3
for n in f.nodes:
    m = sib.marginal_t(n, t2)
    print(n.index, f"node:{n.index} -- Marginals (S,I,R): ({m[0]:.3f}, {m[1]:.3f}, {m[2]:.3f})")

0 node:0 -- Marginals (S,I,R): (0.990, 0.009, 0.001)
1 node:1 -- Marginals (S,I,R): (0.000, 0.961, 0.039)
2 node:2 -- Marginals (S,I,R): (0.000, 0.980, 0.020)
3 node:3 -- Marginals (S,I,R): (0.966, 0.031, 0.002)
