# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using BenchmarkTools

using Clustering
using NIfTI

using JLD
using Printf

In [2]:
using StatsBase

In [3]:
# include("../../../src/functions/loss_functions/curvature_corrected_loss.jl")
# include("../../../src/functions/loss_functions/exact_loss.jl")
include("../../../src/functions/jacobi_field/beta.jl")

β (generic function with 1 method)

### Functions

In [4]:
function SNMF_euclidean(X, k, init_const, max_iter)
    # in this function X is d \times n
    # F is d \times k, G is n \times k
    R = kmeans(X, k)
    d = size(X)[1]
    n = size(X)[2]
    G_init = zeros(n, k)
    for i=1:n
        G_init[i,assignments(R)[i]] = 1
    end
    iter = 0
    G = G_init .+ init_const
    F = zeros(d, k)
    while iter < max_iter
        F = X * G * pinv(G' * G)
        A = X' * F
        B = F' * F
        Apos = (abs.(A) + A) ./ 2
        Aneg = (abs.(A) - A) ./ 2
        Bpos = (abs.(B) + B) ./ 2
        Bneg = (abs.(B) - B) ./ 2
        G = G .* sqrt.((Apos + G * Bneg)./(Aneg + G * Bpos))
        G[isnan.(G)].=0
        iter += 1
    end
    return F, G
end

SNMF_euclidean (generic function with 1 method)

In [5]:
function naive_low_rank_approximation(M::AbstractManifold, q, data, k; init_const = 0.2, max_iter=50)
    log_q_data = log.(Ref(M), Ref(q), data);
    X_eucl = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_data, Ref(DefaultOrthonormalBasis())))
    V, U = SNMF_euclidean(X_eucl, k, init_const, max_iter);
    # Vq = get_vector.(Ref(M), Ref(q),[V[:,l] for l=1:k], Ref(DefaultOrthonormalBasis()))
    return U, V'
end

naive_low_rank_approximation (generic function with 1 method)

In [6]:
function curvature_corrected_low_rank_approximation(M::AbstractManifold, q, X, rank; ENMF_const = 0.2, ENMF_iter=50)
    n = size(X)
    d = manifold_dimension(M)
    r = min(n[1], d, rank)

    # compute initialisation 
    U, Vq = naive_low_rank_approximation(M, q, X, r; init_const=ENMF_const, max_iter=ENMF_iter) 
    # construct linear system
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    
    # construct matrix βκB
    tensorU = repeat(reshape(U, (n[1],1,r,1)), outer=(1,d,1,d))
    
    tensorΨq = repeat(reshape([get_vector(M, q, Matrix(I, d, d)[:,k], DefaultOrthonormalBasis()) for k in 1:d], (1,1,1,d)), outer=(n[1],d,r,1)) 
    
    ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    βκΘq = [β(ONB[j₁].data.eigenvalues[j] * (typeof(M) <: AbstractSphere ? distance(M, q, X[j₁])^2 : 1.)) .* ONB[j₁].data.vectors[j] for j₁=1:n[1], j=1:d]
    tensorβκΘq =  repeat(reshape(βκΘq, (n[1],d,1,1)), outer=(1,1,r,d))
    
    tensorβκB = tensorU .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)
    
    βκB = reshape(tensorβκB, (n[1] * d, r * d))

    # construct matrix A
    A = transpose(βκB) * βκB

    # construct vector βκBb
    tensorb = inner.(Ref(M), Ref(q), repeat(reshape(log_q_X, (n[1],1)), outer=(1,d)), βκΘq)
    b = reshape(tensorb, (n[1] * d))
    βκBb = transpose(βκB) * b

    # solve linear system
    Vₖₗ = A\βκBb
    tensorVₖₗ = reshape(Vₖₗ, (r, d))

    # get ccRr_q
    ccV_q = get_vector.(Ref(M), Ref(q),[tensorVₖₗ[l,:] for l=1:r], Ref(DefaultOrthonormalBasis()))
    return ccV_q, U
end

curvature_corrected_low_rank_approximation (generic function with 1 method)

In [7]:
function curvature_corrected_V_update(M::AbstractManifold, q, X, U, rank)
    n = size(X)
    d = manifold_dimension(M)
    # r = min(n[1], d, rank)
    r = min(n[1], rank)
    
    # construct linear system
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    
    # construct matrix βκB
    tensorU = repeat(reshape(U, (n[1],1,r,1)), outer=(1,d,1,d))
    
    tensorΨq = repeat(reshape([get_vector(M, q, Matrix(I, d, d)[:,k], DefaultOrthonormalBasis()) for k in 1:d], (1,1,1,d)), outer=(n[1],d,r,1)) 
    
    ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    βκΘq = [β(ONB[j₁].data.eigenvalues[j] * (typeof(M) <: AbstractSphere ? distance(M, q, X[j₁])^2 : 1.)) .* ONB[j₁].data.vectors[j] for j₁=1:n[1], j=1:d]
    tensorβκΘq =  repeat(reshape(βκΘq, (n[1],d,1,1)), outer=(1,1,r,d))
    
    tensorβκB = tensorU .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)
    
    βκB = reshape(tensorβκB, (n[1] * d, r * d))

    # construct matrix A
    A = transpose(βκB) * βκB

    # construct vector βκBb
    tensorb = inner.(Ref(M), Ref(q), repeat(reshape(log_q_X, (n[1],1)), outer=(1,d)), βκΘq)
    b = reshape(tensorb, (n[1] * d))
    βκBb = transpose(βκB) * b

    # solve linear system
    Vₖₗ = A\βκBb
    tensorVₖₗ = reshape(Vₖₗ, (r, d))

    # get ccRr_q
    # ccV_q = get_vector.(Ref(M), Ref(q),[tensorVₖₗ[l,:] for l=1:r], Ref(DefaultOrthonormalBasis()))
    return tensorVₖₗ
end

curvature_corrected_V_update (generic function with 1 method)

