# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [8]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using BenchmarkTools

In [9]:
include("../../../src/decompositions/signals/naive_low_rank_approximation.jl")
include("../../../src/decompositions/signals/curvature_corrected_low_rank_approximation.jl")
include("../../../src/decompositions/signals/exact_low_rank_approximation.jl")

include("../../../src/functions/loss_functions/curvature_corrected_loss.jl")
include("../../../src/functions/loss_functions/exact_loss.jl")

exact_loss (generic function with 2 methods)

### Load data and construct manifold ###

In [10]:
# load data
M = SymmetricPositiveDefinite(3)
d = manifold_dimension(M)
n = 100  # 100


100

In [11]:
e = 1. * Matrix(I, 3, 3)
# compute basis
Θ = get_basis(M, e, DefaultOrthonormalBasis())
#  construct data
τ = 2.  # variance
σ = .05  # variance
Xₑ = Θ.data[4]
print(Xₑ)

Random.seed!(31)
predata = [exp(M, e, sqrt(τ) * randn(1)[1] * Xₑ) for i in 1:n]

data = [exp(M, predata[i], random_tangent(M, predata[i], Val(:Gaussian), σ)) for i in 1:n]; # ∈ P(3)^n


[0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0]

In [12]:
# Export slice image
num_export = 10
asymptote_export_SPD("results/artificial1D_orig.asy", data=data[1:min(num_export,n)], scale_axes=(2,2,2)); 

### Construct low rank approximation ###

In [13]:
q = mean(M, data)
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [14]:
(eRr_q, eUr), costs = exact_low_rank_approximation(M, q, data, 2; stepsize=1/10000, max_iter=10)

(([[0.17525404284891413 -0.287957326702762 0.14872841784557383; -0.28795732670276203 -0.00010583998011175689 0.2667710267880373; 0.14872841784557386 0.2667710267880373 0.15378791440207504], [0.13209193428322943 0.09619096703363013 0.03361047693778811; 0.09619096703363016 -22.690965378102366 0.07797447309149758; 0.03361047693778812 0.07797447309149758 0.03476301769969326]], [0.04141185945717856 0.043589454988513694; -0.062377171919476435 0.06746220792835492; … ; 0.08428851980435229 -0.04426547676489212; 0.09977091731476738 0.09542655939149962]), [0.049940970415489165, 0.019723246782690336, 0.05663330241992408, 0.011140240157969297, 0.05555910194147775, 0.013479927280286783, 0.057924414731575044, 0.010207364781930284, 0.05457367466117646, 0.015238115241389493])

In [15]:
# costs

In [16]:
max_iter = 50

nR_q = []
nU = []
ccR_q = []
ccU = []
eR_q = []
eU = []
eCosts = []
for i in 1:d  
    println("#$(i) | computing naive low-rank approximation")
    nRr_q, nUr = naive_low_rank_approximation(M, q, data, i)
    push!(nR_q, nRr_q)
    push!(nU, nUr)
    println("#$(i) | computing curvature corrected low-rank approximation")
    ccRr_q, ccUr = curvature_corrected_low_rank_approximation(M, q, data, i); 
    push!(ccR_q, ccRr_q)
    push!(ccU, ccUr)
    println("#$(i) | computing exact low-rank approximation")
    (eRr_q, eUr), eCostsr = exact_low_rank_approximation(M, q, data, i; stepsize=1/100000, max_iter=max_iter); 
    push!(eR_q, eRr_q)
    push!(eU, eUr)
    push!(eCosts, eCostsr)
end

#1 | computing naive low-rank approximation
#1 | computing curvature corrected low-rank approximation
#1 | computing exact low-rank approximation
#2 | computing naive low-rank approximation
#2 | computing curvature corrected low-rank approximation
#2 | computing exact low-rank approximation
#3 | computing naive low-rank approximation
#3 | computing curvature corrected low-rank approximation
#3 | computing exact low-rank approximation
#4 | computing naive low-rank approximation
#4 | computing curvature corrected low-rank approximation
#4 | computing exact low-rank approximation
#5 | computing naive low-rank approximation
#5 | computing curvature corrected low-rank approximation
#5 | computing exact low-rank approximation
#6 | computing naive low-rank approximation
#6 | computing curvature corrected low-rank approximation
#6 | computing exact low-rank approximation


