This code base is using the Julia Language and DrWatson to make a reproducible scientific project named
Stochastic Approximation with Biased MCMC for Expectation Maximization
using Pkg
Pkg.develop(url="https://github.com/Red-Portal/MCMCSAEM.jl.git")
Consider the logistic regression problem
We will show how to infer the hyperparameters
The likelihood itself can be specificed as follows using the LogDensityProblems interface:
using MCMCSAEM
using Distributions
using LogDensityProblems
struct Logistic{Mat <: AbstractMatrix, Vec <: AbstractVector}
X::Mat
y::Vec
end
function LogDensityProblems.capabilities(::Type{<:Logistic})
LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.logdensity(
model::Logistic, β::AbstractVector, θ::AbstractVector
)
X, y = model.X, model.y
d = size(X,2)
s = X*β
μ, σ = θ[1], θ[2]
ℓp_x = mapreduce(+, s, y) do sᵢ, yᵢ
logpdf(BernoulliLogit(sᵢ), yᵢ)
end
ℓp_β = logpdf(MvNormal(fill(μ, d), σ), β)
ℓp_x + ℓp_β
end
MCMCSAEM
expects the user to define its own E-step and M-step functions.
For the E-step (obtaining the sufficient statistic), we only need the first and second moments of
function MCMCSAEM.sufficient_statistic(::Logistic, x::AbstractMatrix)
mean(eachcol(x)) do xi
vcat(xi, xi.^2)
end
end
The M-step receives the estimated sufficient statistic and returns the hyperparameters maximizing the EM surrogate. (Refer to the paper for more details.)
function MCMCSAEM.maximize_surrogate(::Logistic, S::AbstractVector)
d = div(length(S), 2)
EX = S[1:d]
EX² = S[d+1:end]
μ = mean(EX)
[μ, sqrt(mean(EX²) - μ^2)]
end
SAEM can be executed as follows:
using ADTypes
using Random
using Plots
function main()
rng = Random.default_rng()
# SAEM Settings
T = 50 # Number of SAEM iterations
T_burn = 5 # Number of initial burn-in steps
γ0 = 1e-0 # Base stepsize
γ = t -> γ0 / sqrt(t) # Stepsize schedule
h = 1e-2 # MCMC stepsize
mcmc_type = :ula # MCMC algorithm (:ula or :mala)
ad = ADTypes.AutoForwardDiff() # autodiff backend
# Create synthetic dataset
n = 500 # n_datapoints
d = 30 # n_regressors
X = randn(rng, n, d) # regressors
θ_true = [-0.5, 2.0] # "True" hyperparameters
β_true = rand(rng, Normal(θ_true[1], θ_true[2]), d) # True coefficients
y = rand.(rng, BernoulliLogit.(X*β_true)) # Target variables
# Create Model
model = Logistic(X, y)
# Initialize SAEM
θ0 = [0.0, 5.0]
x0 = randn(rng, d, 1)
θ, x, stats = MCMCSAEM.mcmcsaem(
rng, model, x0, θ0, T, T_burn, γ, h;
ad,
show_progress = true,
mcmc_type = mcmc_type
)
plot([stat.loglike for stat in filter(Base.Fix2(haskey, :loglike), stats)], xlabel="SAEM Iteration", ylabel="Log Joint")
end
main()
We can see the log joint going up as we converge:
The experiments in the paper can be replicated by executing the scripts in scripts/
.