In [8]:
function curvature_corrected_U_update(M::AbstractManifold, q, X, U, V, rank; iters = 100, debug_int = 10)
    n = size(X)
    d = manifold_dimension(M)
    # r = min(n[1], d, rank)
    r = min(n[1], rank)
    # construct linear system
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    # Ξ_q_coords = U * V
    # Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    # construct matrix βκB
    tensorV = repeat(reshape(V, (1, r, 1, d)), outer=(n[1],1,d,1)) # this is j1, i1, j, i
    tensorΨq = repeat(reshape([get_vector(M, q, Matrix(I, d, d)[:,k], DefaultOrthonormalBasis()) for k in 1:d], (1,1,1,d)), outer=(n[1],r,d,1)) # this is j1, i1, j, i
    
    ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    βκΘq = [β(ONB[j₁].data.eigenvalues[j] * (typeof(M) <: AbstractSphere ? distance(M, q, X[j₁])^2 : 1.)) .* ONB[j₁].data.vectors[j] for j₁=1:n[1], j=1:d]
    tensorβκΘq =  repeat(reshape(βκΘq, (n[1],1,d,1)), outer=(1,r,1,d)) # this is j1, i1, j, i
    
    tensorβκB = tensorV .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)
    
    βκB = reshape(tensorβκB, (n[1], r, d * d)) # this is j1, i1, j, i
    # actually we want dimensions of (n[1] * r, d * d), i.e. j1, i1, j, i

    # construct numerator
    tensorb = inner.(Ref(M), Ref(q), repeat(reshape(log_q_X, (n[1],1)), outer=(1,d)), βκΘq)
    b = reshape(repeat(reshape(tensorb, (n[1], 1, d, 1)), outer=(1, r, 1, d)), (n[1], r, d * d))
    B = dropdims(sum(βκB .* b, dims=3),dims=3)
    
    Bpos = (abs.(B) + B) ./ 2
    Bneg = (abs.(B) - B) ./ 2
    # reshape to (n, 1, d, 1), expand to (n, r, d, d), reshape to (n, r, d * d) which is (j1, i1, j & i)
    # then want (essentially) dot product of betakappaB and this along third dimension

    # construct denominator
    iter = 0
    this_U = U
    while iter < iters
        iter += 1
        Ξ_q_coords = this_U * V
        Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        tensorc = inner.(Ref(M), Ref(q), repeat(reshape(Ξ_q, (n[1], 1)), outer=(1,d)), βκΘq)
        c = reshape(repeat(reshape(tensorc, (n[1], 1, d, 1)), outer=(1, r, 1, d)), (n[1], r, d * d))
        C = dropdims(sum(βκB .* c, dims=3),dims=3)
        
        Cpos = (abs.(C) + C) ./ 2
        Cneg = (abs.(C) - C) ./ 2
        
        ccU_q = this_U .* sqrt.((Bpos + Cneg)./(Bneg + Cpos))
        ccU_q[isnan.(ccU_q)].=0
        
        if iter % debug_int == 0
            Ξ_q_coords = ccU_q * V
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("\tsubiter #%-4i | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(ccU_q-this_U))
        end
        this_U = ccU_q
    end
    
#     tensorc = inner.(Ref(M), Ref(q), repeat(reshape(Ξ_q, (n[1], 1)), outer=(1,d)), βκΘq)
#     c = reshape(repeat(reshape(tensorc, (n[1], 1, d, 1)), outer=(1, r, 1, d)), (n[1], r, d * d))
#     C = dropdims(sum(βκB .* c, dims=3),dims=3)
    
#     Bpos = (abs.(B) + B) ./ 2
#     Bneg = (abs.(B) - B) ./ 2
#     Cpos = (abs.(C) + C) ./ 2
#     Cneg = (abs.(C) - C) ./ 2
    
#     ccU_q = U .* sqrt.((Bpos + Cneg)./(Bneg + Bpos))
#     # ccU_q = U .* (Bpos + Cneg)./(Bneg + Bpos)
#     ccU_q[isnan.(ccU_q)].=0
    return this_U
end

curvature_corrected_U_update (generic function with 1 method)

In [9]:
function curvature_corrected_V_update_precomp_ineff(M::AbstractPowerManifold, q, X, log_q_X, U, rank, βκΘq)
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], d, rank)
    # r = min(n[1], rank)
    r = rank
    
    # construct linear system

    tensorU = repeat(reshape(U, (n[1],1,r,1)), outer=(1,d*D,1,d*D))
    # Ψ_dict = [get_vector(M.manifold, q[k], Matrix(I, d, d)[:,i], DefaultOrthonormalBasis()) for i=1:d, k=1:D]; # d by D
    Ψ_list = [get_vector(M, q, Matrix(I, d*D, d*D)[:,i], DefaultOrthonormalBasis()) for i=1:d*D];
    # each element of the list is a vector with D elements
    # the list has d*D elements
    # note that the list fixes the D index, then exhausts the d index
    # for fixed D index, all other elements of that basis element are zero
    # we want the tensor of \psi_q to be n x d x r x (d * D), but maybe we hold off on collapsing down the last dimension until later?
    # this corresponds to j1, j, i1, (i,k) in ipad notation
    tensorΨq = repeat(reshape(Ψ_list, (1, 1, 1, d*D)), outer=(n[1], d*D, r, 1))
    
    # βκΘq = [β(ONB[j₁].data.bases[k].data.eigenvalues[j]) .* ONB[j₁].data.bases[k].data.vectors[j] for j₁=1:n[1], j=1:d, k=1:D];
    # each element of fullβκΘq should be a vector with D elements
    # each element of βκΘq is a single point in TpM
    pt = [zeros(3,3) for i=1:D];
    fullβκΘq = Array{Vector{Matrix{Float64}}}(undef, n[1],d*D);
    for j₁ = 1:n[1]
        for j=1:d
            for k=1:D
                this_pt = [zeros(3,3) for i=1:D];
                this_pt[k] = βκΘq[j₁,j,k]
                # so only kth element is nonzero
                fullβκΘq[j₁,(k-1)*d + j] = this_pt;
            end
        end
    end

    tensorβκΘq =  repeat(reshape(fullβκΘq, (n[1],d*D,1,1)), outer=(1,1,r,d*D)) # total dimensions: n, d*D, r, d*D, D with idx j1, j, i1, i

    tensorβκB = tensorU .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq);
    βκB = reshape(tensorβκB, (n[1] * d * D, r * d * D))
    # βκB = reshape(tensorβκB, (n[1] * d, r * d))

    # construct matrix A
    A = transpose(βκB) * βκB
    # construct vector βκBb
    tensorb = inner.(Ref(M), Ref(q), repeat(reshape(log_q_X, (n[1],1)), outer=(1,d*D)), fullβκΘq)

    b = reshape(tensorb, (n[1] * d *D))
    βκBb = transpose(βκB) * b

    # solve linear system
    Vₖₗ = A\βκBb
    tensorVₖₗ = reshape(Vₖₗ, (r, d * D))

    # get ccRr_q
    # ccV_q = get_vector.(Ref(M), Ref(q),[tensorVₖₗ[l,:] for l=1:r], Ref(DefaultOrthonormalBasis()))
    return tensorVₖₗ
end

curvature_corrected_V_update_precomp_ineff (generic function with 1 method)

In [10]:
function curvature_corrected_U_update_precomp_ineff(M::AbstractPowerManifold, q, X, log_q_X, U, V, rank, βκΘq; iters = 100, debug_int = 10)
    
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], d, rank)
    # r = min(n[1], rank)
    r = rank
    # construct linear system

    # construct matrix βκB
    tensorV = repeat(reshape(V, (1, r, 1, d*D)), outer=(n[1],1,d*D,1)) # this is j1, i1, j, i
    Ψ_list = [get_vector(M, q, Matrix(I, d*D, d*D)[:,i], DefaultOrthonormalBasis()) for i=1:d*D];
    tensorΨq = repeat(reshape(Ψ_list, (1, 1, 1, d*D)), outer=(n[1], r, d*D, 1))
    
    # ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    # βκΘq = [β(ONB[j₁].data.bases[k].data.eigenvalues[j]) .* ONB[j₁].data.bases[k].data.vectors[j] for j₁=1:n[1], j=1:d, k=1:D];
    # each element of fullβκΘq should be a vector with D elements
    # each element of βκΘq is a single point in TpM
    # tensorβκΘq =  repeat(reshape(βκΘq, (n[1],1,d,1)), outer=(1,r,1,d)) # this is j1, i1, j, i
    
    pt = [zeros(3,3) for i=1:D];
    fullβκΘq = Array{Vector{Matrix{Float64}}}(undef, n[1],d*D);
    for j₁ = 1:n[1]
        for j=1:d
            for k=1:D
                this_pt = [zeros(3,3) for i=1:D];
                this_pt[k] = βκΘq[j₁,j,k]
                # so only kth element is nonzero
                fullβκΘq[j₁,(k-1)*d + j] = this_pt;
            end
        end
    end

    tensorβκΘq =  repeat(reshape(fullβκΘq, (n[1],1,d*D,1)), outer=(1,r,1,d*D)) # total dimensions: n, d*D, r, d*D, D with idx j1, i1, j, i  
    tensorβκB = tensorV .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)   
    βκB = reshape(tensorβκB, (n[1], r, d*D * d*D)) # this is j1, i1, j, i
    
    # construct numerator
    tensorb = inner.(Ref(M), Ref(q), repeat(reshape(log_q_X, (n[1],1)), outer=(1,d*D)), fullβκΘq)
    b = reshape(repeat(reshape(tensorb, (n[1], 1, d*D, 1)), outer=(1, r, 1, d*D)), (n[1], r, d *D* d*D))
    
    B = dropdims(sum(βκB .* b, dims=3),dims=3)
    
    Bpos = (abs.(B) + B) ./ 2
    Bneg = (abs.(B) - B) ./ 2
    # reshape to (n, 1, d, 1), expand to (n, r, d, d), reshape to (n, r, d * d) which is (j1, i1, j & i)
    # then want (essentially) dot product of betakappaB and this along third dimension

    # construct denominator
    iter = 0
    this_U = U
    while iter < iters
        iter += 1
        Ξ_q_coords = this_U * V
        Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        tensorc = inner.(Ref(M), Ref(q), repeat(reshape(Ξ_q, (n[1], 1)), outer=(1,d*D)), fullβκΘq)
        c = reshape(repeat(reshape(tensorc, (n[1], 1, d*D, 1)), outer=(1, r, 1, d*D)), (n[1], r, d * D*d*D))
        C = dropdims(sum(βκB .* c, dims=3),dims=3)
        
        Cpos = (abs.(C) + C) ./ 2
        Cneg = (abs.(C) - C) ./ 2
        
        ccU_q = this_U .* sqrt.((Bpos + Cneg)./(Bneg + Cpos))
        ccU_q[isnan.(ccU_q)].=0
        
        if iter % debug_int == 0
            Ξ_q_coords = ccU_q * V
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("\tsubiter #%-4i | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(ccU_q-this_U))
        end
        this_U = ccU_q
    end

    return this_U
end

curvature_corrected_U_update_precomp_ineff (generic function with 1 method)

In [11]:
function curvature_corrected_V_update_precomp(M::AbstractPowerManifold, q, X, log_q_X, U, rank, βκΘq, tensorβκΘq, tensorΨq)
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], d, rank)
    # r = min(n[1], rank)
    r = rank
    
    # construct linear system

    tensorU = repeat(reshape(U, (n[1],1,r,1,1)), outer=(1,d,1,d,D))
    # Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    # tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));

    # tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorU .* repeat(reshape(pdts,(n[1],d,1,d,D)),outer=(1,1,r,1,1));

    tensorβκB_exp = reshape(tensorβκB,(n[1],d,1,r,d,D));
    tensorβκB_alt = zeros(n[1],d,D,r,d,D);
    for k=1:D
        tensorβκB_alt[:,:,k,:,:,k] = @view tensorβκB_exp[:,:,1,:,:,k];
    end

    βκB = reshape(tensorβκB_alt,(n[1]*d*D,r*d*D));

    A = transpose(βκB) * βκB;

    log_q_X_expanded = [log_q_X[j₁][k] for j₁=1:n[1],k=1:D];
    log_q_X_tensor = repeat(reshape(log_q_X_expanded, (n[1],1,D)),outer=(1,d,1));
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(tensorb_pre,(n[1] * d * D));
    βκBb = transpose(βκB) * b;

    # solve linear system
    Vₖₗ = A\βκBb
    tensorVₖₗ = reshape(Vₖₗ, (r, d * D));
    return tensorVₖₗ
