# Vertices for Hubbard atom

In [None]:
using Revise
using LinearAlgebra
using SparseIR
import SparseIR: valueim
using MSSTA
import MSSTA: QuanticsInd, quantics_to_index, asqubits, index_to_quantics, QubitInd, qubit_to_index, index_to_qubit
using OvercompleteIR
import OvercompleteIR: PHConvention, freq_box
import OvercompleteIR.Atom: HubbardAtom, MagneticChannel, chi0, full_vertex, gamma
using Plots
import TensorCrossInterpolation as TCI

newaxis = [CartesianIndex()]

In [None]:
using ITensors

In [None]:
# Taken from "olving the Bethe–Salpeter equation with exponential convergence"
U = 2.3
beta  = 1.55
ch = DensityChannel()
#phconv = PHConvention()
#ppconv = PPConvention()
model = HubbardAtom(U, beta)

In [None]:
#phbox = freq_box(phconv, 4, 3);
#box4 = to_full_freq.(phconv, phbox);

In [None]:
# TCI with R bits
function create_func(func, R)
    # Origin of frequency box
    origin = -2^(R-1)

    m = MSSTA.DiscreteMesh{3}(R, (origin, origin, origin))

    function q_to_n(q::Vector{QuanticsInd{3}})::NTuple{3, Int}
        return MSSTA.originalcoordinate(m, q)
    end

    function func_q(q::Vector{QuanticsInd{3}})::ComplexF64
        idx_3d = q_to_n(q)
        if iseven(idx_3d[1]) || iseven(idx_3d[2]) || isodd(idx_3d[3])
            return 0.0
        end
        return func((FermionicFreq(idx_3d[1]), FermionicFreq(idx_3d[2]), BosonicFreq(idx_3d[3])))
    end

    return q_to_n, func_q, m
end

R = 15
q2n, fq, mesh = create_func(w->gamma(ch, model, w), R)
;

In [None]:
fI = x->fq(QuanticsInd{3}.(x))
localdims = fill(8, R)

firstpivot = convert.(Int, index_to_quantics((2, 2, 1), R))
#fq(QuanticsInd{3}.(firstpivot))
@show fI(firstpivot)
firstpivot = TCI.optfirstpivot(fI, localdims, firstpivot)
@show fI(firstpivot)

In [None]:
qtt, ranks, errors = TCI.crossinterpolate2(
    ComplexF64, fI, localdims,
    [firstpivot], tolerance=1e-8, maxiter=200, verbosity=1
)

In [None]:
q_diagonal = [MSSTA.quanticsindex(mesh, (2, 2, 1) .+ i) for i in 0:2^(R-5):2^R-1]
reconst_diagonal = [TCI.evaluate(qtt, convert.(Int, q)) for q in q_diagonal]
ref_diagonal = fq.(q_diagonal)
;

In [None]:
p = plot(yaxis=:log, ylims=(1e-10,1e+4))
plot!(p, abs.(reconst_diagonal), marker=:x)
plot!(p, abs.(ref_diagonal), marker=:+)
plot!(p, abs.(ref_diagonal .- reconst_diagonal))

In [None]:
f = collect(-21:2:21) # box for fermionic frequencies
b = 0 # bosonic frequency

is = collect(Iterators.product(f, f,  b))
ws = [MSSTA.meshindex(mesh, i) for i in is]
qs = [MSSTA.quanticsindex(mesh, w) for w in ws]

vals = [TCI.evaluate(qtt, convert.(Int, q)) for q in qs]
;

In [None]:
maxval = maximum(abs, vals)
heatmap(f, f, real.(vals); c=:redsblues, clim=(-maxval, maxval))

In [None]:
ws_ = [(FermionicFreq(i[1]), FermionicFreq(i[2]), BosonicFreq(i[3])) for i in is]
vals_ref = gamma.(ch, model, ws_)

heatmap(real.(vals_ref); c=:redsblues, clim=(-maxval, maxval))

In [None]:
heatmap(log.(abs.(vals)))