In [17]:
ref_distance = sum(distance.(Ref(M), Ref(q), data).^2)

naive_tangent_distances_r = zeros(d)
predicted_naive_distances_r= zeros(d)
true_naive_distances_r= zeros(d)

curvature_corrected_tangent_distances_r = zeros(d)
predicted_curvature_corrected_distances_r = zeros(d)
true_curvature_corrected_distances_r = zeros(d)

exact_tangent_distances_r = zeros(d)
exact_distances_r= zeros(d)

for rank in 1:d
    naive_log_q_data_r = Symmetric.([sum([nR_q[rank][i] * nU[rank][k,i] for i in 1:rank]) for k in 1:n])
    curvature_corrected_log_q_data_r = Symmetric.([sum([ccR_q[rank][i] * ccU[rank][k,i] for i in 1:rank]) for k in 1:n])
    exact_log_q_data_r = Symmetric.([sum([eR_q[rank][i] * eU[rank][k,i] for i in 1:rank]) for k in 1:n])
    
    # expoentiate back
    naive_data_r = exp.(Ref(M), Ref(q), naive_log_q_data_r)
    curvature_corrected_data_r = exp.(Ref(M), Ref(q), curvature_corrected_log_q_data_r)
    exact_data_r = exp.(Ref(M), Ref(q), exact_log_q_data_r)


    # compute relative tangent space error
    naive_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - naive_log_q_data_r).^2) / ref_distance
    curvature_corrected_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - curvature_corrected_log_q_data_r).^2) / ref_distance
    exact_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - exact_log_q_data_r).^2) / ref_distance


    # compute relative manifold error
    predicted_naive_distances_r[rank] = curvature_corrected_loss(M, q, data, naive_log_q_data_r)
    true_naive_distances_r[rank] = exact_loss(M, q, data, naive_log_q_data_r)
    predicted_curvature_corrected_distances_r[rank] = curvature_corrected_loss(M, q, data, curvature_corrected_log_q_data_r)
    true_curvature_corrected_distances_r[rank] = exact_loss(M, q, data, curvature_corrected_log_q_data_r)
    exact_distances_r[rank] = exact_loss(M, q, data, exact_log_q_data_r)
    
end

