# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using BenchmarkTools

using Clustering
using NIfTI

using JLD
using Printf

In [2]:
using StatsBase

In [3]:
include("../../src/functions/jacobi_field/beta.jl")

β (generic function with 1 method)

### Functions

In [4]:
function SNMF_euclidean(X, k, init_const, max_iter)
    # in this function X is d \times n
    # F is d \times k, G is n \times k
    R = kmeans(X, k)
    d = size(X)[1]
    n = size(X)[2]
    G_init = zeros(n, k)
    for i=1:n
        G_init[i,assignments(R)[i]] = 1
    end
    iter = 0
    G = G_init .+ init_const
    F = zeros(d, k)
    while iter < max_iter
        F = X * G * pinv(G' * G)
        A = X' * F
        B = F' * F
        Apos = (abs.(A) + A) ./ 2
        Aneg = (abs.(A) - A) ./ 2
        Bpos = (abs.(B) + B) ./ 2
        Bneg = (abs.(B) - B) ./ 2
        G = G .* sqrt.((Apos + G * Bneg)./(Aneg + G * Bpos))
        G[isnan.(G)].=0
        iter += 1
    end
    return F, G
end

SNMF_euclidean (generic function with 1 method)

In [5]:
function naive_sNMF(M::AbstractManifold, q, data, k; init_const = 0.2, max_iter=50)
    log_q_data = log.(Ref(M), Ref(q), data);
    X_eucl = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_data, Ref(DefaultOrthonormalBasis())))
    V, U = SNMF_euclidean(X_eucl, k, init_const, max_iter);
    return U, V'
end

naive_sNMF (generic function with 1 method)

In [6]:
function curvature_corrected_V_update_precomp!(tensorU, M::AbstractPowerManifold, q, X, log_q_X_tensor, U, rank, βκΘq, tensorβκΘq, tensorΨq)
    # U comes in n[1] by r
    # tensorU is n[1], d, r, d, D
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    r = rank
    
    # construct linear system
    for col=1:r
        for row=1:n[1]
            tensorU[row,:,col,:,:].=view(U,row,col)
        end
    end

    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorU .* repeat(reshape(pdts,(n[1],d,1,d,D)),outer=(1,1,r,1,1));

    tensorβκB_exp = reshape(tensorβκB,(n[1],d,1,r,d,D));
    tensorβκB_alt = zeros(n[1],d,D,r,d,D);
    for k=1:D
        tensorβκB_alt[:,:,k,:,:,k] .= @view tensorβκB_exp[:,:,1,:,:,k];
    end

    βκB = reshape(tensorβκB_alt,(n[1]*d*D,r*d*D));

    A = transpose(βκB) * βκB;
    
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(tensorb_pre,(n[1] * d * D));
    βκBb = transpose(βκB) * b;

    # solve linear system
    Vₖₗ = A\βκBb # vector of length r * d * D
    tensorVₖₗ = reshape(Vₖₗ, (r, d * D));
    return tensorVₖₗ
end

curvature_corrected_V_update_precomp! (generic function with 1 method)

