# Low-rank approximation on $\mathcal{P}(d)$ - the space of $d$-dimensional SPD matrices

In this notebook we want to get some intuition in different approaches for computing low-rank approximations for manifold-valued signals

In [1]:
using Manifolds
using Manopt
using LinearAlgebra
using Random
using Plots
using LaTeXStrings
using Profile

In [92]:
include("../../../src/decompositions/signals/naive_low_rank_approximation.jl")
include("../../../src/decompositions/signals/curvature_corrected_low_rank_approximation.jl")
include("../../../src/decompositions/signals/exact_low_rank_approximation.jl")

include("../../../src/functions/loss_functions/curvature_corrected_loss.jl")
include("../../../src/functions/loss_functions/exact_loss.jl")

exact_loss (generic function with 2 methods)

### Load data and construct manifold ###

In [54]:
# load data
M = SymmetricPositiveDefinite(3)
d = manifold_dimension(M)
n = 100  # 100


100

In [55]:
e = 1. * Matrix(I, 3, 3)
# compute basis
Θ = get_basis(M, e, DefaultOrthonormalBasis())
#  construct data
τ = 2.  # variance
σ = .05  # variance
Xₑ = Θ.data[4]
print(Xₑ)

Random.seed!(31)
predata = [exp(M, e, sqrt(τ) * randn(1)[1] * Xₑ) for i in 1:n]

data = [exp(M, predata[i], random_tangent(M, predata[i], Val(:Gaussian), σ)) for i in 1:n]; # ∈ P(3)^n


[0.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 0.0]

In [50]:
# Export slice image
num_export = 10
asymptote_export_SPD("results/artificial1D_orig.asy", data=data[1:min(num_export,n)], scale_axes=(2,2,2)); 

### Construct low rank approximation ###

In [56]:
q = mean(M, data)
log_q_data = log.(Ref(M), Ref(q), data);  # ∈ T_q P(3)^n

In [95]:
st_ccRr_q, st_ccUr = stochastic_curvature_corrected_low_rank_approximation(M, q, data, 3; stepsize=1/10, max_iter=2000); 
st_curvature_corrected_log_q_data_r = Symmetric.([sum([st_ccRr_q[i] * st_ccUr[k,i] for i in 1:3]) for k in 1:n])
st_curvature_corrected_data_r = exp.(Ref(M), Ref(q), st_curvature_corrected_log_q_data_r)
curvature_corrected_loss(M, q, data, st_curvature_corrected_log_q_data_r)

Initial 
# 1     
# 2     
# 3     
# 4     
# 5     
# 6     
# 7     
# 8     
# 9     
# 10    
# 11    
# 12    
# 13    
# 14    
# 15    
# 16    
# 17    
# 18    
# 19    
# 20    
# 21    
# 22    
# 23    
# 24    
# 25    
# 26    
# 27    
# 28    
# 29    
# 30    
# 31    
# 32    
# 33    
# 34    
# 35    
# 36    
# 37    
# 38    
# 39    
# 40    
# 41    
# 42    
# 43    
# 44    
# 45    
# 46    
# 47    
# 48    
# 49    
# 50    
# 51    
# 52    
# 53    
# 54    
# 55    
# 56    
# 57    
# 58    
# 59    
# 60    
# 61    
# 62    
# 63    
# 64    
# 65    
# 66    
# 67    
# 68    
# 69    
# 70    
# 71    
# 72    
# 73    
# 74    
# 75    
# 76    
# 77    
# 78    
# 79    
# 80    
# 81    
# 82    
# 83    
# 84    
# 85    
# 86    
# 87    
# 88    
# 89    
# 90    
# 91    
# 92    
# 93    
# 94    
# 95    
# 96    
# 97    
# 98    
# 99    
# 100   
# 101   
# 102   
# 103   
# 104   
# 105   
# 106   
# 107   
# 108   
# 109   
# 110   
#

0.10074589713974623

In [38]:
Profile.print()

