In [1]:
using MatrixProductBP, MatrixProductBP.Models, MatrixProductBP.MPEMs
using IndexedGraphs
import ProgressMeter; ProgressMeter.ijulia_behavior(:clear)
import Random: GLOBAL_RNG;

┌ Info: Precompiling MatrixProductBP [3d39929c-b583-45fa-b331-3f50b693a38a]
└ @ Base loading.jl:1662


In [2]:
import MatrixProductBP.MPEMs: MPEM1, mpem1, MPEM2, evaluate

In [3]:
# weight on time 0, 1, 2, ..., T, Inf
struct Probt
    p :: Vector{Float64}
    
    function Probt(p::Vector{Float64})
        @assert all(x->x≥0, p)
        new(p)
    end
end

getT(pt::Probt) = length(pt.p) - 2

function rand_probt(T::Integer; rng=GLOBAL_RNG)
    p = rand(rng, T+2)
#     p ./= sum(p)
    Probt(p)
end

struct SIChain
   q :: Vector{Vector{Float64}}
end

getT(p::SIChain) = length(p.q) - 1

function _evaluate(p::SIChain, x)
    @assert length(x) == length(p.q)
    prob = 1.0
    for t in eachindex(p.q)
        prob *= p.q[t][x[t]]
    end
    for t in 1:length(x)-1
        x[t+1] < x[t] && return 0.0
    end    
    return prob
end

function normalization(p::SIChain)
    sum(_evaluate(p, x) for x in ([fill(1,t); fill(2,T-t+1)] for t in 0:length(p.q)))
end

evaluate(p::SIChain, x) = _evaluate(p, x) / normalization(p)

function SIChain(pt::Probt)
    q = map(1:length(pt.p)-1) do t
       h = log(pt.p[t+1]/pt.p[t])
        [1, exp(-h)]
    end
    SIChain(q)   
end

function mpem1(p::SIChain)
    q = p.q
    T = length(q) - 1
    tensors = map(0:T) do t
        if t == 0
            Float64[(x⁰==a¹) for _ in 1:1, a¹ in 1:2, x⁰ in 1:2]
        elseif t < T
            Float64[(aᵗ≤xᵗ)*q[t][aᵗ]*(xᵗ==aᵗ⁺¹) for aᵗ in 1:2, aᵗ⁺¹ in 1:2, xᵗ in 1:2]
        else
            Float64[(aᵀ≤xᵀ)*q[t][aᵀ]*q[t+1][xᵀ] for aᵀ in 1:2, _ in 1:1, xᵀ in 1:2]
        end
    end
    A = MPEM1(tensors)
    normalize!(A)
    A
end

function rand_sis_mpem1(T::Integer; rng=GLOBAL_RNG)
    pt = rand_probt(T; rng)
    p = SIChain(pt)
    mpem1(p)
end

function message_eq(b::SIChain)
    T = getT(b)
    tensors = map(0:T) do t
        if t == 0
            Float64[(xᵢ⁰==xⱼ⁰)*(xᵢ⁰==a¹) for _ in 1:1, a¹ in 1:2, xᵢ⁰ in 1:2, xⱼ⁰ in 1:2]
        elseif t < T
            Float64[(xⱼᵗ==xᵢᵗ)*(xᵢᵗ==aᵗ⁺¹)*(aᵗ≤xᵢᵗ)*b.q[t][aᵗ] for aᵗ in 1:2, aᵗ⁺¹ in 1:2, xᵢᵗ in 1:2, xⱼᵗ in 1:2]
        else
            Float64[(xᵢᵀ==xⱼᵀ)*(aᵀ≤xᵢᵀ)*b.q[t][aᵀ]*b.q[end][xᵢᵀ] for aᵀ in 1:2, _ in 1:1, xᵢᵀ in 1:2, xⱼᵀ in 1:2]
        end
    end
    A = MPEM2(tensors)
#     normalize!(A)
    A
end

twostates = collect(Iterators.product(1:2,1:2))[:]

