# Vertices for Hubbard atom

In [None]:
using Revise
using LinearAlgebra
using SparseIR
import SparseIR: valueim
using Quantics
import Quantics: QuanticsInd, index_to_fused_quantics
using OvercompleteIR
import OvercompleteIR: PHConvention, freq_box
import OvercompleteIR.Atom: HubbardAtom, MagneticChannel, chi0, full_vertex, gamma
using Plots
import TensorCrossInterpolation as TCI

newaxis = [CartesianIndex()]

In [None]:
using ITensors

In [None]:
U = 1.0
beta  = 10.0
ch = DensityChannel()
conv = PHConvention()
model = HubbardAtom(U, beta)

In [None]:
# TCI with R bits
function create_func(func, R)
    # Origin of frequency box
    origin = -2^(R-1)

    m = Quantics.InherentDiscreteGrid{3}(R, (origin, origin, origin))

    function q_to_n(q::Vector{QuanticsInd{3}})::NTuple{3, Int}
        return Quantics.originalcoordinate(m, reverse(q)) # Resverse quantics indices
    end

    function func_q(q::Vector{QuanticsInd{3}})::ComplexF64
        idx_3d = q_to_n(q)
        if iseven(idx_3d[1]) || iseven(idx_3d[2]) || isodd(idx_3d[3])
            return 0.0
        end
        return func((FermionicFreq(idx_3d[1]), FermionicFreq(idx_3d[2]), BosonicFreq(idx_3d[3])))
    end

    return q_to_n, func_q, m
end

R = 20
q2n, fq, mesh = create_func(w->gamma(ch, model, w), R)
;

In [None]:
fI = x->fq(QuanticsInd{3}.(x))
localdims = fill(8, R)

firstpivot = convert.(Int, index_to_fused_quantics((2, 2, 1), R))

@show fI(firstpivot)
firstpivot = TCI.optfirstpivot(fI, localdims, firstpivot)
@show fI(firstpivot)

In [None]:
tol = 1e-5

qtt, ranks, errors = TCI.crossinterpolate2(
    ComplexF64, fI, localdims,
    [firstpivot], tolerance=tol, maxiter=10, verbosity=1
)

In [None]:
TCI.makecanonical!(qtt, fI)

In [None]:
#TCI.print_nesting_info(qtt)

In [None]:
nglobalpivots = TCI.insert_global_pivots!(qtt, fI, nsearch=20, verbosity=1, tolerance=tol)

In [None]:
q_diagonal = [Quantics.to_quantics(mesh, (2, 2, 1) .+ i) for i in 0:2^(R-5):2^R-1]
reconst_diagonal = [TCI.evaluate(qtt, convert.(Int, q)) for q in q_diagonal]
ref_diagonal = fq.(q_diagonal)
;

In [None]:
plot(TCI.linkdims(qtt))

In [None]:
p = plot(yaxis=:log, ylims=(1e-10,1e+4))
plot!(p, abs.(reconst_diagonal), marker=:x)
plot!(p, abs.(ref_diagonal), marker=:+)
plot!(p, abs.(ref_diagonal .- reconst_diagonal))

In [None]:
f = collect(-101:2:101)
b = 2

is = collect(Iterators.product(f, f,  b))
ws = [Quantics.gridpoint(mesh, i) for i in is]
qs = [Quantics.to_quantics(mesh, w) for w in ws]

vals = [TCI.evaluate(qtt, convert.(Int, q)) for q in qs]
refs = fq.(qs)
;

In [None]:
heatmap((abs.(vals)))

In [None]:
heatmap((abs.(refs)))

In [None]:
col = size(refs, 2) ÷ 2
p = plot(yaxis=:log, ylims=(1e-10,1e+4))
plot!(p, abs.(refs[:, col]), marker=:x)
plot!(p, abs.(vals[:, col]), marker=:o)
plot!(p, abs.(refs[:, col] .- vals[:, col]))

In [None]:
heatmap(log.(abs.(vals .- refs)))

In [None]:
plot(TCI.linkdims(qtt), yaxis=:log, marker=:x)