# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using BenchmarkTools

In [20]:
include("../../../src/decompositions/signals/naive_low_rank_approximation.jl")
include("../../../src/decompositions/signals/curvature_corrected_low_rank_approximation.jl")
include("../../../src/decompositions/signals/exact_low_rank_approximation.jl")

include("../../../src/functions/loss_functions/curvature_corrected_loss.jl")
include("../../../src/functions/loss_functions/exact_loss.jl")

exact_loss (generic function with 2 methods)

### Load data and construct manifold ###

In [3]:
# load data
M = SymmetricPositiveDefinite(3)
d = manifold_dimension(M)
n = 100  # 100


100

In [4]:
e = 1. * Matrix(I, 3, 3)
# compute basis
Θ = get_basis(M, e, DefaultOrthonormalBasis())
#  construct data
τ = 2.  # variance
σ = .05  # variance
Xₑ = Θ.data[4]
print(Xₑ)

Random.seed!(31)
predata = [exp(M, e, sqrt(τ) * randn(1)[1] * Xₑ) for i in 1:n]

data = [exp(M, predata[i], random_tangent(M, predata[i], Val(:Gaussian), σ)) for i in 1:n]; # ∈ P(3)^n


[0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0]

In [32]:
# Export slice image
num_export = 10
asymptote_export_SPD("results/artificial1D_orig.asy", data=data[1:min(num_export,n)], scale_axes=(2,2,2)); 

### Construct low rank approximation ###

In [5]:
q = mean(M, data)
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [8]:
(eRr_q, eUr), costs = exact_low_rank_approximation(M, q, data, 2; stepsize=1/10000, max_iter=10)

(([[0.17525404284891413 -0.287957326702762 0.14872841784557383; -0.28795732670276203 -0.00010583998011175689 0.2667710267880373; 0.14872841784557386 0.2667710267880373 0.15378791440207504], [0.13209193428322943 0.09619096703363013 0.03361047693778811; 0.09619096703363016 -22.690965378102366 0.07797447309149758; 0.03361047693778812 0.07797447309149758 0.03476301769969326]], [0.04141185945717856 0.043589454988513694; -0.062377171919476435 0.06746220792835492; … ; 0.08428851980435229 -0.04426547676489212; 0.09977091731476738 0.09542655939149962]), [0.049940970415489165, 0.019723246782690336, 0.05663330241992408, 0.011140240157969297, 0.05555910194147775, 0.013479927280286783, 0.057924414731575044, 0.010207364781930284, 0.05457367466117646, 0.015238115241389493])

In [31]:
# costs

10-element Vector{Float64}:
 0.05179777061223446
 0.015948606491658798
 0.05660493646913368
 0.009919243079205645
 0.05294981080865168
 0.01597051307193173
 0.05776102150791481
 0.0090382802715633
 0.051570914309466524
 0.018407499028757636

In [24]:
max_iter = 50

nR_q = []
nU = []
ccR_q = []
ccU = []
eR_q = []
eU = []
eCosts = []
for i in 1:d  
    println("#$(i) | computing naive low-rank approximation")
    nRr_q, nUr = naive_low_rank_approximation(M, q, data, i)
    push!(nR_q, nRr_q)
    push!(nU, nUr)
    println("#$(i) | computing curvature corrected low-rank approximation")
    ccRr_q, ccUr = curvature_corrected_low_rank_approximation(M, q, data, i); 
    push!(ccR_q, ccRr_q)
    push!(ccU, ccUr)
    println("#$(i) | computing exact low-rank approximation")
    (eRr_q, eUr), eCostsr = exact_low_rank_approximation(M, q, data, i; stepsize=1/100000, max_iter=max_iter); 
    push!(eR_q, eRr_q)
    push!(eU, eUr)
    push!(eCosts, eCostsr)
end