function message_leq(c::SIChain)
    T = getT(c)
    tensors = map(0:T) do t
        if t == 0
            Float64[(xᵢ⁰≤xⱼ⁰)*(xᵢ⁰==aᵢ¹)*(xⱼ⁰==aⱼ¹) 
                for _ in 1:1, (aᵢ¹,aⱼ¹) in twostates, xᵢ⁰ in 1:2, xⱼ⁰ in 1:2]
        elseif t < T
            Float64[(xᵢᵗ≤xⱼᵗ)*(xᵢᵗ==aᵢᵗ⁺¹)*(xⱼᵗ==aⱼᵗ⁺¹)*(aᵢᵗ≤xᵢᵗ)*(aⱼᵗ≤xⱼᵗ)*c.q[t][aᵢᵗ] 
                for (aᵢᵗ,aⱼᵗ) in twostates, (aᵢᵗ⁺¹,aⱼᵗ⁺¹) in twostates, xᵢᵗ in 1:2, xⱼᵗ in 1:2]
        else
            Float64[(xᵢᵗ≤xⱼᵗ)*(aᵢᵗ≤xᵢᵗ)*(aⱼᵗ≤xⱼᵗ)*c.q[t][aᵢᵗ]*c.q[t+1][xᵢᵗ] 
                for (aᵢᵗ,aⱼᵗ) in twostates, _ in 1:1, xᵢᵗ in 1:2, xⱼᵗ in 1:2]
        end
    end
    A = MPEM2(tensors)
#     normalize!(A)
    A
end

message_leq (generic function with 1 method)

Compute
$$
m_{i\to j}(x_{i},x_{j})\propto\prod_{t}\mathbb{1}\left[x_{i}^{t}=x_{j}^{t}\right]\mathbb{1}\left[x_{i}^{t+1}\ge x_{i}^{t}\right]\tilde{b}_{i\to j}(x_{i})+\prod_{t}\mathbb{1}\left[x_{i}^{t}\le x_{j}^{t}\right]\mathbb{1}\left[x_{i}^{t+1}\ge x_{i}^{t}\right]\mathbb{1}\left[x_{j}^{t+1}\ge x_{j}^{t}\right]\tilde{c}_{i\to j}^{t}(x_{i}^{t})
$$
where $\tilde{b}$ and $\tilde{c}$ are the terms in the chain-factorization of arbitrary distribution of infection times $b(t_i)$, $c(t_i)$.

In [4]:
function message(c::Probt, b::Probt)
    C = message_eq(SIChain(c))
    B = message_leq(SIChain(b))
    C + B    
end

message (generic function with 1 method)

In [5]:
T = 10
N = 2
pt = rand_probt(T);

In [6]:
p = SIChain(pt);

In [7]:
t = 4
x = [fill(1,t); fill(2,T-t+1)]
evaluate(p, x)

0.12334346655324532

In [8]:
A = mpem1(p);

In [9]:
@assert evaluate(p, x) ≈ evaluate(A, x)
@assert all( evaluate(p,xx) ≈ evaluate(A, xx) for xx in ([fill(1,t); fill(2,T-t+1)] for t in 0:T))

In [10]:
T = 20
b = rand_probt(T)
c = rand_probt(T)
M = message(c, b)
bond_dims(M)'

1×20 adjoint(::Vector{Int64}) with eltype Int64:
 6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6  6

In [11]:
M4 = compress!(deepcopy(M); svd_trunc=TruncThresh(0.0))
bond_dims(M4)'

1×20 adjoint(::Vector{Int64}) with eltype Int64:
 3  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4  4

In [15]:
svd_trunc = TruncBondMax(3)
M3 = compress!(deepcopy(M); svd_trunc)
svd_trunc.maxerr

1-element Vector{Float64}:
 2.0093282396114493e-16

In [31]:
M3[2]

3×3×2×2 Array{Float64, 4}:
[:, :, 1, 1] =
 0.945144     0.0583304   -0.0005193
 0.320793     0.019798    -0.000176256
 0.000376307  2.32241e-5  -2.06758e-7

[:, :, 2, 1] =
 0.0  0.0  0.0
 0.0  0.0  0.0
 0.0  0.0  0.0

[:, :, 1, 2] =
  0.113713  -0.285299  -0.011449
 -0.335156   0.840887   0.0337446
  0.107859  -0.270613  -0.0108596

[:, :, 2, 2] =
  0.0161231  -0.032812    0.0866386
 -0.0473106   0.0962817  -0.254227
 -0.164008    0.333774   -0.881314

In [28]:
C = message_leq(SIChain(c))
B = message_eq(SIChain(b))
svd_trunc = TruncBondMax(3)
CC = compress!(deepcopy(C); svd_trunc)
svd_trunc.maxerr

1-element Vector{Float64}:
 0.0

In [25]:
svd_trunc = TruncBondMax(2)
BB = compress!(deepcopy(B); svd_trunc)
svd_trunc.maxerr

1-element Vector{Float64}:
 0.0