# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [29]:
using Manifolds
using Manopt
using LinearAlgebra
using NIfTI
using Plots
using LaTeXStrings

In [16]:
include("../../../src/decompositions/naive_SVD.jl")
include("../../../src/decompositions/curvature_corrected_low_rank_approximation.jl")

curvature_corrected_low_rank_approximation (generic function with 1 method)

### Load data and construct manifold ###

In [17]:
# load data
# ni = niread("data/nifti_dt.nii.gz") * 1e9;
ni = niread("../../../data/nifti_dt_nonlinear.nii.gz") * 1e9;
size(ni)

(112, 112, 50, 1, 6)

In [18]:
 d1, d2, d3, _, d = size(ni)

# construct data as points on the manifold (and add 1e-5 * I for numerical stability)
predata = [ # data ordered as [xx, yx, yy, zx, zy, zz]
    [
    [ni[i,j,k,1,1] + 1e-5;; ni[i,j,k,1,2];; ni[i,j,k,1,4]]; 
    [ni[i,j,k,1,2];; ni[i,j,k,1,3] + 1e-5;; ni[i,j,k,1,5]]; 
    [ni[i,j,k,1,4];; ni[i,j,k,1,5];; ni[i,j,k,1,6] + 1e-5]
    ] for i=1:d1, j=1:d2, k=1:d3];
    
# pick a 1D slice 
data = vec(predata[34:81,32:79,15]); # TODO just get all the data (that is non-zero) and put it in a 1D array!!!!
print(size(data))
# construct data manifold
n = size(data)[1] 
M = SymmetricPositiveDefinite(3)
d = manifold_dimension(M);

(2304,)

In [19]:
# Export slice image
num_export = 20
asymptote_export_SPD("results/Camino1D_orig.asy", data=data[1:min(num_export,n)], scale_axes=(2,2,2));

### Construct low rank approximation ###

In [20]:
q = mean(M, data)
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [25]:
curvature_corrected_low_rank_approximation(M, q, data, 1; stepsize=1/200, max_iter=200, change_tol=1e-3)

