# Compress momentum & frequency dependence

In [None]:
using Revise
using Plots
using LinearAlgebra
using SparseIR
import SparseIR: valueim

newaxis = [CartesianIndex()]

In [None]:
BLAS.set_num_threads(16)

In [None]:
println(Threads.nthreads())

In [None]:
using ITensors
import ITensors.NDTensors

println(ITensors.blas_get_num_threads())

In [None]:
beta = 10.0

μ = 0.5

ek(kx, ky) = 2*cos(kx) + 2*cos(ky)

In [None]:
N = 8
half_N = N ÷ 2
nk = 2^N
nk

In [None]:
wmax = 20.0
basis = FiniteTempBasis(Fermionic(), beta, wmax, 1e-7)
smpl = MatsubaraSampling(basis)

In [None]:
kmesh = LinRange(0, 2π, nk)
v = smpl.sampling_points
iv = valueim.(v, beta)
nw = length(v)

giv0 = Array{ComplexF64}(undef, nw, nk, nk)
for kx in 1:nk, ky in 1:nk, w in 1:nw
    giv0[w, kx, ky] = 1 / (iv[w] - ek(2π*kx/nk, 2π*ky/nk) + μ)
end
size(giv0)

In [None]:
heatmap(reshape(abs.(giv0[1,:,:]), nk, nk))

In [None]:
heatmap(reshape(abs.(giv0[nw ÷ 2 + 1,:,:]), nk, nk))

In [None]:
tensor = reshape(giv0, nw, repeat([2,], 2*N)...)
;

In [None]:
indsw = [Index(nw, "w")]
indsx = [Index(2, "Qubit,kx=$(k)") for k in 1:N]
indsy = [Index(2, "Qubit,ky=$(k)") for k in 1:N]
indsall = vcat(indsw, indsx, indsy)
itens = ITensor(tensor, indsall)

In [None]:
half_inds_ = Tuple(vcat(indsx[1:half_N], indsy[1:half_N]))
@show half_inds_
u, s, vt = svd(itens, half_inds_)
;

In [None]:
s_ = Vector(diag(s))
plot(s_/s_[1], yaxis=:log)

In [None]:
sites = similar(indsall)
sites[1] = indsw[1]
offset = 1
for k in 1:N
    sites[2*k-1 + offset] = indsx[k]
    sites[2*k + offset] = indsy[k]
end
sites

In [None]:
ITensors.set_warn_order(10000)

cutoff = 1E-5
maxdim = 20000
M = MPS(tensor, sites; cutoff=cutoff, maxdim=maxdim)

In [None]:
tensor_reconst = Array(reduce(*, M), sites...)
;

In [None]:
# Error scaled by the maxvalue of the original data
error = abs.(vec(tensor_reconst) .- vec(tensor))
y1 = error ./ maximum(abs, tensor)
y2 = abs.(vec(tensor))
plot(y1[1:100])

In [None]:
bonddims = collect(size(m)[1] for m in M)
plot(bonddims)

In [None]:
sum(prod(size(m)) for m in M)

In [None]:
prod(size(giv0))

In [None]:
sum(prod(size(m)) for m in M)/prod(size(giv0))