In [None]:
using ExpFamily
using EPInference
using GR
GR.inline()

## Creating artificial data

In [None]:
srand(12)

diagCov = false 

nObs    = 10000
dim     = 5
nSites  = 4

P   = diagCov?diagm(-1+2*rand(dim)):(-1+2*rand(dim,dim))
P   = P * P' + 0.1*eye(dim)
mu  = rand(dim)

priorNP = GaussianNatParam(mean=rand(dim), cov=10*eye(dim))
w       = rand(priorNP)

priordNP = DiagGaussianNatParam(mean=rand(dim), cov=10*ones(dim))

X = rand(GaussianNatParam(P*mu, -P), nObs)
s = (1./(1.+exp.(-w'*X)))'
y = (s .> rand(nObs)) * 2.0 - 1.0
;

In [None]:
# Shard data
frac        = round(Int, (nObs/nSites))
siteIndices = frac: frac: ((nSites-1)*frac)

function batches(i::Int)::Tuple{Matrix{Float64},Vector{Float64}}
    @assert i>0 && i <= nSites "wrong batch index"
    rge = [1]
    if i == nSites
        rge = (siteIndices[end]+1) : nObs
    elseif i == 1
        rge = 1:siteIndices[1]
    else
        rge = (siteIndices[i-1]+1) : siteIndices[i]
    end
    (X[:,rge], y[rge])
end

loglogistic(u::Float64) = -log(1.+exp.(-u))

function logfactor_blr(siteIndex, points)     # points dim: DxM for arbitrary M>1
    (locBatchX, locBatchY) = batches(siteIndex) # corresponding observations
    #
    logfac = zeros(size(points,2))
    for i in 1:size(points,2)
        u = locBatchY .* (locBatchX'*points[:,i]) # size N*1
        logfac[i] += sum(loglogistic.(u))
    end
    logfac
end

factors = [x->logfactor_blr(i, x) for i in 1:nSites];
;

## NP-like algorithms

In [None]:
@time begin
    params = ParamsEP(priorNP, factors, 500, 100, 0.9, 1e-2)
    (approx_np, memapprox_np) = epNP(params)
    println(norm(mean(approx_np)-w)/norm(w))
end

In [None]:
err = [norm(mean(memapprox_np[i])-w)/norm(w) for i in 1:length(memapprox_np)]
plot(err, )

In [None]:
@time begin
    params = ParamsEP(priordNP, factors, 200, 100, 0.01)
    approx_np = epdNP(params)
    println(norm(mean(approx_np)-w)/norm(w))
end

## MP-like algorithms

In [None]:
@time begin
    params2   = ParamsEP(priorNP, factors, 100, 5, 0.1)
    (approx_mp, memapprox_mp) = epMP(params2)
    println(norm(mean(approx_mp)-w)/norm(w))
end

In [None]:
err = [norm(mean(memapprox_mp[i])-w)/norm(w) for i in 1:length(memapprox_mp)]
plot(err)

In [None]:
@time begin
    params2   = ParamsEP(priorNP, factors, 100, 50, -0.001)
    (approx_mp, memapprox_mp) = epMP(params2)
    println(norm(mean(approx_mp)-w)/norm(w))
end
err = [norm(mean(memapprox_mp[i])-w)/norm(w) for i in 1:length(memapprox_mp)]
plot(err)

## SNEP-like algorithms

In [None]:
@time begin
    params2   = ParamsEP(priorNP, factors, 300, 10, 0.03)
    (approx_mp, memapprox_mp) = epSNEP(params2)
    println(norm(mean(approx_mp)-w)/norm(w))
end
err = [norm(mean(memapprox_mp[i])-w)/norm(w) for i in 1:length(memapprox_mp)]
plot(err)

In [None]:
@time begin
    params2   = ParamsEP(priorNP, factors, 200, 10, 0.037)
    (approx_mp, memapprox_mp) = epSNEP2(params2)
    println(norm(mean(approx_mp)-w)/norm(w))
end
err = [norm(mean(memapprox_mp[i])-w)/norm(w) for i in 1:length(memapprox_mp)]
plot(err)