# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using BenchmarkTools

using Clustering
using NIfTI

using JLD
using Printf

In [2]:
using StatsBase

In [3]:
include("../../src/functions/jacobi_field/beta.jl")
include("../../src/decompositions/signals/curvature_corrected_low_rank_approximation.jl")

curvature_corrected_low_rank_approximation (generic function with 1 method)

### Functions

In [4]:
function exact_loss(M::AbstractManifold, q, X, Ξ)
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    exp_q_Ξ =   exp.(Ref(M), Ref(q), Ξ)
    return sum(distance.(Ref(M), X, exp_q_Ξ).^2) / ref_distance
end

exact_loss (generic function with 1 method)

In [5]:
function curvature_corrected_loss(M::AbstractManifold, q, X, Ξ)
    n = size(X)
    # compute log
    log_q_X = log.(Ref(M), Ref(q), X)  # ∈ T_q M^n
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k in 1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], Θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
    end
    return loss/ref_distance
end

curvature_corrected_loss (generic function with 1 method)

In [6]:
function curvature_corrected_loss_fromlog(M::AbstractManifold, q, log_q_X, Ξ)
    n = size(log_q_X)
    ref_distance = sum(norm.(Ref(M), Ref(q), log_q_X).^2)
    if typeof(M) <: PowerManifold
        D = power_dimensions(M)[1]
        d = manifold_dimension(M.manifold)
    else
        d = manifold_dimension(M)
    end   
    # compute directions
    I = CartesianIndices(n)
    loss = 0.
    for i in I
        if typeof(M) <: PowerManifold
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            for k=1:D
                θₖ = ONBᵢ.data.bases[k].data.vectors
                κₖ = ONBᵢ.data.bases[k].data.eigenvalues
                # compute loss
                loss += sum([β(κₖ[j])^2 * inner(M.manifold, q[k], Ξ[i][k] - log_q_X[i][k], θₖ[j])^2 for j=1:d])
            end
        else
            ONBᵢ = get_basis(M, q, DiagonalizingOrthonormalBasis(log_q_X[i]))
            Θᵢ = ONBᵢ.data.vectors
            κᵢ = ONBᵢ.data.eigenvalues
            loss += sum([β(κᵢ[j])^2 * inner(M, q, Ξ[i] - log_q_X[i], Θᵢ[j])^2 for j=1:d])
        end
        
        
    end
    return loss/ref_distance
end

curvature_corrected_loss_fromlog (generic function with 1 method)

### Load data and construct manifold ###

In [7]:
ni_xx = niread("../../data/IITmean_xx.nii.gz") * 1e3;
ni_yx = niread("../../data/IITmean_yx.nii.gz") * 1e3;
ni_yy = niread("../../data/IITmean_yy.nii.gz") * 1e3;
ni_zx = niread("../../data/IITmean_zx.nii.gz") * 1e3;
ni_zy = niread("../../data/IITmean_zy.nii.gz") * 1e3;
ni_zz = niread("../../data/IITmean_zz.nii.gz") * 1e3;
size(ni_xx)

(182, 218, 182)

In [8]:
d1, d2, d3  = size(ni_xx)

(182, 218, 182)

In [9]:
predata = [ # data ordered as [xx, yx, yy, zx, zy, zz]
    [
    [ni_xx[i,j,k] + 1e-5;; ni_yx[i,j,k];; ni_zx[i,j,k]]; 
    [ni_yx[i,j,k];; ni_yy[i,j,k] + 1e-5;; ni_zy[i,j,k]]; 
    [ni_zx[i,j,k];; ni_zy[i,j,k];; ni_zz[i,j,k] + 1e-5]
    ] for i=1:d1, j=1:d2, k=1:d3];

In [10]:
function get_step_range(start, stop, step)
    n_steps = ((stop-start) + 1) / step;
    this_range = [convert(Int64,start + (i-1) * step) for i=1:(n_steps+1)]
    return this_range
end

get_step_range (generic function with 1 method)

In [11]:
# let's take 4 x 4 x 4 voxels
x1 = 101;
x2 = 112;
y1 = 151;
y2 = 162;
z1 = 51;
z2 = 102;
step = 4;
x_range = get_step_range(x1,x2,step);
y_range = get_step_range(y1,y2,step);
z_range = get_step_range(z1,z2,step);

In [12]:
predata_masked = [reshape(predata[x_range[i]:x_range[i+1]-1,y_range[j]:y_range[j+1]-1,z_range[k]:z_range[k+1]-1],
        (1,:)) for i=1:length(x_range)-1,
        j=1:length(y_range)-1,k=1:length(z_range)-1];

In [13]:
size(predata_masked)

(3, 3, 13)

In [14]:
data=vec(predata_masked);
size(data)

(117,)

In [15]:
size(data[1])

(1, 64)

In [16]:
ni_xx = nothing;
ni_yx = nothing;
ni_yy = nothing;
ni_zx = nothing;
ni_zy = nothing;
ni_zz = nothing;

In [17]:
# construct data manifold
n = size(data)[1] 
M = PowerManifold(SymmetricPositiveDefinite(3), NestedPowerRepresentation(), size(data[1])[2])
d = manifold_dimension(M);
print(d)

384

In [18]:
D = power_dimensions(M)[1];
print(D)

64

In [19]:
predata_masked=nothing;

In [20]:
predata=nothing;

In [21]:
GC.gc()

### Construct low rank approximation ###

In [22]:
q = mean(M, data);

In [23]:
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [24]:
ranks=[20]  # 20
ccL = []
eL = []
ccR_q = []
ccU = []
for rank=ranks
    println("computing rank $(rank) approximation")
    ccRr_q, ccUr = curvature_corrected_low_rank_approximation(M, q, data, rank); 
    Ξ_q = [sum([Symmetric.(ccRr_q[i]) * ccUr[k,i] for i in 1:rank]) for k in 1:n]
    push!(ccR_q, ccRr_q)
    push!(ccU, ccUr)
    loss_cc = curvature_corrected_loss_fromlog(M, q, log_q_data, Ξ_q);
    loss_exact = exact_loss(M, q, data, Ξ_q);
    println("curvature corrected loss = $(loss_cc) and exact loss = $(loss_exact)")
    push!(ccL, loss_cc)
    push!(eL, loss_exact)
end

computing rank 20 approximation
curvature corrected loss = 0.01814373470609508 and exact loss = 0.018142407125046965


### Figures (example)

In [25]:
foldername="IIT_figs_CCSVD"

"IIT_figs_CCSVD"

In [26]:
n=size(data);
for r=1:length(ranks)
    rank=ranks[r]
    ccR = ccR_q[r]
    overall_factor_list_cc = exp.(Ref(M), Ref(q), [ccR[k] for k=1:rank]) 
    # factors
    
    for k=1:rank
        factor_cc_unscaled = overall_factor_list_cc[k];
        eigmaxes = opnorm.(factor_cc_unscaled, 2);
        if maximum(eigmaxes) < 1
            factor_cc = factor_cc_unscaled / maximum(eigmaxes);
        else
            factor_cc = factor_cc_unscaled
        end
        factor_2D_cc = reshape(factor_cc, (4, 4, 4));
        fig_filename_cc = @sprintf("%s/k%d/cc_factor%d.asy", foldername, rank, k);
        asymptote_export_SPD(fig_filename_cc, data=factor_2D_cc, scale_axes=(1.5,1.5,1.5), camera_position=(-2., 6., 14.)); 
        render_asymptote(fig_filename_cc)     
    end
end