## Max-Sum

In [2]:
using StaticArrays

In [7]:
hamming(x,y) = count_ones((x-1) ⊻ (y-1))

hamming (generic function with 1 method)

In [17]:
msg_sum(u1::SVector, u2::SVector) = u1 .+ u2
normalize_max(u::SVector) = u .- maximum(u)
function neutral_prob_ms(Q::Integer)
    v = -Inf*ones(SVector{Q})
    v = setindex(v, 0.0, 1)
end

neutral_prob_ms (generic function with 1 method)

In [72]:
function iter_var_ms(us, s::Int, Q)
    h = - hamming.(SVector{Q}(1:Q), s)
    for u in us
        h = msg_sum(h, u)
    end
    h .- maximum(h)
end

iter_var_ms (generic function with 2 methods)

In [28]:
function msg_maxconv_gfq!(u::MVector{Q,T}, h1::SVector{Q,T}, h2::SVector{Q,T}) where {Q,T}
    for x1 in eachindex(h1), x2 in eachindex(h2)
        # adjust for indices starting at 1 instead of 0
        u[((x1-1) ⊻ (x2-1)) + 1] = max(u[((x1-1) ⊻ (x2-1)) + 1], h1[x1]+h2[x2])
    end
    SVector(u)
end

msg_maxconv_gfq! (generic function with 1 method)

In [90]:
function iter_factor_ms(hs, Hs, H0, gfmult, gfdiv, Q)
    uaux = -Inf*ones(MVector{Q,Float64})
    u_tilde = neutral_prob_ms(Q)
    for (h,H) in zip(hs,Hs)
        h_tilde = h[SVector{Q}(gfdiv[:,H])]
        uaux .= -Inf
        u_tilde = msg_maxconv_gfq!(uaux, u_tilde, h_tilde)
    end
    u = u_tilde[SVector{Q}(gfmult[:,H0])]
    u = u .- maximum(u)
end

iter_factor_ms (generic function with 2 methods)

In [30]:
 gfmult = SA[1 1 1 1;
                1 2 3 4;
                1 3 4 2;
                1 4 2 3]
    gfdiv = SA[0 1 1 1;
               0 2 4 3;
               0 3 2 4;
               0 4 3 2]  

4×4 SMatrix{4, 4, Int64, 16} with indices SOneTo(4)×SOneTo(4):
 0  1  1  1
 0  2  4  3
 0  3  2  4
 0  4  3  2

In [31]:
h1 = SVector{Q}(1.0,2.0,3.0,4.0)
h2 = SVector{Q}(5.0,6.0,7.0,8.0)
Hs = [3,4]
H0 = 2
hs = [h1, h2]
iter_factor_ms(hs, Hs, H0, gfmult, gfdiv)

4-element SVector{4, Float64} with indices SOneTo(4):
 11.0
 11.0
 10.0
 12.0

In [107]:
function RS_gfq(Λ, Pk, Q, popP, popQ, gfmult, gfdiv; maxiter=10^2, tol=1e-5, 
        toliter=1/length(popP)) 
    ks = [k for k in eachindex(Pk) if Pk[k] > tol]
    ds = [d for d in eachindex(Λ) if Λ[d] > tol]
    @assert sum(Pk[ks]) ≈ 1 && sum(Λ[ds]) ≈ 1    
    Λ_red = [d*Λ[d] for d in eachindex(Λ)]; Λ_red ./= sum(Λ_red)
    P_red = [d*Pk[d] for d in eachindex(Pk)]; P_red ./= sum(P_red)
    
    N = length(popP)
    avgP = mean(popP)
    for it in 1:maxiter
        # update Q
        for i in 1:N
            d = sample(eachindex(Λ_red), weights(Λ_red))
            s = rand(1:Q)
            idx = sample(1:N, d-1)
            popQ[i] = iter_var_ms(popP[idx], s, Q)
        end
        # update P(u)
        for i in 1:N
            k = sample(eachindex(P_red), weights(P_red))
            Hs = rand(2:Q,k-1)
            H0 = rand(2:Q)
            idx = sample(1:N, k-1)
            popP[i] = iter_factor_ms(popQ[idx], Hs, H0, gfmult, gfdiv, Q)
        end
        avgP_old = avgP
        avgP = mean(popP)
        err = maximum(abs.(avgP_old.-avgP))
        @show err, toliter
        if err < toliter
            break
        end
    end
    