end

curvature_corrected_V_update_precomp (generic function with 1 method)

In [12]:
function curvature_corrected_V_update_precomp!(tensorU, M::AbstractPowerManifold, q, X, log_q_X_tensor, U, rank, βκΘq, tensorβκΘq, tensorΨq)
    # U comes in n[1] by r
    # tensorU is n[1], d, r, d, D
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], d, rank)
    # r = min(n[1], rank)
    r = rank
    
    # construct linear system

    # tensorU = repeat(reshape(U, (n[1],1,r,1,1)), outer=(1,d,1,d,D))
    # loop through dimension 2 of U, then 1, and fill with repeat?
    for col=1:r
        for row=1:n[1]
            tensorU[row,:,col,:,:].=view(U,row,col)
        end
    end
    
    # Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    # tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));

    # tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorU .* repeat(reshape(pdts,(n[1],d,1,d,D)),outer=(1,1,r,1,1));

    tensorβκB_exp = reshape(tensorβκB,(n[1],d,1,r,d,D));
    tensorβκB_alt = zeros(n[1],d,D,r,d,D);
    for k=1:D
        tensorβκB_alt[:,:,k,:,:,k] .= @view tensorβκB_exp[:,:,1,:,:,k];
    end

    βκB = reshape(tensorβκB_alt,(n[1]*d*D,r*d*D));

    A = transpose(βκB) * βκB;

    # log_q_X_expanded = [log_q_X[j₁][k] for j₁=1:n[1],k=1:D];
    # log_q_X_tensor = repeat(reshape(log_q_X_expanded, (n[1],1,D)),outer=(1,d,1));
    
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(tensorb_pre,(n[1] * d * D));
    βκBb = transpose(βκB) * b;

    # solve linear system
    Vₖₗ = A\βκBb # vector of length r * d * D
    tensorVₖₗ = reshape(Vₖₗ, (r, d * D));
    return tensorVₖₗ
end

curvature_corrected_V_update_precomp! (generic function with 1 method)

