In [None]:
using Revise

using Test
import TensorCrossInterpolation as TCI
import Random

In [None]:
using Test
import TensorCrossInterpolation as TCI
import TensorCrossInterpolation: rank, linkdims, TensorCI2, MultiIndex, evaluate, crossinterpolate2, pivoterror, tensortrain
import Random
import QuanticsGrids as QD

In [None]:
import TensorCrossInterpolation as TCI
using TensorCrossInterpolation
using Random
using Test
using ITensors
ITensors.disable_warn_order()

In [None]:
pivotsearch = :full
seed = 124

Random.seed!(seed)

#==
R = 30
abstol = 1e-1
grid = QD.DiscretizedGrid{1}(R, (0.0,), (1.0,))

f(bitlist) = fx(QD.quantics_to_origcoord(grid, bitlist)[1])

B = 2^(-30) # global variable

nc = 1000
coeffs = [randn() for _ in 1:nc]
exps = 10 .* rand(nc)

function fx(x)
    #return cos(x / B) * cos(x / (4 * sqrt(5) * B)) * exp(-x^2) + 2 * exp(-x)
    #cos(x^2) + exp(-x^2) + 2 * exp(-x)
    sum(c * exp(- e * x^2) for (c, e) in zip(coeffs, exps))
end
==#

R = 20
abstol = 1e-3
sites = [Index(2, "n=$n") for n in 1:R]

Ψ = random_mps(sites; linkdims=20)

tensor = Array(reduce(*, Ψ), reverse(sites))
tensor .+= (0.1 * abstol) .* randn(size(tensor))

f(x) = tensor[x...]

localdims = fill(2, R)
firstpivot = ones(Int, R)
tci, ranks, errors = crossinterpolate2(
    Float64,
    f,
    localdims,
    [firstpivot];
    tolerance=abstol,
    maxbonddim=1000,
    maxiter=20,
    loginterval=1,
    verbosity=1,
    normalizeerror=false,
    pivotsearch=pivotsearch,
)
@show abstol

In [None]:
#@show length(tci.globalpivots)
#@show length(TCI.reducedglobalpivots(tci.globalpivots))
#@show length(TCI.fullglobalpivots(TCI.reducedglobalpivots(tci.globalpivots)))

In [None]:
for b in 1:length(tci)-1
    @show length(TCI.Iset(tci, b+1)), length(TCI.Jset(tci, b))
end

In [None]:
#for orthocenter in [1, length(tci), length(tci) ÷ 2]
for orthocenter in [1]
    tt = TCI.TensorTrain(TCI.sitetensors(tci, f; orthocenter=orthocenter))
    TCI.compress!(tt, :SVD; tolerance=1e-8)
    println("Orthocenter: ", orthocenter)
    @show TCI.linkdims(tt)
    println("Error on pivots:")
    for b in 1:length(tci)-1
        Iset = TCI.Iset(tci, b+1)
        Jset = TCI.Jset(tci, b)
        diff = maximum(abs, [TCI.evaluate(tt, vcat(i, j)) - f(vcat(i, j)) for i in Iset, j in Jset])
        println(b, "    ", diff, " ", diff > abstol)
    end
    println("Error on T:")
    for l in 1:length(tci)-1
        diff = [TCI.evaluate(tt, vcat(i, m, j)) - f(vcat(i, m, j)) for i in TCI.Iset(tci, l), j in TCI.Jset(tci, l), m in 1:localdims[l]]
        println(l, "    ", maximum(abs, diff), " ", maximum(abs, diff) > abstol)
    end

    println("Error on Π:")
    for b in 1:length(tci)-1
        Iset_b = TCI.Iset(tci, b)
        Jset_bp1 = TCI.Jset(tci, b+1)
        diff = [
            TCI.evaluate(tt, vcat(i, i1, i2, j)) - f(vcat(i, i1, i2, j)) for i in Iset_b, j in Jset_bp1, i1 in 1:localdims[b], i2 in 1:localdims[b+1]]
        println(b, "    ", maximum(abs, diff), " ", maximum(abs, diff) > abstol)
    end

    println("error on global pivots: ", maximum([abs(TCI.evaluate(tt, p) - f(p)) for p in tci.globalpivots]))
end

In [None]:
tt = TCI.TensorTrain(TCI.sitetensors(tci, f))
tt_old = TCI.TensorTrain(TCI.sitetensors_site0update(tci, f))

@show TCI.linkdims(tt)
@show TCI.linkdims(tt_old)

In [None]:
for l in 1:length(tci)
    Iset = TCI.Iset(tci, l)
    Jset = TCI.Jset(tci, l)
    println(l, "    ", maximum(abs, [TCI.evaluate(tt, vcat(i, l, j)) - f(vcat(i, l, j)) for i in Iset, j in Jset, l in 1:localdims[l]]))
end

In [None]:
maximum([abs(TCI.evaluate(tt, p) - f(p)) for p in tci.globalpivots])

In [None]:
for b in 1:length(tci)-1
    Iset = TCI.Iset(tci, b+1)
    Jset = TCI.Jset(tci, b)
    println(b, "    ", maximum(abs, [TCI.evaluate(tt_old, vcat(i, j)) - f(vcat(i, j)) for i in Iset, j in Jset]))
end

maximum([abs(TCI.evaluate(tt_old, p) - f(p)) for p in tci.globalpivots])

In [None]:
using Plots

p = plot(ylims=(0, 2))
b = 5
plot!(p, sort(vec(abs.(tt[b]))))
plot!(p, sort(vec(abs.(tt_old[b]))))

In [None]:
length(tci.globalpivots)