In [65]:
import pandas as pd
import sib

# Load Data

In [66]:
#contacts are directed, add the two directions for bidirectional contacts
df_contacts = pd.read_csv("./data/dummy/transmissions.csv")[["i", "j", "t", "lambda"]]
contacts = list(df_contacts.to_records(index=False))
N = df_contacts[["i", "j"]].max().max() + 1
print(f"Number of nodes: {N}")
df_contacts

Number of nodes: 4


Unnamed: 0,i,j,t,lambda
0,2,3,0,0.1
1,3,2,0,0.1
2,1,2,0,0.1
3,2,1,0,0.1
4,2,1,1,0.1
5,1,2,1,0.1
6,0,1,3,0.1
7,1,0,3,0.1


In [67]:
#observation S=0, I=1, R=2
df_observ = pd.read_csv("./data/dummy/observations.csv")
observ = list(df_observ.to_records(index=False))
observ

[(3, 0, 0), (1, 1, 1), (3, 0, 1), (2, 1, 2)]

## Generate the factor Graph

Parameters of the SIR model used for inference

In [80]:
mu = 1/10 #rate of recovery
pseed = 1/N #probability to have sources
params=sib.Params(prob_r=sib.Exponential(mu=mu), pseed=pseed)

### Dummy observations for computing marginals
We want to computer the marginal probabilities of each node to be S, I or R at time t=0 and t=3.

We add dummy observations to the observations list.

We force each node the have these times in the factor graph.

In [81]:
t1 = 0
t2 = 3
observ += [(i,-1,t1) for i in range(N)]
observ += [(i,-1,t2) for i in range(N)]
observ = list(sorted(observ, key=lambda x: x[2]))

#### factor graph

In [82]:
f = sib.FactorGraph(params=params,
                    contacts=contacts, 
                    observations=observ
                    )

# Iterate the BP equations

In [83]:
sib.iterate(f, maxit=100, tol=0.0001)

sib.iterate(damp=0.0): 3/100 0.000e+00/0.0001      


# Print the marginals

In [84]:
#marginal at time 0
for n in f.nodes:
    m = sib.marginal_t(n, t1)
    print(n.index, f"node:{n.index} -- Marginals (S,I,R): ({m[0]:.3f}, {m[1]:.3f}, {m[2]:.3f})")

0 node:0 -- Marginals (S,I,R): (0.657, 0.310, 0.033)
1 node:1 -- Marginals (S,I,R): (0.082, 0.918, 0.000)
2 node:2 -- Marginals (S,I,R): (0.179, 0.821, 0.000)
3 node:3 -- Marginals (S,I,R): (0.951, 0.000, 0.049)


In [85]:
#marginal at time 3
for n in f.nodes:
    m = sib.marginal_t(n, t2)
    print(n.index, f"node:{n.index} -- Marginals (S,I,R): ({m[0]:.3f}, {m[1]:.3f}, {m[2]:.3f})")

0 node:0 -- Marginals (S,I,R): (0.657, 0.230, 0.113)
1 node:1 -- Marginals (S,I,R): (0.000, 0.814, 0.186)
2 node:2 -- Marginals (S,I,R): (0.000, 0.905, 0.095)
3 node:3 -- Marginals (S,I,R): (0.947, -0.000, 0.053)