end

RS_gfq (generic function with 1 method)

In [100]:
function freenrj_factor(popQ, Pk, gfmult, gfdiv, Q)
    N=length(popQ)
    k = sample(eachindex(Pk), weights(Pk))
    Hs = rand(2:Q,k)
    idx = sample(1:N, k)
    hs=popQ[idx]
    u = iter_factor_ms(hs, Hs, 1, gfmult, gfdiv, Q)
    Fa = -u[1]
end
function freenrj_edge(popP, popQ)
    N=length(popP)
    h=popQ[rand(1:N)]
    u=popP[rand(1:N)]
    Fia = -maximum(msg_sum(h, u))
end
function freenrj_var(Λ, popP, Q)
    N = length(popP)
    d = sample(eachindex(Λ), weights(Λ))
    s = rand(1:Q)
    idx = sample(1:N, d)
    us = popP[idx]
    Fi = -maximum(iter_var_ms(us, s, Q))
end

function freenrj(Λ, Pk, Q, popP, popQ, gfmult, gfdiv; samples=10^3)
    mK = sum(k*Pk[k] for k=eachindex(Pk))
    mΛ = sum(d*Λ[d] for d=eachindex(Λ))
    α = mΛ/mK
    Fa=Fi=Fia=0 
    for t=1:samples
        Fa += freenrj_factor(popQ, Pk, gfmult, gfdiv, Q)
        Fi += freenrj_var(Λ, popP, Q)
        Fia += freenrj_edge(popP, popQ)
        F = (Fi + α*Fa - mΛ*Fia)/t/log2(Q)
        @show F
    end
    F = (Fi + α*Fa - mΛ*Fia)/samples/log2(Q)
end

freenrj (generic function with 1 method)

In [101]:
popP = fill(zero(SVector{Q}), N)
mean(popP)

4-element SVector{4, Float64} with indices SOneTo(4):
 0.0
 0.0
 0.0
 0.0

In [108]:
N=1000
maxiter = 1000
popP = fill(zero(SVector{Q}), N)
popQ = fill(zero(SVector{Q}), N)
Λ = [0.5,0.5]
Pk = [0,0,1]
RS_gfq(Λ, Pk, Q, popP, popQ, gfmult, gfdiv, maxiter=maxiter)

