# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using NIfTI
using Plots

### Load data ###

In [2]:
# load data
# ni = niread("data/nifti_dt.nii.gz") * 1e9;
ni = niread("data/nifti_dt_nonlinear.nii.gz") * 1e9;

In [3]:
size(ni)

(112, 112, 50, 1, 6)

In [4]:
x = 49
y = 60
z = 25

ni[x,y,z,1,:]  # gets xx, yx, yy, zx, zy, zz

6-element Vector{Float64}:
  1.456667231281017
  0.08766011527772122
  1.8108553545559403
 -0.10191358174038712
 -0.15621183513392367
  1.4246920310156952

### Construct manifold ###

In [5]:
d1, d2, d3, _, d = size(ni)

(112, 112, 50, 1, 6)

In [6]:
# construct data as points on the manifold
predata = [
    [
    [ni[i,j,k,1,1];; ni[i,j,k,1,2];; ni[i,j,k,1,4]]; 
    [ni[i,j,k,1,2];; ni[i,j,k,1,3];; ni[i,j,k,1,5]]; 
    [ni[i,j,k,1,4];; ni[i,j,k,1,5];; ni[i,j,k,1,6]]
    ] for i=1:d1, j=1:d2, k=1:d3];
# predata = [array2mat(ni[i,j,k,1,:]) for i=1:d1, j=1:d2, k=1:d3];

In [7]:
print(size(predata))

(112, 112, 50)

In [45]:
# pick a 2D slice 
# xlims [34,81] y: [32,79] z:[8,30]
# data = predata[34:81, 32:79, 8:30]
data = predata[34:81,32:79,15]
D1, D2 = size(data)

M = SymmetricPositiveDefinite(3)
println(size(data))
check_point.(Ref(M), data)

(48, 48, 20)


48×48×20 Array{Union{Nothing, DomainError}, 3}:
[:, :, 1] =
 nothing  nothing  nothing  nothing  …  nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing  …  nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 ⋮                                   ⋱           ⋮                 
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing  …  nothing

In [None]:
# Export slice image

In [55]:
q = 1. * Matrix(I, 3, 3)
#  construct data
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^d1
exp_log_q_data = Symmetric.(exp.(Ref(M), Ref(q), log_q_data));
check_point.(Ref(M), exp_log_q_data)

48×48×20 Array{Union{Nothing, DomainError}, 3}:
[:, :, 1] =
 nothing  nothing  nothing  nothing  …  nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing  …  nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 ⋮                                   ⋱           ⋮                 
 nothing  nothing  nothing  nothing     nothing  nothing  nothing  nothing
 nothing  nothing  nothing  nothing  …  nothing

In [57]:
# check how transpose works
log_q_data[10,10,19]

3×3 transpose(::Matrix{Float64}) with eltype Float64:
 -0.888727  -0.129218   0.115751
 -0.129218  -0.266916  -0.253125
  0.115751  -0.253125   0.085243

First do HOSVD without curvature reweighing

In [58]:
Gramm1  = Symmetric([sum(inner.(Ref(M), Ref(q), log_q_data[k,:,:], log_q_data[l,:,:])) for k=1:D1, l=1:D1]);
Gramm2  = Symmetric([sum(inner.(Ref(M), Ref(q), log_q_data[:,k,:], log_q_data[:,l,:])) for k=1:D2, l=1:D2]);
Gramm3  = Symmetric([sum(inner.(Ref(M), Ref(q), log_q_data[:,:,k], log_q_data[:,:,l])) for k=1:D3, l=1:D3]);

In [59]:
(_, U1) = eigen(Gramm1);
(_, U2) = eigen(Gramm2);
(_, U3) = eigen(Gramm3);

In [60]:
rank = 10
r1, r2, r3 = rank .* (1,1,1)
U1r = U1[:,end-r1+1:end]
U2r = U2[:,end-r2+1:end]
U3r = U3[:,end-r3+1:end]


R_q_1 = Symmetric.([sum(U1r[:,i] .* log_q_data[:,j,k]) for i=1:r1, j=1:D2, k=1:D3])
R_q_12 = Symmetric.([sum(U2r[:,j] .* R_q_1[i,:,k]) for i=1:r1, j=1:r2, k=1:D3])
R_q = Symmetric.([sum(U3r[:,k] .* R_q_12[i,j,:]) for i=1:r1, j=1:r2, k=1:r3]);


In [61]:
log_q_data_rrr = Symmetric.([sum(U1r[i,:] .* R_q[:,j,k]) for i=1:D1, j=1:r2, k=1:r3])
log_q_data_rr = Symmetric.([sum(U2r[j,:] .* log_q_data_rrr[i,:,k]) for i=1:D1, j=1:D2, k=1:r3])
log_q_data_r = Symmetric.([sum(U3r[k,:] .* log_q_data_rr[i,j,:]) for i=1:D1, j=1:D2, k=1:D3]);

In [62]:
log_q_data_r[1,1,1]

3×3 Symmetric{Float64, Matrix{Float64}}:
 -0.016155     0.00362718  -0.00202822
  0.00362718  -0.0178575   -0.00397492
 -0.00202822  -0.00397492  -0.0287388

In [63]:
# relative error
sqrt(sum(norm.(Ref(M), Ref(q), log_q_data_r - log_q_data) .^ 2))/ sqrt(sum(norm.(Ref(M), Ref(q), log_q_data) .^ 2))

0.8540885626205013

In [67]:
data_r = Symmetric.(exp.(Ref(M), Ref(q), Symmetric.(log_q_data_r)))
check_point.(Ref(M), data_r)
# manifold_error = sqrt(sum(distance.(Ref(M), data_r, data) .^2))
println(data_r[10,10,19])
println(data[10,10,19])

[0.8003595064433119 -0.0040658567064974695 9.599457682010412e-5; -0.0040658567064974695 0.8198837594860204 -0.08422259910721691; 9.599457682010412e-5 -0.08422259910721691 0.8512047555091191]
[0.4202338477199419 -0.08564433190416665 0.0937187827343422; -0.08564433190416665 0.7998522955077192 -0.2412395283535318; 0.0937187827343422 -0.2412395283535318 1.1267979882489954]


Next, let's see what happens if we choose the reweighted metric