Initial  F(x): 4.22480413859 | 
# 1     change: 0.144691585 |  F(x): 1.48136430851 | 
# 2     change: 0.059092687 |  F(x): 0.93593636067 | 
# 3     change: 0.034747709 |  F(x): 0.73630113408 | 
# 4     change: 0.023065619 |  F(x): 0.64600678305 | 
# 5     change: 0.016237493 |  F(x): 0.60052639611 | 
# 6     change: 0.011848076 |  F(x): 0.57605126237 | 
# 7     change: 0.008847833 |  F(x): 0.56230403114 | 
# 8     change: 0.006709775 |  F(x): 0.55435870493 | 
# 9     change: 0.005143115 |  F(x): 0.54967312036 | 
# 10    change: 0.003974773 |  F(x): 0.54686546007 | 
# 11    change: 0.003094858 |  F(x): 0.54515741590 | 
# 12    change: 0.002429751 |  F(x): 0.54410005435 | 
# 13    change: 0.001927963 |  F(x): 0.54343036639 | 
# 14    change: 0.001552198 |  F(x): 0.54299271739 | 
# 15    change: 0.001274539 |  F(x): 0.54269442511 | 
# 16    change: 0.001073304 |  F(x): 0.54248007712 | 
# 17    change: 0.000930950 |  F(x): 0.54231646327 | 
The algorithm performed a step with a change (0.00

([[25.45799600883235 -0.3532508618166673 -0.524517306486419; -0.3532508618166673 24.024464890091515 -0.3074860251427185; -0.524517306486419 -0.3074860251427185 23.409309229634868]], [0.0011466320135393726; -0.0020504948129943506; … ; 0.00037673717853216955; -0.0034605523130220336;;])

In [26]:
nR_q, nU = naive_SVD(M, q, data);
# TODO compute the low rank approximations up till rank d
ccR_q = []
ccU = []
for i in 1:d
    ccRr_q, ccUr = curvature_corrected_low_rank_approximation(M, q, data, i; stepsize=1/200, max_iter=200, change_tol=1e-3); 
    push!(ccR_q, ccRr_q)
    push!(ccU, ccUr)
end

Initial  F(x): 4.22480413859 | 
# 1     change: 0.144691585 |  F(x): 1.48136430851 | 
# 2     change: 0.059092687 |  F(x): 0.93593636067 | 
# 3     change: 0.034747709 |  F(x): 0.73630113408 | 
# 4     change: 0.023065619 |  F(x): 0.64600678305 | 
# 5     change: 0.016237493 |  F(x): 0.60052639611 | 
# 6     change: 0.011848076 |  F(x): 0.57605126237 | 
# 7     change: 0.008847833 |  F(x): 0.56230403114 | 
# 8     change: 0.006709775 |  F(x): 0.55435870493 | 
# 9     change: 0.005143115 |  F(x): 0.54967312036 | 
# 10    change: 0.003974773 |  F(x): 0.54686546007 | 
# 11    change: 0.003094858 |  F(x): 0.54515741590 | 
# 12    change: 0.002429751 |  F(x): 0.54410005435 | 
# 13    change: 0.001927963 |  F(x): 0.54343036639 | 
# 14    change: 0.001552198 |  F(x): 0.54299271739 | 
# 15    change: 0.001274539 |  F(x): 0.54269442511 | 
# 16    change: 0.001073304 |  F(x): 0.54248007712 | 
# 17    change: 0.000930950 |  F(x): 0.54231646327 | 
The algorithm performed a step with a change (0.00

# 50    change: 0.003390879 |  F(x): 0.49315353583 | 
# 51    change: 0.003314371 |  F(x): 0.49098047549 | 
# 52    change: 0.003241721 |  F(x): 0.48890097089 | 
# 53    change: 0.003172688 |  F(x): 0.48690847138 | 
# 54    change: 0.003107047 |  F(x): 0.48499698994 | 
# 55    change: 0.003044591 |  F(x): 0.48316104791 | 
# 56    change: 0.002985123 |  F(x): 0.48139562601 | 
# 57    change: 0.002928464 |  F(x): 0.47969612067 | 
# 58    change: 0.002874442 |  F(x): 0.47805830521 | 
# 59    change: 0.002822901 |  F(x): 0.47647829501 | 
# 60    change: 0.002773689 |  F(x): 0.47495251642 | 
# 61    change: 0.002726670 |  F(x): 0.47347767875 | 
# 62    change: 0.002681711 |  F(x): 0.47205074914 | 
# 63    change: 0.002638691 |  F(x): 0.47066892988 | 
# 64    change: 0.002597493 |  F(x): 0.46932963798 | 
# 65    change: 0.002558012 |  F(x): 0.46803048673 | 
# 66    change: 0.002520144 |  F(x): 0.46676926892 | 
# 67    change: 0.002483794 |  F(x): 0.46554394180 | 
# 68    change: 0.002448874 

 F(x): 0.23642195297 | 
# 70    change: 0.001333384 |  F(x): 0.23606907393 | 
# 71    change: 0.001312933 |  F(x): 0.23572688875 | 
# 72    change: 0.001293154 |  F(x): 0.23539488970 | 
# 73    change: 0.001274023 |  F(x): 0.23507259730 | 
# 74    change: 0.001255515 |  F(x): 0.23475955846 | 
# 75    change: 0.001237608 |  F(x): 0.23445534489 | 
# 76    change: 0.001220279 |  F(x): 0.23415955154 | 
# 77    change: 0.001203508 |  F(x): 0.23387179522 | 
# 78    change: 0.001187274 |  F(x): 0.23359171327 | 
# 79    change: 0.001171559 |  F(x): 0.23331896232 | 
# 80    change: 0.001156342 |  F(x): 0.23305321716 | 
# 81    change: 0.001141606 |  F(x): 0.23279416970 | 
# 82    change: 0.001127334 |  F(x): 0.23254152796 | 
# 83    change: 0.001113509 |  F(x): 0.23229501513 | 
# 84    change: 0.001100115 |  F(x): 0.23205436872 | 
# 85    change: 0.001087135 |  F(x): 0.23181933976 | 
# 86    change: 0.001074556 |  F(x): 0.23158969199 | 
# 87    change: 0.001062363 |  F(x): 0.23136520123 | 
# 88

In [27]:
ref_distance = sum(distance.(Ref(M), Ref(q), data).^2)

naive_tangent_distances_r = zeros(d)
naive_distances_r = zeros(d)
curvature_corrected_tangent_distances_r = zeros(d)
curvature_corrected_distances_r = zeros(d)

for rank in 1:d
    naive_log_q_data_r = Symmetric.([sum([nR_q[i] * nU[k,i] for i in d-rank+1:d]) for k in 1:n])
    curvature_corrected_log_q_data_r = Symmetric.([sum([ccR_q[rank][i] * ccU[rank][k,i] for i in 1:rank]) for k in 1:n])
    
    # expoentiate back
    naive_data_r = exp.(Ref(M), Ref(q), naive_log_q_data_r)
    curvature_corrected_data_r = exp.(Ref(M), Ref(q), curvature_corrected_log_q_data_r)


    # compute relative tangent space error
    naive_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - naive_log_q_data_r).^2) / ref_distance
    curvature_corrected_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - curvature_corrected_log_q_data_r).^2) / ref_distance


    # compute relative manifold error
    naive_distances_r[rank] = sum(distance.(Ref(M), data, naive_data_r).^2) / ref_distance
    curvature_corrected_distances_r[rank] = sum(distance.(Ref(M), data, curvature_corrected_data_r).^2) / ref_distance
end

In [30]:
# We want plots with (1) the lower bound error, (2) the actually uncorrected manifold error and (3) the corrected manifold error
plot(1:d, [naive_tangent_distances_r, naive_distances_r, curvature_corrected_distances_r], label = ["theoretical lower bound" "naive" "curvature corrected"], ylims=(0,1), xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$"))
savefig("results/Camino1D_errors_by_rank.png")
plot(1:d, [naive_tangent_distances_r .+ 1e-16, naive_distances_r .+ 1e-16, curvature_corrected_distances_r .+ 1e-16], label = ["theoretical lower bound" "naive" "curvature corrected"], ylims=(1e-16,1), xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$", :log), legend=:bottomleft)
savefig("results/Camino1D_logerrors_by_rank.png")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/Camino1D_logerrors_by_rank.png"