(err, toliter) = (0.844, 0.001)
(err, toliter) = (0.19699999999999995, 0.001)
(err, toliter) = (0.08599999999999997, 0.001)
(err, toliter) = (0.06299999999999994, 0.001)
(err, toliter) = (0.052999999999999936, 0.001)
(err, toliter) = (0.07299999999999995, 0.001)
(err, toliter) = (0.03600000000000003, 0.001)
(err, toliter) = (0.040999999999999925, 0.001)
(err, toliter) = (0.06299999999999994, 0.001)
(err, toliter) = (0.04599999999999993, 0.001)
(err, toliter) = (0.04799999999999993, 0.001)
(err, toliter) = (0.030000000000000027, 0.001)
(err, toliter) = (0.09999999999999998, 0.001)
(err, toliter) = (0.041000000000000036, 0.001)
(err, toliter) = (0.06799999999999995, 0.001)
(err, toliter) = (0.05699999999999994, 0.001)
(err, toliter) = (0.052999999999999936, 0.001)
(err, toliter) = (0.04999999999999993, 0.001)
(err, toliter) = (0.06599999999999995, 0.001)
(err, toliter) = (0.02300000000000002, 0.001)
(err, toliter) = (0.06699999999999995, 0.001)
(err, toliter) = (0.02499999999999991, 0.00

(err, toliter) = (0.10399999999999998, 0.001)
(err, toliter) = (0.03799999999999992, 0.001)
(err, toliter) = (0.05500000000000005, 0.001)
(err, toliter) = (0.04400000000000004, 0.001)
(err, toliter) = (0.02199999999999991, 0.001)
(err, toliter) = (0.05399999999999994, 0.001)
(err, toliter) = (0.04600000000000004, 0.001)
(err, toliter) = (0.016999999999999904, 0.001)
(err, toliter) = (0.040999999999999925, 0.001)
(err, toliter) = (0.04299999999999993, 0.001)
(err, toliter) = (0.051999999999999935, 0.001)
(err, toliter) = (0.03500000000000003, 0.001)
(err, toliter) = (0.025000000000000022, 0.001)
(err, toliter) = (0.038999999999999924, 0.001)
(err, toliter) = (0.03500000000000003, 0.001)
(err, toliter) = (0.04899999999999993, 0.001)
(err, toliter) = (0.02300000000000002, 0.001)
(err, toliter) = (0.039999999999999925, 0.001)
(err, toliter) = (0.04999999999999993, 0.001)
(err, toliter) = (0.07199999999999995, 0.001)
(err, toliter) = (0.11699999999999999, 0.001)
(err, toliter) = (0.10199999

(err, toliter) = (0.03400000000000003, 0.001)
(err, toliter) = (0.07199999999999995, 0.001)
(err, toliter) = (0.052000000000000046, 0.001)
(err, toliter) = (0.03299999999999992, 0.001)
(err, toliter) = (0.07899999999999996, 0.001)
(err, toliter) = (0.04700000000000004, 0.001)
(err, toliter) = (0.041999999999999926, 0.001)
(err, toliter) = (0.05899999999999994, 0.001)
(err, toliter) = (0.019999999999999907, 0.001)
(err, toliter) = (0.018000000000000016, 0.001)
(err, toliter) = (0.041999999999999926, 0.001)
(err, toliter) = (0.08899999999999997, 0.001)
(err, toliter) = (0.06600000000000006, 0.001)
(err, toliter) = (0.07799999999999996, 0.001)
(err, toliter) = (0.05500000000000005, 0.001)
(err, toliter) = (0.07699999999999996, 0.001)
(err, toliter) = (0.06400000000000006, 0.001)
(err, toliter) = (0.09299999999999997, 0.001)
(err, toliter) = (0.017000000000000015, 0.001)
(err, toliter) = (0.06699999999999995, 0.001)
(err, toliter) = (0.040999999999999925, 0.001)
(err, toliter) = (0.0350000

(err, toliter) = (0.05400000000000005, 0.001)
(err, toliter) = (0.04200000000000004, 0.001)
(err, toliter) = (0.018000000000000016, 0.001)
(err, toliter) = (0.027000000000000024, 0.001)
(err, toliter) = (0.03199999999999992, 0.001)
(err, toliter) = (0.03699999999999992, 0.001)
(err, toliter) = (0.06599999999999995, 0.001)
(err, toliter) = (0.04300000000000004, 0.001)
(err, toliter) = (0.05799999999999994, 0.001)
(err, toliter) = (0.08599999999999997, 0.001)
(err, toliter) = (0.051999999999999935, 0.001)
(err, toliter) = (0.04699999999999993, 0.001)
(err, toliter) = (0.025999999999999912, 0.001)
(err, toliter) = (0.029000000000000026, 0.001)
(err, toliter) = (0.06899999999999995, 0.001)
(err, toliter) = (0.05499999999999994, 0.001)
(err, toliter) = (0.05799999999999994, 0.001)
(err, toliter) = (0.08299999999999996, 0.001)
(err, toliter) = (0.08899999999999997, 0.001)
(err, toliter) = (0.08399999999999996, 0.001)
(err, toliter) = (0.06299999999999994, 0.001)
(err, toliter) = (0.044999999

(err, toliter) = (0.04899999999999993, 0.001)
(err, toliter) = (0.014000000000000012, 0.001)
(err, toliter) = (0.04499999999999993, 0.001)
(err, toliter) = (0.04699999999999993, 0.001)
(err, toliter) = (0.051999999999999935, 0.001)
(err, toliter) = (0.03399999999999992, 0.001)
(err, toliter) = (0.02100000000000002, 0.001)
(err, toliter) = (0.06699999999999995, 0.001)
(err, toliter) = (0.05600000000000005, 0.001)
(err, toliter) = (0.039999999999999925, 0.001)
(err, toliter) = (0.09599999999999997, 0.001)
(err, toliter) = (0.052000000000000046, 0.001)
(err, toliter) = (0.07699999999999996, 0.001)
(err, toliter) = (0.041999999999999926, 0.001)
(err, toliter) = (0.025999999999999912, 0.001)
(err, toliter) = (0.05399999999999994, 0.001)
(err, toliter) = (0.05699999999999994, 0.001)
(err, toliter) = (0.07099999999999995, 0.001)
(err, toliter) = (0.040999999999999925, 0.001)
(err, toliter) = (0.02300000000000002, 0.001)
(err, toliter) = (0.018000000000000016, 0.001)
(err, toliter) = (0.040000

(err, toliter) = (0.03400000000000003, 0.001)
(err, toliter) = (0.04799999999999993, 0.001)
(err, toliter) = (0.04399999999999993, 0.001)
(err, toliter) = (0.031000000000000028, 0.001)
(err, toliter) = (0.027999999999999914, 0.001)
(err, toliter) = (0.019000000000000017, 0.001)
(err, toliter) = (0.020000000000000018, 0.001)
(err, toliter) = (0.04299999999999993, 0.001)
(err, toliter) = (0.06899999999999995, 0.001)
(err, toliter) = (0.02100000000000002, 0.001)
(err, toliter) = (0.06999999999999995, 0.001)
(err, toliter) = (0.04399999999999993, 0.001)
(err, toliter) = (0.05799999999999994, 0.001)
(err, toliter) = (0.10599999999999998, 0.001)
(err, toliter) = (0.06700000000000006, 0.001)
(err, toliter) = (0.04400000000000004, 0.001)
(err, toliter) = (0.029999999999999916, 0.001)
(err, toliter) = (0.051999999999999935, 0.001)
(err, toliter) = (0.02199999999999991, 0.001)
(err, toliter) = (0.025000000000000022, 0.001)
(err, toliter) = (0.06999999999999995, 0.001)
(err, toliter) = (0.0480000

In [109]:
freenrj(Λ, Pk, Q, popP, popQ, gfmult, gfdiv, samples=10^2)

F = -0.75
F = -0.375
F = -0.25
F = -0.1875
F = -0.3
F = -0.25
F = -0.21428571428571427
F = -0.1875
F = -0.16666666666666666
F = -0.225
F = -0.2727272727272727
F = -0.3125
F = -0.34615384615384615
F = -0.32142857142857145
F = -0.3
F = -0.328125
F = -0.3088235294117647
F = -0.3333333333333333
F = -0.35526315789473684
F = -0.375
F = -0.35714285714285715
F = -0.3409090909090909
F = -0.358695652173913
F = -0.375
F = -0.36
F = -0.375
F = -0.3888888888888889
F = -0.4017857142857143
F = -0.41379310344827586
F = -0.45
F = -0.4596774193548387
F = -0.46875
F = -0.45454545454545453
F = -0.4411764705882353
F = -0.42857142857142855
F = -0.4166666666666667
F = -0.40540540540540543
F = -0.39473684210526316
F = -0.40384615384615385
F = -0.39375
F = -0.38414634146341464
F = -0.375
F = -0.36627906976744184
F = -0.39204545454545453
F = -0.38333333333333336
F = -0.391304347826087
F = -0.39893617021276595
F = -0.40625
F = -0.3979591836734694
F = -0.405
F = -0.4117647058823529
F = -0.4182692307692308
F = -0.

-0.4125

In [99]:
mK = sum(k*Pk[k] for k=eachindex(Pk))
mΛ = sum(d*Λ[d] for d=eachindex(Λ))
α = mΛ/mK
R=1-α

0.5

In [35]:
using StatsBase
