# Vertices for Hubbard atom

In [None]:
using Revise
using Plots
using LinearAlgebra
using SparseIR
import SparseIR: valueim
using OvercompleteIR
import OvercompleteIR: PHConvention, freq_box
import OvercompleteIR.Atom: HubbardAtom, MagneticChannel, chi0, _delta, full_vertex, gamma

newaxis = [CartesianIndex()]

In [None]:
BLAS.set_num_threads(16)

In [None]:
println(Threads.nthreads())

In [None]:
using ITensors

println(ITensors.blas_get_num_threads())

In [None]:
N = 8
half_N = N ÷ 2
nw = 2^N
half_nw = 2^(N-1)
nw


In [None]:
wb = 0
box = Matrix{Tuple{FermionicFreq,FermionicFreq,BosonicFreq}}(undef, nw, nw)
for i in -half_nw:half_nw-1, j in -half_nw:half_nw-1
    box[i+half_nw+1, j+half_nw+1] = (FermionicFreq(2*i+1), FermionicFreq(2*j+1), BosonicFreq(2*wb))
end

In [None]:
U = 4.0
beta  = 1.0
atom = HubbardAtom(U, beta)
ch = DensityChannel()

chi0_(ch, atom::HubbardAtom, nnpm::Tuple{FermionicFreq, FermionicFreq, BosonicFreq}) = chi0(ch, atom, (nnpm[1], nnpm[3])) * _delta(nnpm[1], nnpm[2])

In [None]:
#data = Float64[]
#Us = LinRange(3.0, 4.0, 100)
#for U in Us
    #atom = HubbardAtom(U, beta)
    #push!(data, real(sum(gamma.(DensityChannel(), atom, box))))
#end
#plot(Us, data)

In [None]:
gamma_box = gamma.(ch, atom, box)
chi0_box = chi0_.(ch, atom, box)
full_box = full_vertex.(ch, atom, box)
;

In [None]:
full_reconst = gamma_box + beta^(-2) * gamma_box * chi0_box * full_box
;

In [None]:
println(maximum(abs, full_box - full_reconst))
println(maximum(abs, full_box))

In [None]:
heatmap(abs.(full_box))

In [None]:
heatmap(abs.(gamma_box))

In [None]:
function to_tensor(vertex)
    tensor = reshape(vertex, repeat([2,], 2*N)...)

    dims = Int[]
    for i in 1:N
        push!(dims, i)
        push!(dims, i+N)
    end
    tensor = reshape(permutedims(tensor, dims), repeat([4,], N)...)
    return tensor
end

In [None]:
u, s, vt = svd(reshape(to_tensor(full_box), :, 4^half_N))
plot(s/s[1], yaxis=:log)

In [None]:
u, s, vt = svd(reshape(to_tensor(chi0_box), :, 4^half_N))
plot(s/s[1], yaxis=:log)

In [None]:
u, s, vt = svd(reshape(to_tensor(gamma_box), :, 4^half_N))
plot(s/s[1], yaxis=:log)

In [None]:
sites = siteinds(4, N)

function tensor_to_mps(tensor)
    cutoff = 1E-20
    maxdim = 100
    M = MPS(tensor, sites; cutoff=cutoff, maxdim=maxdim)
    return M 
end

In [None]:
chi0_mps = tensor_to_mps(to_tensor(chi0_box))

In [None]:
full_mps = tensor_to_mps(to_tensor(full_box))

In [None]:
gamma_mps = tensor_to_mps(to_tensor(gamma_box))

In [None]:
chi0_reconst = Array(reduce(*, chi0_mps), sites...)
full_reconst = Array(reduce(*, full_mps), sites...)
gamma_reconst = Array(reduce(*, gamma_mps), sites...)
;

In [None]:
function from_tensor(tensor)

    # (f_1, b_1, ...., f_N, b_N) => (f_1, ..., f_N, b_1, ..., b_N)
    dims = Int[]
    for i in 1:N
        push!(dims, 2*i-1)
    end
    for i in 1:N
        push!(dims, 2*i)
    end
    println(dims)

    res = reshape(tensor, repeat([2,], 2*N)...)
    println(size(res))
    res = permutedims(res, dims)
    return reshape(res, 2^N, 2^N)
end

In [None]:
heatmap(abs.(from_tensor(full_reconst)))

In [None]:
maximum(abs, from_tensor(full_reconst) .- full_box)