In [13]:
function curvature_corrected_U_update_precomp(M::AbstractPowerManifold, q, X, log_q_X, U, V, 
        rank, βκΘq, tensorβκΘq, tensorΨq; iters = 100, debug_int = 10)
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], rank)
    r = rank
    # construct linear system

    # construct matrix βκB
    tensorV = repeat(reshape(V, (1,r,1,d,D)), outer=(n[1],1,d,1,1)); #n, d, r, d, D
    # Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    # tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));

    # tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    
    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorV .* repeat(reshape(pdts,(n[1],1,d,d,D)),outer=(1,r,1,1,1));

    βκB = reshape(tensorβκB,(n[1],r,d*d*D));
    
    # tensorβκB = tensorV .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)   
    # βκB = reshape(tensorβκB, (n[1], r, d*D * d*D)) # this is j1, i1, j, i
    
    # construct numerator
    # log_q_X_expanded = [log_q_X[j₁][k] for j₁=1:n[1],k=1:D];
    # log_q_X_tensor = repeat(reshape(log_q_X_expanded, (n[1],1,D)),outer=(1,d,1));
    log_q_X_tensor = repeat(reshape([log_q_X[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(repeat(reshape(tensorb_pre,(n[1], 1, d, 1, D)), outer=(1,r,1,d,1)),(n[1],r,d*d*D));

    B = dropdims(sum(βκB .* b, dims=3),dims=3);
    
    Bpos = (abs.(B) .+ B) ./ 2
    Bneg = (abs.(B) .- B) ./ 2
    # reshape to (n, 1, d, 1), expand to (n, r, d, d), reshape to (n, r, d * d) which is (j1, i1, j & i)
    # then want (essentially) dot product of betakappaB and this along third dimension

    # construct denominator
    iter = 0
    this_U = U
    Ξ_q_coords = this_U * V
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    Cpos = similar(Bpos);
    Cneg = similar(Bneg);
    ccU_q = similar(this_U);
    while iter < iters
        iter += 1
        # Ξ_q_coords = this_U * V
        # Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        # Ξ_q_expanded = [Ξ_q[j₁][k] for j₁=1:n[1],k=1:D];
        # Ξ_q_tensor = repeat(reshape(Ξ_q_expanded, (n[1],1,D)),outer=(1,d,1));
        Ξ_q_tensor = repeat(reshape([Ξ_q[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
        sΞ_q = eachslice(Ξ_q_tensor,dims=3);
        sβκΘq_forlog = eachslice(βκΘq,dims=3);
        tensorc_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΞ_q, sβκΘq_forlog, q));
        # c = reshape(repeat(reshape(tensorc_pre, (n[1], 1, d, 1, D)), outer=(1, r, 1, d, 1)), (n[1], r, d*d*D))
        # C = dropdims(sum(βκB .* c, dims=3),dims=3);
        C = dropdims(sum(βκB .* reshape(repeat(reshape(tensorc_pre, (n[1], 1, d, 1, D)), outer=(1, r, 1, d, 1)), (n[1], r, d*d*D)), dims=3),dims=3);
        
        Cpos .= (abs.(C) .+ C) ./ 2
        Cneg .= (abs.(C) .- C) ./ 2
        
        ccU_q .= this_U .* sqrt.((Bpos .+ Cneg)./(Bneg .+ Cpos))
        ccU_q[isnan.(ccU_q)].=0
        
        Ξ_q_coords = ccU_q * V
        Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        if iter % debug_int == 0
            # Ξ_q_coords = ccU_q * V
            # Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("\tsubiter #%-4i | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(ccU_q-this_U))
        end
        this_U = ccU_q
        
    end

    return this_U
end

curvature_corrected_U_update_precomp (generic function with 1 method)

In [14]:
function curvature_corrected_U_update_precomp!(tensorV, M::AbstractPowerManifold, q, X, log_q_X, log_q_X_tensor, U, V, 
        rank, βκΘq, tensorβκΘq, tensorΨq; iters = 100, debug_int = 10)
    # V is r by (d * D)
    n = size(X)
    d = manifold_dimension(M.manifold)
    D = power_dimensions(M)[1]
    # r = min(n[1], rank)
    r = rank
    # construct linear system

    # construct matrix βκB
    # tensorV = repeat(reshape(V, (1,r,1,d,D)), outer=(n[1],1,d,1,1)); #n, r, d, d, D
    V_reshape = reshape(V, (r,d,D));
    for i1=1:r
        for i2=1:d
            for i3=1:D
                tensorV[:,i1,:,i2,i3].=view(V_reshape,i1,i2,i3);
            end
        end
    end   
    # Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    # tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));

    # tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    
    sΨq = eachslice(tensorΨq,dims=4);
    sβκΘq = eachslice(tensorβκΘq,dims=4);
    pdts = reduce((x,y) -> cat(x,y,dims=4), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΨq, sβκΘq, q));

    tensorβκB = tensorV .* repeat(reshape(pdts,(n[1],1,d,d,D)),outer=(1,r,1,1,1));

    βκB = reshape(tensorβκB,(n[1],r,d*d*D));
    
    # tensorβκB = tensorV .* inner.(Ref(M), Ref(q), tensorΨq, tensorβκΘq)   
    # βκB = reshape(tensorβκB, (n[1], r, d*D * d*D)) # this is j1, i1, j, i
    
    # construct numerator
    # log_q_X_expanded = [log_q_X[j₁][k] for j₁=1:n[1],k=1:D];
    # log_q_X_tensor = repeat(reshape(log_q_X_expanded, (n[1],1,D)),outer=(1,d,1));
    # log_q_X_tensor = repeat(reshape([log_q_X[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
    slogX = eachslice(log_q_X_tensor,dims=3);
    sβκΘq_forlog = eachslice(βκΘq,dims=3);
    tensorb_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), slogX, sβκΘq_forlog, q));
    b = reshape(repeat(reshape(tensorb_pre,(n[1], 1, d, 1, D)), outer=(1,r,1,d,1)),(n[1],r,d*d*D));

    B = dropdims(sum(βκB .* b, dims=3),dims=3);
    
    Bpos = (abs.(B) .+ B) ./ 2
    Bneg = (abs.(B) .- B) ./ 2
    # reshape to (n, 1, d, 1), expand to (n, r, d, d), reshape to (n, r, d * d) which is (j1, i1, j & i)
    # then want (essentially) dot product of betakappaB and this along third dimension

    # construct denominator
    iter = 0
    this_U = U
    Ξ_q_coords = this_U * V
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    Cpos = similar(Bpos);
    Cneg = similar(Bneg);
    ccU_q = similar(this_U);
    while iter < iters
        iter += 1
        # Ξ_q_coords = this_U * V
        # Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        # Ξ_q_expanded = [Ξ_q[j₁][k] for j₁=1:n[1],k=1:D];
        # Ξ_q_tensor = repeat(reshape(Ξ_q_expanded, (n[1],1,D)),outer=(1,d,1));
        Ξ_q_tensor = repeat(reshape([Ξ_q[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
        sΞ_q = eachslice(Ξ_q_tensor,dims=3);
        sβκΘq_forlog = eachslice(βκΘq,dims=3);
        tensorc_pre = reduce((x,y) -> cat(x,y,dims=3), map((x,y,z) -> inner.(Ref(M.manifold),Ref(z), x, y), sΞ_q, sβκΘq_forlog, q));
        # c = reshape(repeat(reshape(tensorc_pre, (n[1], 1, d, 1, D)), outer=(1, r, 1, d, 1)), (n[1], r, d*d*D))
        # C = dropdims(sum(βκB .* c, dims=3),dims=3);
        C = dropdims(sum(βκB .* reshape(repeat(reshape(tensorc_pre, (n[1], 1, d, 1, D)), outer=(1, r, 1, d, 1)), (n[1], r, d*d*D)), dims=3),dims=3);
        
        Cpos .= (abs.(C) .+ C) ./ 2
        Cneg .= (abs.(C) .- C) ./ 2
        
        ccU_q .= this_U .* sqrt.((Bpos .+ Cneg)./(Bneg .+ Cpos))
        ccU_q[isnan.(ccU_q)].=0
        
        Ξ_q_coords = ccU_q * V
        Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
        if iter % debug_int == 0
            # Ξ_q_coords = ccU_q * V
            # Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("\tsubiter #%-4i | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(ccU_q-this_U))
        end
        this_U = ccU_q
        
    end

    return this_U
end

curvature_corrected_U_update_precomp! (generic function with 1 method)

In [15]:
function exact_loss(M::AbstractManifold, q, X, Ξ)
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)

    exp_q_Ξ =   exp.(Ref(M), Ref(q), Ξ)
    return sum(distance.(Ref(M), X, exp_q_Ξ).^2) / ref_distance
end

exact_loss (generic function with 1 method)

In [16]:
function curvature_corrected_loss(M::AbstractManifold, q, X, Ξ)
    n = size(X)
    # d = manifold_dimension(M)
    # compute log
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k in 1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], Θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues

            # if typeof(M) <: AbstractSphere # bug in Manifolds.jl
            #     κᵢ .*= distance(M, q, X[i])^2
            # end
            # compute loss
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
        
    
    end
    return loss/ref_distance
end

curvature_corrected_loss (generic function with 1 method)

In [17]:
function curvature_corrected_loss_fromlog(M::AbstractManifold, q, log_q_X, Ξ)
    n = size(log_q_X)
    # d = manifold_dimension(M)
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end   
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k=1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues

            # if typeof(M) <: AbstractSphere # bug in Manifolds.jl
            #     κᵢ .*= distance(M, q, X[i])^2
            # end
            # compute loss
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
        
        
    end
    return loss/ref_distance
end

curvature_corrected_loss_fromlog (generic function with 1 method)

In [18]:
function curvature_corrected_sNMF(M::AbstractManifold, q, X, rank; ENMF_const = 0.2, ENMF_iter = 50, max_iter=50, debug_freq=10, 
        init_type = "eNMF", first_factor = "V", max_U_iter = 100, U_debug=10)
    n = size(X)
    # initialize
    if init_type == "eNMF"
        U₀, V₀ = naive_low_rank_approximation(M, q, X, rank; init_const=ENMF_const, max_iter=ENMF_iter)
    elseif init_type == "kmeans"
        log_q_X = log.(Ref(M), Ref(q), X)
        X_eucl = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_X, Ref(DefaultOrthonormalBasis())))
        R = kmeans(X_eucl, rank)
        d = size(X_eucl)[1]
        this_n = size(X_eucl)[2]
        G_init = zeros(this_n, rank)
        for i=1:this_n
            G_init[i,assignments(R)[i]] = 1
        end
        V₀ = R.centers' # k \times d
        replace!(x -> x == 0 ? ENMF_const : x, G_init)
        U₀ = G_init ./ sum(G_init;dims=2)
    elseif init_type == "rand"
        G_init = rand(Float64, (n[1],rank)); # TODO make this general
        U₀ = G_init ./ sum(G_init;dims=2)
        d=manifold_dimension(M);
        V₀ = randn(Float64, (rank, d)); # k \times d
        
    else
       throw(ArgumentError("init_type must be eNMF, kmeans, or rand"))
    end
    println("computed initialization")
    iter = 0
    Uₖ = copy(U₀)
    Vₖ = copy(V₀)
    Ξ_q_coords = Uₖ * Vₖ
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
    ex_loss = exact_loss(M, q, X, Ξ_q)
    @printf("iter #0        | CCL: %.15f | exact loss: %.15f\n", CC_loss, ex_loss)
    # Ξ_q = []
    while iter < max_iter
        iter += 1
        Uprev = Uₖ
        Vprev = Vₖ
        U_iter = 0
        if first_factor == "V" 
            Vₖ = curvature_corrected_V_update(M, q, X, Uₖ, rank)
            if iter % debug_freq == 0
                Ξ_q_coords = Uₖ * Vₖ
                Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
                CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
                ex_loss = exact_loss(M, q, X, Ξ_q)
                @printf("iter #%-4i (V) | CCL: %.15f | exact loss: %.15f | change (V): %.5f\n", iter, CC_loss, ex_loss, norm(Vₖ-Vprev))
                # println("iter #$(iter) (V) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
            end

            Uprev=Uₖ
            Uₖ = curvature_corrected_U_update(M, q, X, Uprev, Vₖ, rank;iters = max_U_iter,debug_int=U_debug)
            if iter % debug_freq == 0
                Ξ_q_coords = Uₖ * Vₖ
                Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
                CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
                ex_loss = exact_loss(M, q, X, Ξ_q)
                @printf("iter #%-4i (U) | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(Uₖ-Uprev))
                # println("iter #$(iter) (U) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
            end
            # U_iter += 1

        else
            Uₖ = curvature_corrected_U_update(M, q, X, Uprev, Vₖ, rank)
            if iter % debug_freq == 0
                Ξ_q_coords = Uₖ * Vₖ
                Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
                CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
                ex_loss = exact_loss(M, q, X, Ξ_q)
                @printf("iter #%-4i (U) | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(Uₖ-Uprev))
                # println("iter #$(iter) (U) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
            end
            Vₖ = curvature_corrected_V_update(M, q, X, Uₖ, rank)
            if iter % debug_freq == 0
                Ξ_q_coords = Uₖ * Vₖ
                Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
                CC_loss = curvature_corrected_loss(M, q, X, Ξ_q)
                ex_loss = exact_loss(M, q, X, Ξ_q)
                @printf("iter #%-4i (V) | CCL: %.15f | exact loss: %.15f | change (V): %.5f\n", iter, CC_loss, ex_loss, norm(Vₖ-Vprev))
                # println("iter #$(iter) (V) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
            end
        end
    end
    if init_type == "kmeans"
        U₀ = G_init;
    end
    return Uₖ, Vₖ, Ξ_q, U₀, V₀
end

curvature_corrected_sNMF (generic function with 1 method)

In [19]:
function curvature_corrected_sNMF_precomp(M::AbstractPowerManifold, q, X, log_q_X, rank; ENMF_const = 0.2, ENMF_iter = 50, max_iter=50, 
        debug_freq=10, init_type = "eNMF", max_U_iter = 100, U_debug=10)
    n = size(X)
    # initialize
    if init_type == "eNMF"
        U₀, V₀ = naive_low_rank_approximation(M, q, X, rank; init_const=ENMF_const, max_iter=ENMF_iter)
    elseif init_type == "kmeans"
        # log_q_X = log.(Ref(M), Ref(q), X)
        X_eucl = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_X, Ref(DefaultOrthonormalBasis())))
        R = kmeans(X_eucl, rank)
        # d = size(X_eucl)[1]
        this_n = size(X_eucl)[2]
        G_init = zeros(this_n, rank)
        for i=1:this_n
            G_init[i,assignments(R)[i]] = 1
        end
        V₀ = R.centers' # k \times d
        replace!(x -> x == 0 ? ENMF_const : x, G_init)
        U₀ = G_init ./ sum(G_init;dims=2)
    elseif init_type == "rand"
        G_init = rand(Float64, (n[1],rank)); # TODO make this general
        U₀ = G_init ./ sum(G_init;dims=2)
        d=manifold_dimension(M);
        V₀ = randn(Float64, (rank, d)); # k \times d
    else
       throw(ArgumentError("init_type must be eNMF, kmeans, or rand"))
    end
    println("computed initialization")
    iter = 0
    Uₖ = copy(U₀)
    Vₖ = copy(V₀)
    Ξ_q_coords = Uₖ * Vₖ
    Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
    ex_loss = exact_loss(M, q, X, Ξ_q)
    @printf("iter #0        | CCL: %.15f | exact loss: %.15f\n", CC_loss, ex_loss)
    # Ξ_q = []
    ONB = get_basis.(Ref(M), Ref(q), DiagonalizingOrthonormalBasis.(log_q_X))
    d = manifold_dimension(M.manifold);
    D = power_dimensions(M)[1];

    βκΘq = [β(ONB[j₁].data.bases[k].data.eigenvalues[j]) .* ONB[j₁].data.bases[k].data.vectors[j] for j₁=1:n[1], j=1:d, k=1:D];
    tensorβκΘq = repeat(reshape(βκΘq,(n[1],d,1,D)), outer=(1,1,d,1));

    Ψ_mat = [get_vector(M, q, Matrix(I, d*D, d*D)[:,(k-1)*d + j], DefaultOrthonormalBasis())[k] for j=1:d, k=1:D];
    tensorΨq = repeat(reshape(Ψ_mat, (1, 1, d,D)), outer=(n[1], d, 1,1));
    
    tensorU = zeros(n[1],d,rank,d,D);
    tensorV = zeros(n[1],rank,d, d, D);
    log_q_X_tensor = repeat(reshape([log_q_X[j₁][k] for j₁=1:n[1],k=1:D], (n[1],1,D)),outer=(1,d,1));
    while iter < max_iter
        iter += 1
        Uprev = Uₖ
        Vprev = Vₖ

        # Vₖ = curvature_corrected_V_update_precomp(M, q, X, log_q_X, Uₖ, rank, βκΘq, tensorβκΘq, tensorΨq)
        Vₖ = curvature_corrected_V_update_precomp!(tensorU, M, q, X, log_q_X_tensor, Uₖ, rank, βκΘq, tensorβκΘq, tensorΨq)

        if iter % debug_freq == 0
            Ξ_q_coords = Uₖ * Vₖ
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("iter #%-4i (V) | CCL: %.15f | exact loss: %.15f | change (V): %.5f\n", iter, CC_loss, ex_loss, norm(Vₖ-Vprev))
            # println("iter #$(iter) (V) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
        end

        Uprev=Uₖ
        # Uₖ = curvature_corrected_U_update_precomp(M, q, X, log_q_X, Uprev, Vₖ, rank, βκΘq, tensorβκΘq, tensorΨq;iters = max_U_iter,debug_int=U_debug)
        Uₖ = curvature_corrected_U_update_precomp!(tensorV, M, q, X, log_q_X, log_q_X_tensor, Uprev, Vₖ, rank, βκΘq, tensorβκΘq, tensorΨq;iters = max_U_iter,debug_int=U_debug)
        if iter % debug_freq == 0
            Ξ_q_coords = Uₖ * Vₖ
            Ξ_q = get_vector.(Ref(M), Ref(q), [Ξ_q_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
            CC_loss = curvature_corrected_loss_fromlog(M, q, log_q_X, Ξ_q)
            ex_loss = exact_loss(M, q, X, Ξ_q)
            @printf("iter #%-4i (U) | CCL: %.15f | exact loss: %.15f | change (U): %.5f\n", iter, CC_loss, ex_loss, norm(Uₖ-Uprev))
            # println("iter #$(iter) (U) | CCL: $(CC_loss) | exact loss: $(ex_loss)")
        end

    end
    if init_type == "kmeans"
        U₀ = G_init;
    end
    return Uₖ, Vₖ, Ξ_q, U₀, V₀
end

curvature_corrected_sNMF_precomp (generic function with 1 method)

In [20]:
function get_corrected_factors(M::AbstractManifold, q, U, V)
    # U is n \times k
    # V is k \times d
    K = size(U)[2]
    V_TM = get_vector.(Ref(M), Ref(q), [V[l,:] for l=1:K], Ref(DefaultOrthonormalBasis()))
    # proj_components = zeros(K)
    proj_components = reduce(hcat, [[minimum([inner(M, q, V_TM[k], V_TM[l]),0])/inner(M, q, V_TM[k], V_TM[k]) for k=1:K] for l=1:K])';
    H = U + U * proj_components
    return H
end

get_corrected_factors (generic function with 1 method)

In [21]:
function get_label_mat(A)
    # assuming A is n \times k, where n is number of points and k is number of classes
    _, idxs = findmax(A;dims=2)
    B = zeros(size(A))
    B[idxs] .= 1.0
    return B
end

get_label_mat (generic function with 1 method)

In [22]:
function get_used_labels(A)
    idxs = findall(x -> x > 0,sum(A;dims=1))
    cols_to_keep = [idxs[i][2] for i=1:length(idxs)]
    clean_A = A[:,cols_to_keep]
    return clean_A
end

get_used_labels (generic function with 1 method)

In [23]:
function get_f1(X,Y; avg_type="equal")
    # X is n \times k label matrix of ground truth
    # Y is n \times r label matrix
    clean_X = get_used_labels(X)
    clean_Y = get_used_labels(Y)
    k=size(clean_X)[2]
    r=size(clean_Y)[2]
    inner_pdts = clean_X' * clean_Y
    X_tot = repeat(sum(clean_X;dims=1)', outer=(1,r))
    Y_tot = repeat(sum(clean_Y;dims=1), outer=(k,1))
    F_mat = 2.0*inner_pdts ./ (X_tot + Y_tot);
    F = maximum(F_mat;dims=2)
    if avg_type=="equal"
        return mean(F)
    elseif avg_type=="weighted"
        return mean(F,weights(sum(clean_X;dims=1)))
    elseif avg_type=="none"
        return F
    else
        throw(ArgumentError("avg_type must be equal, weighted, or none"))
    end
end

get_f1 (generic function with 1 method)

In [24]:
function get_rand_idx(X,Y)
    # assuming X is n \times k, Y is n \times r, both label matrices
    X_clean = get_used_labels(X);
    Y_clean = get_used_labels(Y);
    X_same_counts = X_clean * X_clean';
    Y_same_counts = Y_clean * Y_clean';
    N = size(X)[1]
    same_clusters = X_same_counts .* Y_same_counts;
    a = sum(triu(same_clusters,1))
    b = sum(triu((ones(N,N)-X_same_counts) .* (ones(N,N) - Y_same_counts),1))
    return (a + b)/binomial(N,2)
end

get_rand_idx (generic function with 1 method)

### Load data and construct manifold ###

In [25]:
ni_xx = niread("../../../data/IITmean_xx.nii.gz") * 1e3;
ni_yx = niread("../../../data/IITmean_yx.nii.gz") * 1e3;
ni_yy = niread("../../../data/IITmean_yy.nii.gz") * 1e3;
ni_zx = niread("../../../data/IITmean_zx.nii.gz") * 1e3;
ni_zy = niread("../../../data/IITmean_zy.nii.gz") * 1e3;
ni_zz = niread("../../../data/IITmean_zz.nii.gz") * 1e3;
size(ni_xx)

(182, 218, 182)

In [26]:
d1, d2, d3  = size(ni_xx)

(182, 218, 182)

In [27]:
predata = [ # data ordered as [xx, yx, yy, zx, zy, zz]
    [
    [ni_xx[i,j,k] + 1e-5;; ni_yx[i,j,k];; ni_zx[i,j,k]]; 
    [ni_yx[i,j,k];; ni_yy[i,j,k] + 1e-5;; ni_zy[i,j,k]]; 
    [ni_zx[i,j,k];; ni_zy[i,j,k];; ni_zz[i,j,k] + 1e-5]
    ] for i=1:d1, j=1:d2, k=1:d3];

In [28]:
# all_det = det.(predata);

In [29]:
# minimum(findall(x -> x > 2e-15, all_det[:,:,25:124]))

In [30]:
# maximum(findall(x -> x > 2e-15, all_det[:,:,25:124]))

In [31]:
# predata_masked = predata[18:164,15:203,1:161];
# predata_masked = predata[41:140,61:160,56];
# predata_masked = predata[41:140,61:160,25:124];

# predata_masked = predata[41:140,61:160,25:44];
# smaller patches for testing
# predata_masked = predata[81:100,101:120,25:44];

# small center patch
# predata_masked = predata[86:95,106:115,25:44];

# used in previous experiments:
# small corner patch
# predata_masked = @view predata[101:110,151:160,51:100]; # used for saved data
# predata_masked = predata[101:110,151:160,51]; # one slice

In [32]:
function get_step_range(start, stop, step)
    n_steps = ((stop-start) + 1) / step;
    this_range = [convert(Int64,start + (i-1) * step) for i=1:(n_steps+1)]
    return this_range
end

get_step_range (generic function with 1 method)

In [33]:
# let's take 4 x 4 x 4 voxels
# first let's try the piece from the previous (approximately):
x1 = 101;
x2 = 112;
y1 = 151;
y2 = 162;
z1 = 51;
z2 = 102;
step = 4;
x_range = get_step_range(x1,x2,step);
y_range = get_step_range(y1,y2,step);
z_range = get_step_range(z1,z2,step);

In [34]:
predata_masked = [reshape(predata[x_range[i]:x_range[i+1]-1,y_range[j]:y_range[j+1]-1,z_range[k]:z_range[k+1]-1],
        (1,:)) for i=1:length(x_range)-1,
        j=1:length(y_range)-1,k=1:length(z_range)-1];

In [35]:
size(predata_masked)

(3, 3, 13)

In [36]:
data=vec(predata_masked);
size(data)

(117,)

In [37]:
size(data[1])

(1, 64)

In [38]:
ni_xx = nothing;
ni_yx = nothing;
ni_yy = nothing;
ni_zx = nothing;
ni_zy = nothing;
ni_zz = nothing;

In [39]:
# construct data manifold
n = size(data)[1] 
M = PowerManifold(SymmetricPositiveDefinite(3), NestedPowerRepresentation(), size(data[1])[2])
d = manifold_dimension(M);
print(d)

384

In [40]:
D = power_dimensions(M)[1];
print(D)

64

In [41]:
predata_masked=nothing;

In [42]:
predata=nothing;

In [43]:
GC.gc()

In [44]:
# Export slice image
# filename = "full_IIT_test.asy"
# # asymptote_export_SPD(filename, data = reshape(data, (10, 10)), scale_axes=(2, 2, 2));
# asymptote_export_SPD(filename, data = predata_masked, scale_axes=(1, 1, 1));

In [45]:
# render_asymptote(filename);

In [46]:
# filename = "IIT_test.asy"
# asymptote_export_SPD(filename, data=[predata_masked[47,138,86]]);

### Construct low rank approximation ###

In [47]:
q = [Matrix(I, 3, 3) .* 1e-5 for i=1:D];

In [48]:
# q = mean(M, data)
# q = Matrix(I, 3, 3) .* 1e-4
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [49]:
offset = 0.1
# E_iter = 50
CC_iter = 20;
debug_int = 1;
U_iters = 10;
U_debug_int = 3;

In [52]:
ranks=[20,40]
# ranks=7
# nL = []
ccL = []
# nU_q = []
# nV_q = []

ccU_q = []
ccV_q = []
for k=ranks
    println("computing rank $(k) approximation")
    # Uₖ, Vₖ, Ξ_q, U₀, V₀ = curvature_corrected_sNMF_precomp(M, q, data, log_q_data, k; 
    #     max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans",
    #     ENMF_const = offset, 
    #     max_U_iter=U_iters, U_debug=U_debug_int);
    Uₖ, Vₖ, Ξ_q, U₀, V₀ = curvature_corrected_sNMF_precomp(M, q, data, log_q_data, k; 
        max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans",
        ENMF_const = offset, 
        max_U_iter=U_iters, U_debug=U_debug_int);
    # push!(nU_q, U₀)
    # push!(nV_q, V₀)
    # Ξ_naive_coords = U₀ * V₀
    # Ξ_naive = get_vector.(Ref(M), Ref(q), [Ξ_naive_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    push!(ccU_q,Uₖ)
    push!(ccV_q,Vₖ)
    # loss_naive = exact_loss(M, q, data, Ξ_naive);
    loss_cc = exact_loss(M, q, data, Ξ_q);
    # push!(nL, loss_naive)
    push!(ccL, loss_cc)
    # println("computing rank $(k) approximation (U first)")
    # Uₖ, Vₖ, Ξ_q, U₀, V₀ = curvature_corrected_sNMF(M, q, data, k; ENMF_const = offset, ENMF_iter = E_iter, max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans", first_factor="U");
    # push!(ccU_q_uf,Uₖ)
    # push!(ccV_q_uf,Vₖ)
    # loss_cc = exact_loss(M, q, data, Ξ_q);
    # push!(ccL_uf, loss_cc)
end

computing rank 20 approximation
computed initialization
iter #0        | CCL: 0.001255518777412 | exact loss: 0.001254790874380
iter #1    (V) | CCL: 0.000428747619051 | exact loss: 0.000428327989921 | change (V): 75.78108
	subiter #3    | CCL: 0.000393624819833 | exact loss: 0.000393211974086 | change (U): 0.00000
	subiter #6    | CCL: 0.000390877039759 | exact loss: 0.000390469682492 | change (U): 0.00000
	subiter #9    | CCL: 0.000388671449745 | exact loss: 0.000388269326048 | change (U): 0.00000
iter #1    (U) | CCL: 0.000387947044276 | exact loss: 0.000387546637449 | change (U): 0.02540
iter #2    (V) | CCL: 0.000387922411594 | exact loss: 0.000387523162908 | change (V): 0.61248
	subiter #3    | CCL: 0.000385748219825 | exact loss: 0.000385354090261 | change (U): 0.00000
	subiter #6    | CCL: 0.000383610567873 | exact loss: 0.000383221459169 | change (U): 0.00000
	subiter #9    | CCL: 0.000381508485744 | exact loss: 0.000381124302801 | change (U): 0.00000
iter #2    (U) | CCL: 0.0

In [53]:
save("IIT_power_voxels_2040.jld","ccV_q",ccV_q,"ccU_q",ccU_q,"ccL",ccL)

In [None]:
ccU_q_proj = []
for k=ranks
    push!(ccU_q_proj, get_corrected_factors(M, q, ccU_q_vf[k-9], ccV_q_vf[k-9]))
end

In [None]:
U_diffs = []
for k=ranks
    push!(U_diffs, norm(ccU_q_proj[k-9] - ccU_q_vf[k-9]))
end

In [None]:
maximum(U_diffs)

In [None]:
plot(ranks,[nL, ccL_vf],label=["exact loss (naive approx)" "exact loss (CC)"],xaxis="approximation rank",yaxis="exact loss")

In [None]:
plot(ranks, ccL_vf)

In [None]:
this_V = ccV_q[1];

In [None]:
size(this_V)

In [None]:
V_TM = get_vector.(Ref(M), Ref(q), [this_V[l,:] for l=1:10], Ref(DefaultOrthonormalBasis()));

In [None]:
size(V_TM)

In [None]:
V_TM[1][10]

In [None]:
this_U = ccU_q[1];

In [None]:
U_max = maximum(this_U; dims=1);

In [None]:
U_max

In [None]:
folder_name = "IIT_power"

In [None]:
for k=ranks
    ccU = ccU_q[k-9]
    ccV = ccV_q[k-9]

    ccV_TM = get_vector.(Ref(M), Ref(q), [ccV[l,:] for l=1:k], Ref(DefaultOrthonormalBasis()))
    ccU_max = maximum(ccU; dims=1)
    overall_factor_list_cc = exp.(Ref(M), Ref(q), [ccU_max[i] * ccV_TM[i] for i=1:k]) 
    # factors
    
    for j=1:k
        # factor_naive = factor_list_naive[j];
        # # factor_2D_naive = reshape(factor_naive, (10, 10)); # only for patches
        # factor_2D_naive = reshape(factor_naive, (48, 48)); # for full data
        # fig_filename_naive = @sprintf("%s/k%d/factor%d_naive.asy", folder_name, k, j);
        # asymptote_export_SPD(fig_filename_naive, data=factor_2D_naive, scale_axes=(2,2,2)); 
        # render_asymptote(fig_filename_naive)
        factor_cc_unscaled = overall_factor_list_cc[j];
        # factor_cc = overall_factor_list_cc[j];
        
        eigmaxes = opnorm.(factor_cc_unscaled, 2);
        if maximum(eigmaxes) < 1
            factor_cc = factor_cc_unscaled / maximum(eigmaxes);
        else
            factor_cc = factor_cc_unscaled
        end
        factor_2D_cc = reshape(factor_cc, (10, 10));
        fig_filename_cc = @sprintf("%s/k%d/cc_factor%d.asy", folder_name, k, j);
        asymptote_export_SPD(fig_filename_cc, data=factor_2D_cc, scale_axes=(2,2,2)); 
        render_asymptote(fig_filename_cc)     
        
    end
    
end

In [None]:
ni_atlas = niread("../../../data/IIT_WM_atlas.nii.gz");

In [None]:
size(ni_atlas)

The data is from z = 56, but let's try with a different slice.

In [None]:
sub_atlas = ni_atlas[41:140,61:160,56,:];

In [None]:
size(sub_atlas)

In [None]:
GM_desikan_prob = niread("../../../data/IIT_GM_Desikan_prob.nii.gz");

In [None]:
size(GM_desikan_prob)

In [None]:
sub_GM_desikan_prob = GM_desikan_prob[41:140,61:160,56,:];

In [None]:
size(sub_GM_desikan_prob)

In [None]:
GM_desikan_mask = niread("../../../data/IIT_GM_Desikan_mask.nii");

In [None]:
size(GM_desikan_mask)

In [None]:
sub_GM_desikan_mask = GM_desikan_mask[41:140,61:160,56];

In [None]:
heatmap(sub_GM_desikan_mask)

In [None]:
vec_mask = reshape(sub_GM_desikan_mask,(10000,));

In [None]:
vec_GM_probs = reshape(sub_GM_desikan_prob,(10000,84));

In [None]:
size(vec_GM_probs)

In [None]:
GM_probs_masked = vec_GM_probs[vec_mask .> 0,:];

In [None]:
size(GM_probs_masked)

In [None]:
test_data = vec(predata[41:140,61:160,56]);

In [None]:
this_V = ccV_q_vf[10];

In [None]:
log_test_data = log.(Ref(M), Ref(q), test_data);

In [None]:
this_V_TM = get_vector.(Ref(M), Ref(q), [this_V[l,:] for l=1:10], Ref(DefaultOrthonormalBasis()));

In [None]:
n

In [None]:
all_dists = [norm.(Ref(log_test_data[i]).-this_V_TM) for i=1:n];

In [None]:
size(all_dists)

In [None]:
ris = []
vms_0p1 = []
vms_0p01 = []
vms_0p5 = []
vms_0p1_cc = []
vms_0p01_cc = []
vms_0p5_cc = []
ris_cc = []
vms_cc = []
vec_mask = reshape(sub_GM_desikan_mask,(10000,));
vec_GM_probs = reshape(sub_GM_desikan_prob,(10000,84));
GM_probs_masked = vec_GM_probs[vec_mask .> 0,:];
atlas_assign = mapslices(argmax, GM_probs_masked, dims=2)[:];
# log_test_data = log.(Ref(M), Ref(q), test_data[vec_mask .> 0]);
# n_test = size(log_test_data)[1];
for r=ranks
    # ccV = ccV_q_vf[r-9]
    # ccV_TM = get_vector.(Ref(M), Ref(q), [ccV[l,:] for l=1:r], Ref(DefaultOrthonormalBasis()));
    # nV = nV_q[r-9]
    # nV_TM = get_vector.(Ref(M), Ref(q), [nV[l,:] for l = 1:r], Ref(DefaultOrthonormalBasis()));
    # rep_data = repeat(log_test_data,outer=(1,r));
    this_U = nU_q[r-9]
    this_U_masked = this_U[vec_mask .> 0,:];
    n_assign = mapslices(argmax, this_U_masked, dims=2)[:];
    # cc_assign=argmin.([[curvature_corrected_loss_fromlog(M, q, [log_test_data[i]], [ccV_TM[j]]) for j=1:r] for i=1:n_test]);
    # n_assign=argmin.([[curvature_corrected_loss_fromlog(M, q, [log_test_data[i]], [nV_TM[j]]) for j=1:r] for i=1:n_test]);
    
    # cc_assign = mapslices(argmin, norm.(rep_data.-ccV_TM');dims=2)[vec_mask .> 0];
    # n_assign = mapslices(argmin, norm.(rep_data.-nV_TM');dims=2)[vec_mask .> 0];
    
    # ri = get_rand_idx(get_label_mat(GM_probs_masked),get_label_mat(this_U_masked));
    # vm_0p5 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked, dims=2);β=0.5)
    # vm_0p1 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked, dims=2);β=0.1)
    # vm_0p01 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked, dims=2);β=0.01)

    ri = randindex(atlas_assign, n_assign)
    vm_0p5 = vmeasure(atlas_assign, n_assign;β=0.5)
    vm_0p1 = vmeasure(atlas_assign, n_assign;β=0.1)
    vm_0p01 = vmeasure(atlas_assign, n_assign;β=0.01)
    push!(ris, ri)
    push!(vms_0p1, vm_0p1)
    push!(vms_0p01, vm_0p01)
    push!(vms_0p5, vm_0p5)
    this_U_cc = ccU_q_vf[r-9]
    this_U_masked_cc = this_U_cc[vec_mask .> 0,:];
    cc_assign = mapslices(argmax, this_U_masked_cc, dims=2)[:];
    # ri_cc = get_rand_idx(get_label_mat(GM_probs_masked),get_label_mat(this_U_masked_cc));
    # vm_cc = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked_cc, dims=2);β=0.5)
    # vm_0p5 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked_cc, dims=2);β=0.5)
    # vm_0p1 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked_cc, dims=2);β=0.1)
    # vm_0p01 = vmeasure(mapslices(argmax, GM_probs_masked, dims=2), mapslices(argmax, this_U_masked_cc, dims=2);β=0.01)
    ri_cc = randindex(atlas_assign, cc_assign)
    vm_0p5 = vmeasure(atlas_assign, cc_assign;β=0.5)
    vm_0p1 = vmeasure(atlas_assign, cc_assign;β=0.1)
    vm_0p01 = vmeasure(atlas_assign, cc_assign;β=0.01)

    push!(ris_cc, ri_cc)
    # push!(vms_cc, vm_cc)
    push!(vms_0p1_cc, vm_0p1)
    push!(vms_0p01_cc, vm_0p01)
    push!(vms_0p5_cc, vm_0p5)
    @printf("rank %i done\n", r)
end

In [None]:
size(test_data)

In [None]:
typeof(atlas_assign[:])

In [None]:
test_ris = reduce(hcat, ris);

In [None]:
size(test_ris)

In [None]:
test_ris[1:4]

In [None]:
plot(ranks,[[ris[i-9][2] for i=ranks] [ris_cc[i-9][2] for i=ranks]],label=["ri" "ri cc"])

In [None]:
plot(ranks,[vms_0p5,vms_0p1,vms_0p01],label=["vms 0p5" "vms 0p1" "vms 0p01"])

In [None]:
plot(ranks,[vms_0p5_cc,vms_0p1_cc,vms_0p01_cc],label=["vms 0p5" "vms 0p1" "vms 0p01"])

In [None]:
heatmap(reshape(nU_q[1][:,1],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,1],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,2],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,2],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,3],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,3],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,4],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,4],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,5],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,5],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,6],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,6],(100,100)))