#1 | computing naive low-rank approximation
#1 | computing curvature corrected low-rank approximation
#1 | computing exact low-rank approximation
#2 | computing naive low-rank approximation
#2 | computing curvature corrected low-rank approximation
#2 | computing exact low-rank approximation
#3 | computing naive low-rank approximation
#3 | computing curvature corrected low-rank approximation
#3 | computing exact low-rank approximation
#4 | computing naive low-rank approximation
#4 | computing curvature corrected low-rank approximation
#4 | computing exact low-rank approximation
#5 | computing naive low-rank approximation
#5 | computing curvature corrected low-rank approximation
#5 | computing exact low-rank approximation
#6 | computing naive low-rank approximation
#6 | computing curvature corrected low-rank approximation
#6 | computing exact low-rank approximation


In [25]:
ref_distance = sum(distance.(Ref(M), Ref(q), data).^2)

naive_tangent_distances_r = zeros(d)
predicted_naive_distances_r= zeros(d)
true_naive_distances_r= zeros(d)

curvature_corrected_tangent_distances_r = zeros(d)
predicted_curvature_corrected_distances_r = zeros(d)
true_curvature_corrected_distances_r = zeros(d)

exact_tangent_distances_r = zeros(d)
exact_distances_r= zeros(d)

for rank in 1:d
    naive_log_q_data_r = Symmetric.([sum([nR_q[rank][i] * nU[rank][k,i] for i in 1:rank]) for k in 1:n])
    curvature_corrected_log_q_data_r = Symmetric.([sum([ccR_q[rank][i] * ccU[rank][k,i] for i in 1:rank]) for k in 1:n])
    exact_log_q_data_r = Symmetric.([sum([eR_q[rank][i] * eU[rank][k,i] for i in 1:rank]) for k in 1:n])
    
    # expoentiate back
    naive_data_r = exp.(Ref(M), Ref(q), naive_log_q_data_r)
    curvature_corrected_data_r = exp.(Ref(M), Ref(q), curvature_corrected_log_q_data_r)
    exact_data_r = exp.(Ref(M), Ref(q), exact_log_q_data_r)


    # compute relative tangent space error
    naive_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - naive_log_q_data_r).^2) / ref_distance
    curvature_corrected_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - curvature_corrected_log_q_data_r).^2) / ref_distance
    exact_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - exact_log_q_data_r).^2) / ref_distance


    # compute relative manifold error
    predicted_naive_distances_r[rank] = curvature_corrected_loss(M, q, data, naive_log_q_data_r)
    true_naive_distances_r[rank] = exact_loss(M, q, data, naive_log_q_data_r)
    predicted_curvature_corrected_distances_r[rank] = curvature_corrected_loss(M, q, data, curvature_corrected_log_q_data_r)
    true_curvature_corrected_distances_r[rank] = exact_loss(M, q, data, curvature_corrected_log_q_data_r)
    exact_distances_r[rank] = exact_loss(M, q, data, exact_log_q_data_r)
    
end