In [33]:
# We want plots with (1) the lower bound error, (2) the actually uncorrected manifold error and (3) the corrected manifold error
plot(1:d, [naive_tangent_distances_r, true_naive_distances_r, true_curvature_corrected_distances_r, exact_distances_r], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(0,1), xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$"))
savefig("results/artificial1D_errors_by_rank.svg")
plot(1:d, [naive_tangent_distances_r .+ 1e-4, true_naive_distances_r .+ 1e-4, true_curvature_corrected_distances_r .+ 1e-4, exact_distances_r .+ 1e-4], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(1e-4,1), xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$", :log), legend=:bottomleft)
savefig("results/artificial1D_logerrors_by_rank.svg")
for i in 1:d-1
    if i == 1
        plot(1:length(eCosts[1]), eCosts[1], label = "rank 1", ylims=(1e-4,1), yaxis=(L"$\varepsilon_{rel}$", :log))
    else
        plot!(1:length(eCosts[i]), eCosts[i], label = "rank $(i)", ylims=(1e-4,1), yaxis=(L"$\varepsilon_{rel}$", :log))
    end
end
savefig("results/artificial1D_exact_iterate_loss.svg")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_exact_iterate_loss.svg"

In [34]:
# It would be nice to also have a plot that tells us something about the error in predicting the manifold loss (using CCL) and the actual loss 
# (1) for the naive approach (2) for the curvature corrected approach
plot(1:d-1, (predicted_curvature_corrected_distances_r[1:end-1] .- true_curvature_corrected_distances_r[1:end-1] .+ 1e-16) ./ (curvature_corrected_tangent_distances_r[1:end-1] .* sqrt.(curvature_corrected_tangent_distances_r[1:end-1] .* ref_distance) .+ 1e-16), label=("curvature corrected"), xlims=(1,d-1),xaxis=("approximation rank"), yaxis=(L"$\delta_{rel}$"), color=3)
savefig("results/artificial1D_discrepancy_by_rank.svg")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_discrepancy_by_rank.svg"

### Benchmark different methods ###

In [20]:
# t = @benchmark naive_low_rank_approximation(M, q, data, 2)
# mean(t).time
# std(t).time
# # TODO also do the benchmarking in here similarly as run above | also do this for 3 different step sizes for exact | if benchmark=true
# @benchmark exact_low_rank_approximation(M, q, data, 2; stepsize=1/10000, max_iter=50) 
# @benchmark curvature_corrected_low_rank_approximation(M, q, data, 2) 

In [21]:
@benchmark naive_low_rank_approximation(M, q, data, 2)

BenchmarkTools.Trial: 284 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m15.472 ms[22m[39m … [35m25.371 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 23.77%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m16.809 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m17.589 ms[22m[39m ± [32m 1.794 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.40% ±  6.43%

  [39m [39m [39m [39m [39m▃[39m▃[39m█[39m▇[39m█[34m▂[39m[39m▃[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▂[39m▁[39m▂[39m▇[39m█[39m█[3

In [22]:
@benchmark curvature_corrected_low_rank_approximation(M, q, data, 2) 

BenchmarkTools.Trial: 126 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m33.512 ms[22m[39m … [35m119.995 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 6.14%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m36.329 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m39.672 ms[22m[39m ± [32m  9.210 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.51% ± 5.09%

  [39m [39m [39m█[39m▂[39m [34m [39m[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m█[39m█[39m█[39m▇[

In [23]:
@benchmark exact_low_rank_approximation(M, q, data, 2; stepsize=1/100000, max_iter=1) 

BenchmarkTools.Trial: 35 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m115.566 ms[22m[39m … [35m186.172 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 3.63%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m146.114 ms               [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m3.73%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m143.192 ms[22m[39m ± [32m 22.015 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.62% ± 2.01%

  [39m█[39m█[39m [39m▃[39m [39m▃[39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[34m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m█[39m█[39m▁[39m█

In [24]:
@benchmark exact_low_rank_approximation(M, q, data, 2; stepsize=1/100000, max_iter=50)  

BenchmarkTools.Trial: 2 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m2.623 s[22m[39m … [35m  2.763 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m3.16% … 3.04%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m2.693 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m3.10%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m2.693 s[22m[39m ± [32m98.974 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m3.10% ± 0.08%

  [34m█[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [34m█[39m[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[

In [25]:
nT = []
nΣ = []

for i in 1:d
    nbm = @benchmark naive_low_rank_approximation(M, q, data, $i)
    push!(nT, mean(nbm).time / 1e9)
    push!(nΣ, std(nbm).time / 1e9)
end

In [26]:
ccT = []
ccΣ = []

for i in 1:d
    ccbm = @benchmark curvature_corrected_low_rank_approximation(M, q, data, $i)
    push!(ccT, mean(ccbm).time / 1e9)
    push!(ccΣ, std(ccbm).time / 1e9)
end

In [27]:
eT1 = []
eΣ1 = []

for i in 1:d
    ebm1 = @benchmark exact_low_rank_approximation(M, q, data, $i; stepsize=1/100000, max_iter=1)
    push!(eT1, mean(ebm1).time / 1e9)
    push!(eΣ1, std(ebm1).time / 1e9)
end

In [28]:
eT = []
eΣ = []

for i in 1:d
    ebm = @benchmark exact_low_rank_approximation(M, q, data, $i; stepsize=1/100000, max_iter=50)
    push!(eT, mean(ebm).time / 1e9)
    push!(eΣ, std(ebm).time / 1e9)
end

In [29]:
# # write latex table 
# for i in 1:d
#         println("$(i) & " * raw"$" * "$(Float16(nT[i]))" * raw"\pm" * "$(Float16(nΣ[i]))" * raw"$" * " & " * raw"$" * "$(Float16(ccT[i])) " * raw"\pm" * " $(Float16(ccΣ[i]))" * raw"$" * " & " * raw"$" * "$(Float16(eT1[i])) " * raw"\pm" * " $(Float16(eΣ1[i]))" * raw"$" * " & " * raw"$" * "$(Float16(eT[i])) " * raw"\pm" * " $(Float16(eΣ[i]))" * raw"$ " * raw"\\ ")
# end

# for i in 1:d
#         println("$(i) & " * raw"$" * "$(Float16(nT[i]))" * raw"$" * " & " * raw"$" * "$(Float16(ccT[i]))" * raw"$" * " & " * raw"$" * "$(Float16(eT1[i]))" * raw"$" * " & " * raw"$" * "$(Float16(eT[i]))" * raw"$ " * raw"\\ ")
# end


1 & $0.01753\pm0.001637$ & $0.03223 \pm 0.00377$ & $0.11334 \pm 0.01108$ & $2.586 \pm 0.1497$ \\ 
2 & $0.01822\pm0.001984$ & $0.03687 \pm 0.003742$ & $0.1268 \pm 0.00673$ & $2.717 \pm 0.1677$ \\ 
3 & $0.0184\pm0.002047$ & $0.04303 \pm 0.004223$ & $0.1345 \pm 0.00805$ & $3.09 \pm 0.002518$ \\ 
4 & $0.01817\pm0.00192$ & $0.0486 \pm 0.004436$ & $0.1432 \pm 0.00746$ & $3.656 \pm 0.11285$ \\ 
5 & $0.01834\pm0.002106$ & $0.05658 \pm 0.005733$ & $0.1644 \pm 0.007713$ & $3.941 \pm 0.357$ \\ 
6 & $0.01932\pm0.002438$ & $0.0622 \pm 0.005535$ & $0.10474 \pm 0.01435$ & $0.08514 \pm 0.005672$ \\ 
1 & $0.01753$ & $0.03223$ & $0.11334$ & $2.586$ \\ 
2 & $0.01822$ & $0.03687$ & $0.1268$ & $2.717$ \\ 
3 & $0.0184$ & $0.04303$ & $0.1345$ & $3.09$ \\ 
4 & $0.01817$ & $0.0486$ & $0.1432$ & $3.656$ \\ 
5 & $0.01834$ & $0.05658$ & $0.1644$ & $3.941$ \\ 
6 & $0.01932$ & $0.0622$ & $0.10474$ & $0.08514$ \\ 


In [30]:
# methods above each other and results per rank in colums
println("tHOSVD" * prod([" & " * raw"$" * "$(Float16(nT[i]))" * raw"$" for i in 1:d]) * raw"\\ ")
println("CC-tHOSVD" * prod([" & " * raw"$" * "$(Float16(ccT[i]))" * raw"$" for i in 1:d]) * raw"\\ ")
println("MC-tHOSVD (1 iteration)" * prod([" & " * raw"$" * "$(Float16(eT1[i]))" * raw"$" for i in 1:d]) * raw"\\ ")
println("MC-tHOSVD (50 iteration)" * prod([" & " * raw"$" * "$(Float16(eT[i]))" * raw"$" for i in 1:d]) * raw"\\ ")

tHOSVD & $0.01753$ & $0.01822$ & $0.0184$ & $0.01817$ & $0.01834$ & $0.01932$\\ 
CC-tHOSVD & $0.03223$ & $0.03687$ & $0.04303$ & $0.0486$ & $0.05658$ & $0.0622$\\ 
MC-tHOSVD (1 iteration) & $0.11334$ & $0.1268$ & $0.1345$ & $0.1432$ & $0.1644$ & $0.10474$\\ 
MC-tHOSVD (50 iteration) & $2.586$ & $2.717$ & $3.09$ & $3.656$ & $3.941$ & $0.08514$\\ 
