In [1]:
using MatrixProductBP, MatrixProductBP.Models, MatrixProductBP.MPEMs
using IndexedGraphs
import ProgressMeter; ProgressMeter.ijulia_behavior(:clear)
import Random: GLOBAL_RNG;

In [2]:
import MatrixProductBP.MPEMs: MPEM1, evaluate

In [53]:
# probability of infection at time 0, 1, 2, ..., T, Inf
# joint probability for N particles
struct Probt{N}
    p :: Array{Float64, N}
    
    function Probt(p::Array{Float64}) where {N<:Integer}
        @assert all(x->x≥0, p)
        @assert sum(p) ≈ 1
        new{length(size(p))}(p)
    end
end

function rand_probt(T::Integer, N::Integer=1; rng=GLOBAL_RNG)
    p = rand(rng, fill(T+2,N)...)
    p ./= sum(p)
    Probt(p)
end

struct SIChain
   h :: Vector{Float64}
end

function _evaluate(p::SIChain, x)
    @assert length(x) == length(p.h)
    prob = 1.0
    for t in eachindex(p.h)
        x[t]==2 && (prob *= exp(-p.h[t]))
    end
    for t in 1:length(x)-1
        x[t+1] < x[t] && return 0.0
    end    
    return prob
end

function normalization(p::SIChain)
    sum(_evaluate(p, x) for x in ([fill(1,t); fill(2,T-t+1)] for t in 0:length(p.h)))
end

evaluate(p::SIChain, x) = _evaluate(p, x) / normalization(p)

function SIChain(pt::Probt{1})
    h = [log(pt.p[t+1]/pt.p[t]) for t in 1:length(pt.p)-1]
    SIChain(h)   
end

function MPEM1(p::SIChain)
    T = length(p.h) - 1
    tensors = map(0:T) do t
        if t == 0
            Float64[(x⁰==a¹) for _ in 1:1, a¹ in 1:2, x⁰ in 1:2]
        elseif t == T
            Float64[(xᵀ≥aᵀ)*exp(-p.h[end-1]*aᵀ)*exp(-p.h[end]*xᵀ) for aᵀ in 1:2, _ in 1:1, xᵀ in 1:2]
        else
            Float64[(xᵗ≥aᵗ)*exp(-p.h[t]*aᵗ)*(xᵗ==aᵗ⁺¹) for aᵗ in 1:2, aᵗ⁺¹ in 1:2, xᵗ in 1:2]
        end
    end
    A = MPEM1(tensors)
    normalize!(A)
    A
end

function rand_sis_mpem1(T::Integer, N::Integer=1; rng=GLOBAL_RNG)
    pt = rand_probt(t, N; rng)
    p = SIChain(pt)
    MPEM1(p)
end

rand_sis_mpem1 (generic function with 2 methods)

In [65]:
T = 10
N = 2
pt = rand_probt(T, 1);

In [66]:
p = SIChain(pt);

In [67]:
t = 4
x = [fill(1,t); fill(2,T-t+1)]
evaluate(p, x)

0.15465581085044877

In [68]:
A = MPEM1(p);

In [70]:
@assert evaluate(p, x) ≈ evaluate(A, x)
@assert all( evaluate(p,xx) ≈ evaluate(A, xx) for xx in ([fill(1,t); fill(2,T-t+1)] for t in 0:T))

In [59]:
A = rand_sis_mpem1(T, 1)

MPEM1{Float64}([[0.8553442963572924 0.0;;; 0.0 0.8553442963572924], [0.8553442963572924 0.0; 0.0 0.0;;; 0.0 0.8553442963572924; 0.0 0.23545356059996514], [0.5358051573621487 0.0; 0.0 0.0;;; 0.0 0.5358051573621487; 0.0 0.8553442963572924], [0.8553442963572924 0.0; 0.0 0.0;;; 0.0 0.8553442963572924; 0.0 0.41569773394269377], [0.8553442963572924; 0.0;;; 0.3447426816836672; 0.7200205379462378]])