In [26]:
# We want plots with (1) the lower bound error, (2) the actually uncorrected manifold error and (3) the corrected manifold error
plot(1:d, [naive_tangent_distances_r, true_naive_distances_r, true_curvature_corrected_distances_r, exact_distances_r], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(0,1), xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$"))
savefig("results/artificial1D_errors_by_rank.png")
plot(1:d, [naive_tangent_distances_r .+ 1e-16, true_naive_distances_r .+ 1e-16, true_curvature_corrected_distances_r .+ 1e-16, exact_distances_r .+ 1e-16], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(1e-16,1), xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$", :log), legend=:bottomleft)
savefig("results/artificial1D_logerrors_by_rank.png")
for i in 1:d-1
    if i == 1
        plot(1:length(eCosts[1]), eCosts[1], label = "rank 1", ylims=(1e-4,1), yaxis=(L"$\varepsilon_{rel}$", :log))
    else
        plot!(1:length(eCosts[i]), eCosts[i], label = "rank $(i)", ylims=(1e-4,1), yaxis=(L"$\varepsilon_{rel}$", :log))
    end
end
savefig("results/artificial1D_exact_iterate_loss.png")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_exact_iterate_loss.png"

In [None]:
# It would be nice to also have a plot that tells us something about the error in predicting the manifold loss (using CCL) and the actual loss 
# (1) for the naive approach (2) for the curvature corrected approach
plot(1:d, [(predicted_naive_distances_r .+ 1e-16) ./ (true_naive_distances_r .+ 1e-16), (predicted_curvature_corrected_distances_r .+ 1e-16) ./ (true_curvature_corrected_distances_r .+ 1e-16)], label = ["discrepancy in initialisation" "discrepancy in solutions"], xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\delta_{rel}$"))
savefig("results/artificial1D_discrepancy_by_rank.png")
plot(1:d, [(predicted_naive_distances_r .+ 1e-16) ./ (true_naive_distances_r .+ 1e-16), (predicted_curvature_corrected_distances_r .+ 1e-16) ./ (true_curvature_corrected_distances_r .+ 1e-16)], label = ["discrepancy in initialisation" "discrepancy in solutions"], xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\delta_{rel}$", :log), legend=:bottomleft)
savefig("results/artificial1D_logdiscrepancy_by_rank.png")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_logdiscrepancy_by_rank.png"

### Benchmark different methods ###

In [None]:
t = @benchmark naive_low_rank_approximation(M, q, data, 2)
mean(t).time
std(t).time
# TODO also do the benchmarking in here similarly as run above | also do this for 3 different step sizes for exact | if benchmark=true
@benchmark exact_low_rank_approximation(M, q, data, 2; stepsize=1/10000, max_iter=50) 
@benchmark curvature_corrected_low_rank_approximation(M, q, data, 2) 

In [12]:
@benchmark naive_low_rank_approximation(M, q, data, 2)

BenchmarkTools.Trial: 273 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m15.884 ms[22m[39m … [35m29.450 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m17.508 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m18.321 ms[22m[39m ± [32m 2.068 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.10% ± 5.79%

  [39m [39m [39m [39m [39m [39m▁[39m▄[39m█[39m▄[39m [39m [39m [34m [39m[39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▃[39m▂[39m▄[39m▆[39m▆[39m█[39m

In [13]:
@benchmark curvature_corrected_low_rank_approximation(M, q, data, 2) 

BenchmarkTools.Trial: 145 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m31.986 ms[22m[39m … [35m43.204 ms[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m0.00% … 0.00%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m33.731 ms              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m0.00%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m34.649 ms[22m[39m ± [32m 2.331 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.35% ± 4.75%

  [39m [39m [39m▁[39m [39m▁[39m▆[39m█[39m▃[39m▁[34m▇[39m[39m▁[39m [39m [39m [32m▁[39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m 
  [39m▄[39m▄[39m█[39m█[39m█[39m█[39m

In [14]:
@benchmark exact_low_rank_approximation(M, q, data, 2; stepsize=1/100000, max_iter=50) 

BenchmarkTools.Trial: 4 samples with 1 evaluation.
 Range [90m([39m[36m[1mmin[22m[39m … [35mmax[39m[90m):  [39m[36m[1m1.269 s[22m[39m … [35m  1.353 s[39m  [90m┊[39m GC [90m([39mmin … max[90m): [39m2.87% … 2.76%
 Time  [90m([39m[34m[1mmedian[22m[39m[90m):     [39m[34m[1m1.317 s              [22m[39m[90m┊[39m GC [90m([39mmedian[90m):    [39m2.74%
 Time  [90m([39m[32m[1mmean[22m[39m ± [32mσ[39m[90m):   [39m[32m[1m1.314 s[22m[39m ± [32m35.715 ms[39m  [90m┊[39m GC [90m([39mmean ± σ[90m):  [39m2.69% ± 0.20%

  [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [34m█[39m[39m [39m [39m [39m [39m [32m [39m[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m [39m█[39m [39m 
  [39m█[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁[39m▁