In [1]:
versioninfo()

Julia Version 1.8.2
Commit 36034abf260 (2022-09-29 15:21 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin21.3.0)
  CPU: 10 × Apple M1 Pro
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, apple-m1)
  Threads: 1 on 8 virtual cores


In [2]:
# load packages
using BenchmarkTools, DataFrames, DynamicHMC, DynamicHMC.Diagnostics, 
    LinearAlgebra, LogDensityProblems, MCMCChains, MCMCDiagnosticTools,
    Parameters, Profile, ProximalOperators, CSV,
    Random, Revise, Roots, SparseArrays, Statistics, StatsPlots
import ProximalOperators: prox, prox!

# Proximal MCMC for matrix completion

We first implement the projection to the $\ell_1$ norm epigraph $\{(y, t): \|y\|_1 \le t\}$. To project a point $(x,s)$, the key computation is to find the root of 
$$
\phi(\lambda) = \|S_\lambda(x)\|_1 - \lambda - s,
$$
where $S$ is the sof-thresholding operator. This can be achieved by bisection or the sum-of-max algorithm (Algorithm 2 of <http://proceedings.mlr.press/v48/wangh16.pdf>). Both have linear complexity.

In [3]:
"""
    ϕ(x::Vector, λ)

Return the value `sum(abs, Sλ(x[1:end-1])) - λ - x[end]`.
"""
function ϕ(x::AbstractVector{T}, λ::T) where T <: Real
    out = zero(T)
    @inbounds for i in 1:(length(x) - 1)
        δi = abs(x[i]) - λ
        out += δi > 0 ? abs(δi) : zero(T)
    end
    out -= λ + x[end]
    out
end

ϕ

Now we can implement the projector to the $\ell_1$ norm epigraph
$$
\text{prox}_{\chi_{\mathcal{E}}}(x, s) = \text{proj}_{\mathcal{E}}(x) = \begin{cases}
(x, s) & \|x\|_1 \le s \\
(S_{\lambda^\star}(x), s + \lambda^\star) & \|x\|_1 > s
\end{cases},
$$
where $\lambda^\star$ is the root of $\phi(\lambda)$.

In [4]:
struct NormL1Epigraph <: Function end

is_convex(f::NormL1Epigraph) = true
is_set(f::NormL1Epigraph) = true

function prox!(
        y :: AbstractVector{T}, 
        f :: NormL1Epigraph, 
        x :: AbstractVector{T}, 
        γ :: T = T(1.0)
    ) where T <: Real
    if isinf(x[end]) || sum(abs, x) ≤ x[end] + abs(x[end])
        copyto!(y, x)
        return zero(T)
    end
    @views λ₀ = find_zero(λ -> ϕ(x, λ), (0, maximum(abs, x[1:end-1])), Bisection())
    for i in 1:(length(y) - 1)
        δi   = abs(x[i]) - λ₀
        y[i] = δi > 0 ? copysign(δi, x[i]) : zero(T) 
    end
    y[end] = x[end] + λ₀
    zero(T)
end

prox(f::NormL1Epigraph, x::AbstractVector{T}, γ::T = T(1.0)) where T<:Real = 
    prox!(similar(x), f, x, γ)

prox (generic function with 4 methods)

# Define the matrix completion problem

In [5]:
struct MatrixCompletionProblem{T <: Real}
    Y        :: Matrix{T}   # response
    X        :: Matrix{T}   # 
    Ω        :: Vector{Int} # Y[Ω] are observed
    λ        :: T           # Moreau-Yosida envelope parameter
    σ²prior  :: Tuple{T, T} # IG(r, s) prior for σ²
    αprior   :: Tuple{T, T} # IG(r, s) prior for α
    res      :: Vector{T}   # storage for residuals Y-X
    svα      :: Vector{T}
    svα_prox :: Vector{T}
end

# constructor
function MatrixCompletionProblem(Y::Matrix{T}, Ω::Vector{Int}, λ::T, 
        σ²prior::Tuple{T, T}, αprior::Tuple{T, T}) where T <: Real
    X         = similar(Y)
    res       = Vector{T}(undef, length(Ω))
    svα       = Vector{T}(undef, minimum(size(X)) + 1)
    svα_prox  = similar(svα)
    MatrixCompletionProblem{T}(Y, X, Ω, λ, σ²prior, αprior, res, svα, svα_prox)
end

MatrixCompletionProblem

In [6]:
(problem::MatrixCompletionProblem)(θ) = LogDensityProblems.logdensity(problem, θ)

LogDensityProblems.capabilities(::Type{<:MatrixCompletionProblem}) =
    LogDensityProblems.LogDensityOrder{1}()

LogDensityProblems.dimension(problem::MatrixCompletionProblem) = length(problem.Y) + 2


function LogDensityProblems.logdensity_and_gradient(problem::MatrixCompletionProblem{T}, θ) where T <: Real
    @unpack Y, X, Ω, λ, σ²prior, αprior, res, svα, svα_prox = problem
    rσ², sσ² = σ²prior
    rα , sα  = αprior        
    mn           = length(Y)
    copyto!(X, 1, θ, 1, mn)
    logα         = θ[mn + 1]
    logσ²        = θ[mn + 2]
    α            = exp(logα)
    σ²           = exp(logσ²)
    invσ²        = inv(σ²)
    invλ         = inv(λ)
    # log-likelihood + log(prior), with constant terms dropped
    @views res  .= Y[Ω] .- X[Ω]
    qf           = (abs2(norm(res)) + 2rσ²) / (2σ²)
    logl         = - qf - (length(Ω) / 2 + sσ²) * logσ² - rα / α - sα * logα
    ∇            = fill(T(0), length(θ))
    @views ∇[Ω] .= res .* invσ²               # ∇X
    ∇[mn + 1]    = rα / α - sα                # ∇logα
    ∇[mn + 2]    = qf - (length(Ω) / 2 + sσ²) # ∇logσ²
    # proximal mapping of (X, α)
    Xsvd         = svd!(X)
    copyto!(svα, Xsvd.S)
    svα[end]     = α
    prox!(svα_prox, NormL1Epigraph(), svα)
    svα_prox    .= svα .- svα_prox
    logl        -= abs2(norm(svα_prox)) / (2λ)
    @views ∇[1:mn] .-= invλ .* vec(Xsvd.U * Diagonal(svα_prox[1:end-1]) * Xsvd.Vt)
    ∇[mn + 1]   -= invλ * α * svα_prox[end]
    logl, ∇
end

## Simulation

**Since the computation requires lots of singular value decomposition, we recommend running the following code in a script rather than a notebook. For this reason we load the saved results below.**

Generate data $Y$ from $N(X, \sigma^2 I_n)$. $X$ is a randomly generated low rank matrix.

In [7]:
Random.seed!(123)
m, n = 100, 100
rtrue = 2
Xtrue = randn(m, rtrue) * randn(rtrue, n)
σ = 0.5 # noise s.d.
Y = Xtrue + σ * randn(m, n)
# linear indices of observed entries
Random.seed!(5)
Ω = findall(rand(m * n) .≤ 0.5);

In [8]:
λ = 1e-2
σ²prior = (0.01, 0.01) # a flat IG(r, s) prior for σ²
αprior  = (1., length(Y) + 1.) 
matcomp = MatrixCompletionProblem(Y, Ω, λ, σ²prior, αprior);
Ωᶜ = setdiff(1:m*n, Ω);
α₀  = sum(svdvals(Y))
σ²₀ = 1
Yinit = copy(Y)
Yinit[Ωᶜ] .= 0
paras = vcat(vec(Yinit), log(α₀), log(σ²₀));

@time results = mcmc_with_warmup(Random.GLOBAL_RNG, matcomp, 1000, reporter=ProgressMeterReport(),
                                 initialization = (q = paras,));

# diagnostics
summarize_tree_statistics(results.tree_statistics)

# using DelimitedFiles
# writedlm("results/MatrixCompletion-250x200.csv", results_matrix, ',')

1151.838253 seconds (44.75 M allocations: 984.558 GiB, 1.27% gc time, 0.19% compilation time)


Hamiltonian Monte Carlo sample of length 1000
  acceptance rate mean: 0.94, 5/25/50/75/95%: 0.76 0.91 0.97 0.99 1.0
  termination: divergence => 0%, max_depth => 1%, turning => 99%
  depth: 0 => 0%, 1 => 0%, 2 => 0%, 3 => 0%, 4 => 0%, 5 => 0%, 6 => 0%, 7 => 71%, 8 => 22%, 9 => 1%, 10 => 6%

In [9]:
results_matrix = copy(results.posterior_matrix)
results_matrix[end-1:end, :] = exp.(results_matrix[end-1:end, :])
results_matrix = Matrix(transpose(results_matrix))
mean(results_matrix[:,end-1])

678.3185680515084

In [10]:
chn = Chains(results_matrix);
result_quantile = DataFrame(quantile(chn));

In [11]:
mean(abs.(result_quantile[Ωᶜ, 4] - Xtrue[Ωᶜ]))

0.2066725570889349

In [12]:
mean(abs.(Y[Ωᶜ] - Xtrue[Ωᶜ]))

0.4020701146882816

In [13]:
Ωᶜ = setdiff(1:m*n, Ω);

In [14]:
df = DataFrame(truth = Xtrue[Ωᶜ], median = result_quantile[Ωᶜ, 4], lower = result_quantile[Ωᶜ, 2], upper = result_quantile[Ωᶜ, 6])

Row,truth,median,lower,upper
Unnamed: 0_level_1,Float64,Float64,Float64,Float64
1,-0.459588,-0.232287,-1.72584,1.23557
2,-0.657325,-0.486503,-2.04893,1.07626
3,-1.58982,-1.36117,-2.71808,0.259262
4,-1.96528,-1.81983,-3.30036,-0.225327
5,-3.13052,-2.55166,-4.20848,-1.02856
6,-1.70447,-1.74853,-3.26693,-0.148948
7,2.32103,2.25866,0.714479,3.88961
8,-2.1822,-1.94672,-3.54008,-0.348789
9,-0.800283,-0.368844,-1.92468,1.20987
10,-0.527647,-0.479862,-2.0712,0.959612


In [15]:
# Calculates percent parameters covered
cover = zeros(size(df, 1))
widths = zeros(size(df, 1))
for i in 1:length(cover)
    cover[i] = df[i, 3] ≤ df[i, 1] ≤ df[i, 4] ? 1 : 0
    widths[i] = df[i,4] - df[i,3]
end
println(sum(cover)/length(cover)) 
mean(widths)

0.9994014365522745


2.991571412803108

In [16]:
cov_matrix = cov(results_matrix)
F = svd(cov_matrix)
v_strongest = F.V[:,1]
slowest = results_matrix * v_strongest

1000-element Vector{Float64}:
 613.4012433102819
 627.5975237203226
 622.1965416659449
 628.582699714467
 632.5286714915078
 633.5855591888915
 633.4031183760666
 634.4721288868826
 639.2929842139288
 636.5475301036289
 638.6725609522714
 639.7637400832965
 632.7163336288592
   ⋮
 638.1366633472546
 640.941089690765
 638.7266918621178
 640.1700441709384
 633.1208093979956
 628.5237878467225
 621.8626060962819
 628.8479558373161
 625.8083919379297
 623.7916792842612
 622.916883216086
 624.6062984165462

In [18]:
chn_slowest = MCMCChains.Chains(slowest, ["Yslowest"], thin = 1);

In [19]:
# number of samples = 1000
ess_slowest = ess_rhat(chn_slowest)

ESS
 [1m parameters [0m [1m      ess [0m [1m    rhat [0m
 [90m     Symbol [0m [90m  Float64 [0m [90m Float64 [0m

    Yslowest   208.8082    1.0024