In [7]:
function curvature_corrected_U_update_precomp!(tensorV, M::AbstractPowerManifold, q, X, log_q_X, log_q_X_tensor, U, V, 
        rank, βκΘq, tensorβκΘq, tensorΨq; iters = 100, debug_int = 10)
    # V is r by (d * D)
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    r = rank
    # construct linear system

    # construct matrix βκB
    V_reshape = reshape(V, (r,d,D));
    for i1=1:r
        for i2=1:d
            for i3=1:D
                tensorV[:,i1,:,i2,i3].=view(V_reshape,i1,i2,i3);
            end
        end
    end   

    
    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorV .* repeat(reshape(pdts,(n[1],1,d,d,D)),outer=(1,r,1,1,1));

    βκB = reshape(tensorβκB,(n[1],r,d*d*D));
    
    # construct numerator
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(repeat(reshape(tensorb_pre,(n[1], 1, d, 1, D)), outer=(1,r,1,d,1)),(n[1],r,d*d*D));

    B = dropdims(sum(βκB .* b, dims=3),dims=3);
    
    Bpos = (abs.(B) .+ B) ./ 2
    Bneg = (abs.(B) .- B) ./ 2
    # reshape to (n, 1, d, 1), expand to (n, r, d, d), reshape to (n, r, d * d) which is (j1, i1, j & i)
    # then want (essentially) dot product of betakappaB and this along third dimension
    # construct denominator
    iter = 0
    this_U = U
    Ξ_q_coords = this_U * V
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    Cpos = similar(Bpos);
    Cneg = similar(Bneg);
    ccU_q = similar(this_U);
    while iter < iters
        iter += 1
        Ξ_q_tensor = repeat(reshape([Ξ_q[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
        sΞ_q = eachslice(Ξ_q_tensor,dims=3);
        sβκΘq_forlog = eachslice(βκΘq,dims=3);
        tensorc_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΞ_q, sβκΘq_forlog, q));
        C = dropdims(sum(βκB .* reshape(repeat(reshape(tensorc_pre, (n[1], 1, d, 1, D)), outer=(1, r, 1, d, 1)), (n[1], r, d*d*D)), dims=3),dims=3);
        
        Cpos .= (abs.(C) .+ C) ./ 2
        Cneg .= (abs.(C) .- C) ./ 2
        
        ccU_q .= this_U .* sqrt.((Bpos .+ Cneg)./(Bneg .+ Cpos))
        ccU_q[isnan.(ccU_q)].=0
        
        Ξ_q_coords = ccU_q * V
        Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        if iter % debug_int == 0
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("\tsubiter #%-4i | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(ccU_q-this_U))
        end
        this_U = ccU_q
        
    end

    return this_U
end

curvature_corrected_U_update_precomp! (generic function with 1 method)

In [8]:
function exact_loss(M::AbstractManifold, q, X, Ξ)
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    exp_q_Ξ =   exp.(Ref(M), Ref(q), Ξ)
    return sum(distance.(Ref(M), X, exp_q_Ξ).^2) / ref_distance
end

exact_loss (generic function with 1 method)

In [9]:
function curvature_corrected_loss(M::AbstractManifold, q, X, Ξ)
    n = size(X)
    # compute log
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k in 1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], Θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
    end
    return loss/ref_distance
end

curvature_corrected_loss (generic function with 1 method)

In [10]:
function curvature_corrected_loss_fromlog(M::AbstractManifold, q, log_q_X, Ξ)
    n = size(log_q_X)
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end   
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k=1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
        
        
    end
    return loss/ref_distance
end

curvature_corrected_loss_fromlog (generic function with 1 method)

In [11]:
function curvature_corrected_sNMF_precomp(M::AbstractPowerManifold, q, X, log_q_X, rank; ENMF_const = 0.2, ENMF_iter = 50, max_iter=50, 
        debug_freq=10, init_type = "eNMF", max_U_iter = 100, U_debug=10)
    n = size(X)
    # initialize
    if init_type == "eNMF"
        U₀, V₀ = naive_low_rank_approximation(M, q, X, rank; init_const=ENMF_const, max_iter=ENMF_iter)
    elseif init_type == "kmeans"
        # log_q_X = log.(Ref(M), Ref(q), X)
        X_eucl = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_X, Ref(DefaultOrthonormalBasis())))
        R = kmeans(X_eucl, rank)
        # d = size(X_eucl)[1]
        this_n = size(X_eucl)[2]
        G_init = zeros(this_n, rank)
        for i=1:this_n
            G_init[i,assignments(R)[i]] = 1
        end
        V₀ = R.centers' # k \times d
        replace!(x -> x == 0 ? ENMF_const : x, G_init)
        U₀ = G_init ./ sum(G_init;dims=2)
    elseif init_type == "rand"
        G_init = rand(Float64, (n[1],rank)); # TODO make this general
        U₀ = G_init ./ sum(G_init;dims=2)
        d=manifold_dimension(M);
        V₀ = randn(Float64, (rank, d)); # k \times d
    else
       throw(ArgumentError("init_type must be eNMF, kmeans, or rand"))
    end
    println("computed initialization")
    iter = 0
    Uₖ = copy(U₀)
    Vₖ = copy(V₀)
    Ξ_q_coords = Uₖ * Vₖ
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
    ex_loss = exact_loss(M, q, X, Ξ_q)
    @printf("iter #0        | CCL: %.15f | exact loss: %.15f\n", CC_loss, ex_loss)
    # Ξ_q = []
    ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    d = manifold_dimension(M.manifold);
    D = power_dimensions(M)[1];

    βκΘq = [β(ONB[j₁].data.bases[k].data.eigenvalues[j]) .* ONB[j₁].data.bases[k].data.vectors[j] for j₁=1:n[1], j=1:d, k=1:D];
    tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));
    
    tensorU = zeros(n[1],d,rank,d,D);
    tensorV = zeros(n[1],rank,d, d, D);
    log_q_X_tensor = repeat(reshape([log_q_X[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
    while iter < max_iter
        iter += 1
        Uprev = Uₖ
        Vprev = Vₖ
        Vₖ = curvature_corrected_V_update_precomp!(tensorU, M, q, X, log_q_X_tensor, Uₖ, rank, βκΘq, tensorβκΘq, tensorΨq)

        if iter % debug_freq == 0
            Ξ_q_coords = Uₖ * Vₖ
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("iter #%-4i (V) | CCL: %.15f | exact loss: %.15f | change (V): %.5f\n", iter, CC_loss, ex_loss, norm(Vₖ-Vprev))
        end

        Uprev=Uₖ
        Uₖ = curvature_corrected_U_update_precomp!(tensorV, M, q, X, log_q_X, log_q_X_tensor, Uprev, Vₖ, rank, βκΘq, tensorβκΘq, tensorΨq;iters = max_U_iter,debug_int=U_debug)
        if iter % debug_freq == 0
            Ξ_q_coords = Uₖ * Vₖ
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("iter #%-4i (U) | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(Uₖ-Uprev))
            # println("iter #$(iter) (U) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
        end

    end
    if init_type == "kmeans"
        U₀ = G_init;
    end
    return Uₖ, Vₖ, Ξ_q, U₀, V₀
end

curvature_corrected_sNMF_precomp (generic function with 1 method)

### Load data and construct manifold ###

In [17]:
ni_xx = niread("../../data/IITmean_xx.nii.gz") * 1e3;
ni_yx = niread("../../data/IITmean_yx.nii.gz") * 1e3;
ni_yy = niread("../../data/IITmean_yy.nii.gz") * 1e3;
ni_zx = niread("../../data/IITmean_zx.nii.gz") * 1e3;
ni_zy = niread("../../data/IITmean_zy.nii.gz") * 1e3;
ni_zz = niread("../../data/IITmean_zz.nii.gz") * 1e3;
size(ni_xx)

(182, 218, 182)

In [18]:
d1, d2, d3  = size(ni_xx)

(182, 218, 182)

In [19]:
predata = [ # data ordered as [xx, yx, yy, zx, zy, zz]
    [
    [ni_xx[i,j,k] + 1e-5;; ni_yx[i,j,k];; ni_zx[i,j,k]]; 
    [ni_yx[i,j,k];; ni_yy[i,j,k] + 1e-5;; ni_zy[i,j,k]]; 
    [ni_zx[i,j,k];; ni_zy[i,j,k];; ni_zz[i,j,k] + 1e-5]
    ] for i=1:d1, j=1:d2, k=1:d3];

In [20]:
function get_step_range(start, stop, step)
    n_steps = ((stop-start) + 1) / step;
    this_range = [convert(Int64,start + (i-1) * step) for i=1:(n_steps+1)]
    return this_range
end

get_step_range (generic function with 1 method)

In [21]:
# let's take 4 x 4 x 4 voxels
x1 = 101;
x2 = 112;
y1 = 151;
y2 = 162;
z1 = 51;
z2 = 102;
step = 4;
x_range = get_step_range(x1,x2,step);
y_range = get_step_range(y1,y2,step);
z_range = get_step_range(z1,z2,step);

In [22]:
predata_masked = [reshape(predata[x_range[i]:x_range[i+1]-1,y_range[j]:y_range[j+1]-1,z_range[k]:z_range[k+1]-1],
        (1,:)) for i=1:length(x_range)-1,
        j=1:length(y_range)-1,k=1:length(z_range)-1];

In [23]:
size(predata_masked)

(3, 3, 13)

In [24]:
data=vec(predata_masked);
size(data)

(117,)

In [25]:
size(data[1])

(1, 64)

In [26]:
ni_xx = nothing;
ni_yx = nothing;
ni_yy = nothing;
ni_zx = nothing;
ni_zy = nothing;
ni_zz = nothing;

In [27]:
# construct data manifold
n = size(data)[1] 
M = PowerManifold(SymmetricPositiveDefinite(3), NestedPowerRepresentation(), size(data[1])[2])
d = manifold_dimension(M);
print(d)

384

In [28]:
D = power_dimensions(M)[1];
print(D)

64

In [29]:
predata_masked=nothing;

In [30]:
predata=nothing;

In [31]:
GC.gc()

### Construct low rank approximation ###

In [32]:
q = [Matrix(I, 3, 3) .* 1e-5 for i=1:D];

In [33]:
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [34]:
offset = 0.1
CC_iter = 20;
debug_int = 1;
U_iters = 10;
U_debug_int = 3;

In [35]:
ranks=[20]
ccL = []
nU_q = []
nV_q = []
ccU_q = []
ccV_q = []
for k=ranks
    println("computing rank $(k) approximation")
    Uₖ, Vₖ, Ξ_q, U₀, V₀ = curvature_corrected_sNMF_precomp(M, q, data, log_q_data, k; 
        max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans",
        ENMF_const = offset, 
        max_U_iter=U_iters, U_debug=U_debug_int);
    push!(nU_q, U₀)
    push!(nV_q, V₀)
    push!(ccU_q,Uₖ)
    push!(ccV_q,Vₖ)
    loss_cc = exact_loss(M, q, data, Ξ_q);
    push!(ccL, loss_cc)
end

computing rank 20 approximation
computed initialization
iter #0        | CCL: 0.001284228800866 | exact loss: 0.001283499094757
iter #1    (V) | CCL: 0.000464347653054 | exact loss: 0.000463826006140 | change (V): 79.06175
	subiter #3    | CCL: 0.000434986661225 | exact loss: 0.000434471547466 | change (U): 0.00000
	subiter #6    | CCL: 0.000432197686194 | exact loss: 0.000431688720581 | change (U): 0.00000
	subiter #9    | CCL: 0.000429858761897 | exact loss: 0.000429355790111 | change (U): 0.00000
iter #1    (U) | CCL: 0.000429089080726 | exact loss: 0.000428588079903 | change (U): 0.02497
iter #2    (V) | CCL: 0.000429058296025 | exact loss: 0.000428559109645 | change (V): 0.63470
	subiter #3    | CCL: 0.000426742247065 | exact loss: 0.000426248920763 | change (U): 0.00000
	subiter #6    | CCL: 0.000424461337366 | exact loss: 0.000423973764071 | change (U): 0.00000
	subiter #9    | CCL: 0.000422214735529 | exact loss: 0.000421732806937 | change (U): 0.00000
iter #2    (U) | CCL: 0.0

### Figures (example)

In [36]:
foldername="IIT_figs"

"IIT_figs"

In [37]:
n=size(data);
for r=1:length(ranks)
    k=ranks[r]
    ccU = ccU_q[r]
    ccV = ccV_q[r]
    # current setup of semi-NMF code has V as a real-valued matrix
    ccV_TM = get_vector.(Ref(M), Ref(q), [ccV[l,:] for l=1:k], Ref(DefaultOrthonormalBasis()))
    ccU_max = maximum(ccU; dims=1)
    overall_factor_list_cc = exp.(Ref(M), Ref(q), [ccU_max[i] * ccV_TM[i] for i=1:k]) 
    # factors
    
    for j=1:k
        factor_cc_unscaled = overall_factor_list_cc[j];
        eigmaxes = opnorm.(factor_cc_unscaled, 2);
        if maximum(eigmaxes) < 1
            factor_cc = factor_cc_unscaled / maximum(eigmaxes);
        else
            factor_cc = factor_cc_unscaled
        end
        factor_2D_cc = reshape(factor_cc, (4, 4, 4));
        fig_filename_cc = @sprintf("%s/k%d/cc_factor%d.asy", foldername, k, j);
        asymptote_export_SPD(fig_filename_cc, data=factor_2D_cc, scale_axes=(1.5,1.5,1.5), camera_position=(-2., 6., 14.)); 
        render_asymptote(fig_filename_cc)     
    end
end