Overhead ╎ [+additional indent] Count File:Line; Function
    1╎1     ...e/abstractarray.jl:1007; copyto_unaliased!(deststyle::...
    1╎1     ...e/abstractarray.jl:1008; copyto_unaliased!(deststyle::...
    1╎1     @Base/array.jl:325; _copyto_impl!(dest::Vector{Fl...
    1╎1     @Base/broadcast.jl:509; _bcs(shape::Tuple{Base.OneTo{...
    1╎1     ...ultidimensional.jl:848; _unsafe_getindex(::IndexLinea...
     ╎6969  @Base/task.jl:423; (::IJulia.var"#15#18")()
     ╎ 6969  ...ia/src/eventloop.jl:8; eventloop(socket::ZMQ.Socket)
     ╎  6969  @Base/essentials.jl:714; invokelatest
     ╎   6969  @Base/essentials.jl:716; #invokelatest#2
     ╎    6969  ...xecute_request.jl:67; execute_request(socket::ZMQ....
     ╎     6969  ...oftGlobalScope.jl:65; softscope_include_string(m:...
     ╎    ╎ 6969  @Base/loading.jl:1196; include_string(mapexpr::t...
    2╎    ╎  6969  @Base/boot.jl:373; eval
     ╎    ╎   6966  ...approximation.jl:8; (::var"#curvature_correcte...
     ╎    ╎    1211  ...p

In [6]:
# exact_low_rank_approximation(M, q, data, 1; stepsize=1/1000, max_iter=200); 

In [35]:
max_iter = 50

nR_q = []
nU = []
ccR_q = []
ccU = []
eR_q = []
eU = []
for i in 1:d
    println("#$(i) | computing naive low-rank approximation")
    nRr_q, nUr = naive_low_rank_approximation(M, q, data, i)
    push!(nR_q, nRr_q)
    push!(nU, nUr)
    println("#$(i) | computing curvature corrected low-rank approximation")
    ccRr_q, ccUr = curvature_corrected_low_rank_approximation(M, q, data, i; stepsize=1/100000, max_iter=max_iter); 
    push!(ccR_q, ccRr_q)
    push!(ccU, ccUr)
    println("#$(i) | computing exact low-rank approximation")
    eRr_q, eUr = exact_low_rank_approximation(M, q, data, i; stepsize=1/1000, max_iter=max_iter); 
    push!(eR_q, eRr_q)
    push!(eU, eUr)
end

#1 | computing naive low-rank approximation
#1 | computing curvature corrected low-rank approximation
Initial  F(x): 1.07408003505 | 
# 1     change: 0.003679917 |  F(x): 0.14818842524 | 
# 2     change: 0.001352597 |  F(x): 0.02305719338 | 
# 3     change: 0.000497828 |  F(x): 0.00609899385 | 
# 4     change: 0.000183553 |  F(x): 0.00379222475 | 
# 5     change: 0.000067840 |  F(x): 0.00347687144 | 
# 6     change: 0.000025153 |  F(x): 0.00343347165 | 
# 7     change: 0.000009366 |  F(x): 0.00342744609 | 
# 8     change: 0.000003507 |  F(x): 0.00342659990 | 
# 9     change: 0.000001322 |  F(x): 0.00342647932 | 
# 10    change: 0.000000503 |  F(x): 0.00342646180 | 
The algorithm performed a step with a change (5.031321101913103e-7) less than 1.0e-6.
#1 | computing exact low-rank approximation
Initial  F(x): 0.13166108791 | 
# 1     change: 0.000756393 |  F(x): 0.12184377278 | 
# 2     change: 0.000659314 |  F(x): 0.11319636661 | 
# 3     change: 0.000582925 |  F(x): 0.10557464829 | 
# 

In [36]:
ref_distance = sum(distance.(Ref(M), Ref(q), data).^2)

naive_tangent_distances_r = zeros(d)
predicted_naive_distances_r= zeros(d)
true_naive_distances_r= zeros(d)

curvature_corrected_tangent_distances_r = zeros(d)
predicted_curvature_corrected_distances_r = zeros(d)
true_curvature_corrected_distances_r = zeros(d)

exact_tangent_distances_r = zeros(d)
exact_distances_r= zeros(d)

for rank in 1:d
    naive_log_q_data_r = Symmetric.([sum([nR_q[rank][i] * nU[rank][k,i] for i in 1:rank]) for k in 1:n])
    curvature_corrected_log_q_data_r = Symmetric.([sum([ccR_q[rank][i] * ccU[rank][k,i] for i in 1:rank]) for k in 1:n])
    exact_log_q_data_r = Symmetric.([sum([eR_q[rank][i] * eU[rank][k,i] for i in 1:rank]) for k in 1:n])
    
    # expoentiate back
    naive_data_r = exp.(Ref(M), Ref(q), naive_log_q_data_r)
    curvature_corrected_data_r = exp.(Ref(M), Ref(q), curvature_corrected_log_q_data_r)
    exact_data_r = exp.(Ref(M), Ref(q), exact_log_q_data_r)


    # compute relative tangent space error
    naive_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - naive_log_q_data_r).^2) / ref_distance
    curvature_corrected_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - curvature_corrected_log_q_data_r).^2) / ref_distance
    exact_tangent_distances_r[rank] = sum(norm.(Ref(M), Ref(q),  log_q_data - exact_log_q_data_r).^2) / ref_distance


    # compute relative manifold error
    predicted_naive_distances_r[rank] = curvature_corrected_loss(M, q, data, naive_log_q_data_r)
    true_naive_distances_r[rank] = exact_loss(M, q, data, naive_log_q_data_r)
    predicted_curvature_corrected_distances_r[rank] = curvature_corrected_loss(M, q, data, curvature_corrected_log_q_data_r)
    true_curvature_corrected_distances_r[rank] = exact_loss(M, q, data, curvature_corrected_log_q_data_r)
    exact_distances_r[rank] = exact_loss(M, q, data, exact_log_q_data_r)
end

In [37]:
# We want plots with (1) the lower bound error, (2) the actually uncorrected manifold error and (3) the corrected manifold error
plot(1:d, [naive_tangent_distances_r, true_naive_distances_r, true_curvature_corrected_distances_r, exact_distances_r], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(0,1), xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$"))
savefig("results/artificial1D_errors_by_rank.png")
plot(1:d, [naive_tangent_distances_r .+ 1e-16, true_naive_distances_r .+ 1e-16, true_curvature_corrected_distances_r .+ 1e-16, exact_distances_r .+ 1e-16], label = ["theoretical lower bound" "naive" "curvature corrected" "exact"], ylims=(1e-16,1), xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\varepsilon_{rel}$", :log), legend=:bottomleft)
savefig("results/artificial1D_logerrors_by_rank.png")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_logerrors_by_rank.png"

In [40]:
# It would be nice to also have a plot that tells us something about the error in predicting the manifold loss (using CCL) and the actual loss 
# (1) for the naive approach (2) for the curvature corrected approach
plot(1:d, [(predicted_naive_distances_r .+ 1e-16) ./ (true_naive_distances_r .+ 1e-16), (predicted_curvature_corrected_distances_r .+ 1e-16) ./ (true_curvature_corrected_distances_r .+ 1e-16)], label = ["discrepancy in initialisation" "discrepancy in solutions"], xlims=(1,d),xaxis=("approximation rank"), yaxis=(L"$\delta_{rel}$"))
savefig("results/artificial1D_discrepancy_by_rank.png")
plot(1:d, [(predicted_naive_distances_r .+ 1e-16) ./ (true_naive_distances_r .+ 1e-16), (predicted_curvature_corrected_distances_r .+ 1e-16) ./ (true_curvature_corrected_distances_r .+ 1e-16)], label = ["discrepancy in initialisation" "discrepancy in solutions"], xlims=(1,d), xaxis=("approximation rank"), yaxis=(L"$\delta_{rel}$", :log), legend=:bottomleft)
savefig("results/artificial1D_logdiscrepancy_by_rank.png")

"/Users/wdiepeveen/Documents/PhD/Projects/8 - Manifold-valued tensor decomposition/src/manifold-valued-tensors/experiments/1D/P3/results/artificial1D_logdiscrepancy_by_rank.png"