In [None]:
heatmap(reshape(nU_q[1][:,7],(100,100)))

In [None]:
heatmap(reshape(ccU_q_vf[1][:,7],(100,100)))

In [None]:
for k=ranks
    ccU = ccU_q[k-1]
    ccV = ccV_q[k-1]
    ccU_proj = ccU_q_proj[k-1]
    nU = nU_q[k-1]
    nV = nV_q[k-1]
    Ξ_naive_coords = nU * nV
    Ξ_cc_coords = ccU * ccV
    Ξ_cc_proj_coords = ccU_proj * ccV
    Ξ_naive = get_vector.(Ref(M), Ref(q), [Ξ_naive_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    Ξ_cc = get_vector.(Ref(M), Ref(q), [Ξ_cc_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    Ξ_cc_proj = get_vector.(Ref(M), Ref(q), [Ξ_cc_proj_coords[l,:] for l=1:n[1]], Ref(DefaultOrthonormalBasis()))
    nV_TM = get_vector.(Ref(M), Ref(q), [nV[l,:] for l=1:k], Ref(DefaultOrthonormalBasis()))
    ccV_TM = get_vector.(Ref(M), Ref(q), [ccV[l,:] for l=1:k], Ref(DefaultOrthonormalBasis()))
    nU_max = maximum(nU; dims=1)
    ccU_max = maximum(ccU; dims=1)
    ccU_proj_max = maximum(ccU_proj; dims=1)
    overall_factor_list_naive = exp.(Ref(M), Ref(q), [nU_max[i] * nV_TM[i] for i=1:k])
    overall_factor_list_cc = exp.(Ref(M), Ref(q), [ccU_max[i] * ccV_TM[i] for i=1:k])
    overall_factor_list_cc_proj = exp.(Ref(M), Ref(q), [ccU_proj_max[i] * ccV_TM[i] for i=1:k])
    fig_filename_naive = @sprintf("%s/k%d/overall_factors_naive.asy", folder_name, k)
    fig_filename_cc = @sprintf("%s/k%d/overall_factors_cc.asy", folder_name, k)
    fig_filename_cc_proj = @sprintf("%s/k%d/overall_factors_cc_proj.asy", folder_name, k)
    # asymptote_export_SPD(fig_filename_naive, data=overall_factor_list_naive, scale_axes=(2, 2, 2));
    asymptote_export_SPD(fig_filename_naive, data=overall_factor_list_naive, scale_axes=(axes_scale_n, axes_scale_n, axes_scale_n));
    render_asymptote(fig_filename_naive)
    # asymptote_export_SPD(fig_filename_cc, data=overall_factor_list_cc, scale_axes=(2, 2, 2)); 
    asymptote_export_SPD(fig_filename_cc, data=overall_factor_list_cc, scale_axes=(axes_scale, axes_scale, axes_scale)); 
    render_asymptote(fig_filename_cc)
    
    asymptote_export_SPD(fig_filename_cc_proj, data=overall_factor_list_cc_proj, scale_axes=(axes_scale, axes_scale, axes_scale)); 
    render_asymptote(fig_filename_cc_proj)
    
    # reconstructions
    reconst_1D_naive = exp.(Ref(M), Ref(q), Ξ_naive);
    reconst_1D_cc = exp.(Ref(M), Ref(q), Ξ_cc);
    reconst_1D_cc_proj = exp.(Ref(M), Ref(q), Ξ_cc_proj);
    # factor_M = exp.(Ref(M), Ref(q), get_vector.(Ref(M), Ref(q), [this_U[i,1] * transpose(this_V)[1,:] for i=1:n], Ref(DefaultOrthonormalBasis())));
    # reconst_2D_naive = reshape(reconst_1D_naive, (10, 10));
    reconst_2D_naive = reshape(reconst_1D_naive, (48, 48));
    # reconst_2D_cc = reshape(reconst_1D_cc, (10, 10));
    reconst_2D_cc = reshape(reconst_1D_cc, (48, 48));
    reconst_2D_cc_proj = reshape(reconst_1D_cc_proj, (48, 48));
    fig_filename_naive = @sprintf("%s/k%d/reconst_naive.asy", folder_name, k)
    fig_filename_cc = @sprintf("%s/k%d/reconst_cc.asy", folder_name, k)
    fig_filename_cc_proj = @sprintf("%s/k%d/reconst_cc_proj.asy", folder_name, k)
    asymptote_export_SPD(fig_filename_naive, data=reconst_2D_naive, scale_axes=(2, 2, 2)); 
    render_asymptote(fig_filename_naive)
    asymptote_export_SPD(fig_filename_cc, data=reconst_2D_cc, scale_axes=(2, 2, 2)); 
    render_asymptote(fig_filename_cc)
    asymptote_export_SPD(fig_filename_cc_proj, data=reconst_2D_cc_proj, scale_axes=(2, 2, 2)); 
    render_asymptote(fig_filename_cc_proj)
    
    # factors
    factor_list_naive = [exp.(Ref(M), Ref(q), [nU[i,j] * nV_TM[j] for i=1:n[1]]) for j=1:k];
    factor_list_cc = [exp.(Ref(M), Ref(q), [ccU[i,j] * ccV_TM[j] for i=1:n[1]]) for j=1:k];
    factor_list_cc_proj = [exp.(Ref(M), Ref(q), [ccU_proj[i,j] * ccV_TM[j] for i=1:n[1]]) for j=1:k];
    
    for j=1:k
        # factor_naive = factor_list_naive[j];
        # # factor_2D_naive = reshape(factor_naive, (10, 10)); # only for patches
        # factor_2D_naive = reshape(factor_naive, (48, 48)); # for full data
        # fig_filename_naive = @sprintf("%s/k%d/factor%d_naive.asy", folder_name, k, j);
        # asymptote_export_SPD(fig_filename_naive, data=factor_2D_naive, scale_axes=(2,2,2)); 
        # render_asymptote(fig_filename_naive)
        factor_cc_unscaled = factor_list_cc[j];
        eigmaxes = opnorm.(factor_cc_unscaled, 2);
        if maximum(eigmaxes) < 1
            factor_cc = factor_cc_unscaled / maximum(eigmaxes);
        else
            factor_cc = factor_cc_unscaled
        end
        # factor_2D_cc = reshape(factor_cc, (10, 10));
        factor_2D_cc = reshape(factor_cc, (48, 48));
        fig_filename_cc = @sprintf("%s/k%d/cc_factor%d.asy", folder_name, k, j);
        asymptote_export_SPD(fig_filename_cc, data=factor_2D_cc, scale_axes=(2,2,2)); 
        render_asymptote(fig_filename_cc)
        
        factor_cc_proj_unscaled = factor_list_cc_proj[j];
        eigmaxes = opnorm.(factor_cc_proj_unscaled, 2);
        if maximum(eigmaxes) < 1
            factor_cc_proj = factor_cc_proj_unscaled / maximum(eigmaxes);
        else
            factor_cc_proj = factor_cc_proj_unscaled
        end
        # factor_2D_cc = reshape(factor_cc, (10, 10));
        factor_2D_cc_proj = reshape(factor_cc_proj, (48, 48));
        fig_filename_cc_proj = @sprintf("%s/k%d/cc_proj_factor%d.asy", folder_name, k, j);
        asymptote_export_SPD(fig_filename_cc_proj, data=factor_2D_cc_proj, scale_axes=(2,2,2)); 
        render_asymptote(fig_filename_cc_proj)       
        
    end
    
end

## IIT Atlas

In [None]:
# filename = "IIT/true.asy"
# # asymptote_export_SPD(filename, data = reshape(data, (10, 10)), scale_axes=(2, 2, 2));
# asymptote_export_SPD(filename, data = predata2[67:116,85:134,100], scale_axes=(2, 2, 2));

# render_asymptote(filename);

In [None]:
q = Matrix(I, 3, 3) .* 1e-5
# q = Matrix(I, 3, 3) .* 1e-4
log_q_data = log.(Ref(M), Ref(q), data2);  # ∈ T_q P(3)^n

So the idea is that the coordinates should all be in the same..... higher-dimensional quadrant? Or do we check the inner products directly? Or is this equivalent?

In [None]:
data_coords_mat = reduce(hcat, get_coordinates.(Ref(M), Ref(q), log_q_data, Ref(DefaultOrthonormalBasis())));
inner_pdts = data_coords_mat' * data_coords_mat;

In [None]:
size(log_q_data)[1]^2 - count(inner_pdts .>= 0)

In [None]:
ris

In [None]:
vms

In [None]:
vms_cc

In [None]:
plot(ranks,ris_cc)

In [None]:
plot(ranks,[vms_0p01,vms_0p1,vms_0p5],label=["beta = 0.01" "beta = 0.1" "beta = 0.5"])

In [None]:
f1_total_mat = reduce(hcat, f1_total);

In [None]:
size(f1_total_mat)

In [None]:
heatmap(f1_total_mat)

In [None]:
get_f1(get_label_mat(GM_probs_masked),get_label_mat(this_U_masked);avg_type="equal")

In [None]:
heatmap(get_used_labels(get_label_mat(GM_probs_masked)))

In [None]:
heatmap(get_used_labels(get_label_mat(this_U_masked)))

## Figuring out random Julia stuff

In [None]:
k=10;

In [None]:
@time curvature_corrected_sNMF_precomp(M, q, data, log_q_data, k; 
        max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans",
        ENMF_const = offset, 
        max_U_iter=U_iters, U_debug=U_debug_int);

In [None]:
@time curvature_corrected_sNMF_precomp(M, q, data, log_q_data, k; 
        max_iter = CC_iter, debug_freq = debug_int, init_type="kmeans",
        ENMF_const = offset, 
        max_U_iter=U_iters, U_debug=U_debug_int);