# Fourier transform

In [None]:
using Revise
using Plots
using LinearAlgebra
using SparseIR
import SparseIR: valueim

newaxis = [CartesianIndex()]

In [None]:
BLAS.set_num_threads(16)

In [None]:
println(Threads.nthreads())

In [None]:
using ITensors

println(ITensors.blas_get_num_threads())

We want create a MPO for Fourier transform:
$$
F(t) = \sum_{x=0}^{N-1} f(x) e^{-i \frac{2\pi t x}{N}} = \sum_{x=0}^{N-1} T(t, x) f(x).
$$

MPS/MPO tensors are indexed from the left to the right in ascending order.
We assign the least significant digit to the first or the last tensor.
Let us stick to the former convention.

$x=0,...., 2^Q-1$ can be represented as a binary number, $0b001, 0b010, ..., 0b111 (=0b x_{Q-1} x_{Q-2} ... x_0)$ for $Q=3$.

$$
F(t_0, \cdots, t_{Q-1}) = \sum_{x_0=0}^1 \cdots \sum_{x_{Q-1}=0}^1  T(t_0, \cdots, t_{Q-1}, x_0, \cdots, x_{Q-1}) f(x_0, \cdots, x_{Q-1}).
$$

We transpose the tensor $T$ as $\bar T$:

$$
\bar T(t_Q, x_0, \cdots, t_0, x_{Q-1}) = T(t_0, \cdots, T_{Q-1}, x_0, \cdots, x_{Q-1})
$$

In MPS/MPO, states[1] corresponds to $x_0$ and $t_{Q-1}$ (long-range and high frequency).

In [None]:
nbit = 6
N = 2^nbit

sites = siteinds("Qubit", nbit)
sitesT = reverse(sites)

tmat = zeros(ComplexF64, N, N)

for t in 0:N-1, x in 0:N-1
    tmat[t+1, x+1] = exp(-im * 2π * t * x/N)
end

# `tmat`: (t_0, ..., t_{Q-1}, x_0, ..., x_{Q-1})
tmat = reshape(tmat, ntuple(x->2, 2*nbit))

trans_t = ITensor(tmat, sitesT..., prime(sites)...)

tmat = nothing
;

In [None]:
indices = vcat(sites[1:nbit÷2], prime(sites[1:nbit÷2]))
U, S, V = svd(trans_t, indices...)
plot(Array(diag(S)), yaxis=:log)

In [None]:
states = repeat(["1"], nbit)

psix = MPS(sites, states)
psiy = Array(noprime(trans_t * reduce(*, psix)), sites)

p = plot()
plot!(p, real.(vec(psiy)))
plot!(p, imag.(vec(psiy)))

In [None]:
cutoff = 1E-10
maxdim = 100

M = MPO(trans_t, sites; cutoff=cutoff, maxdim=maxdim)

In [None]:
trans_t_reconst = reduce(*, M)

@show maximum(abs, Array(trans_t_reconst, sites, sites') - Array(trans_t, sites, sites'))

In [None]:
states = repeat(["1"], nbit)

psix = MPS(sites, states)
psiy = reduce(*, noprime(contract(M, psix)))
psiy2 = noprime(trans_t * reduce(*, psix))

@show Array(psiy, sites) - Array(psiy2, sites)

In [None]:
for states in [repeat(["1"], nbit)]
    psix = MPS(sites, states)
    psiy = noprime(contract(M, psix))

    # Indices of psiy: (t_{Q-1}, ...., t_0)
    psiy_arr = Array(reduce(*, psiy), sites...)
    #@assert psiy_arr ≈ ones(ComplexF64, size(psiy_arr)...)

    # Indices of psix: (x_0, ...., x_{Q-1})
    vecx = reshape(Array(reduce(*, psix), sites...), N)
    @show vec(psiy_arr)
    @show reshape(tensor2, N, N) * vecx
end

In [None]:
# x0, ..., x_{Q-1} from left to right.
# t_{Q-1}, ..., t_0 from left to right.
#states = repeat(["1"], nbit)
# x0=1, x1=0
states = ["1", "0"]

psix = MPS(sites, states)
#@show Array(reduce(*, psix), sites...)
psiy = noprime(contract(M, psix))

# Indices of psiy: (t_0, ...., t_{Q-1})
psiy_arr = Array(reduce(*, psiy), sitesT...)


p = plot()
plot!(p, real.(vec(psiy_arr)))

In [None]:
psix_full = reduce(*, psix)

In [None]:
psix_full[sites[1]=>2, sites[2]=>1]

In [None]:
psiy_full = reduce(*, psiy)