Skip to content

Commit

Permalink
Merge pull request #14 from tensor4all/13-needed-a-function-for-creat…
Browse files Browse the repository at this point in the history
…ing-an-mpo-for-affinetransform

13 needed a function for creating an mpo for affinetransform
  • Loading branch information
shinaoka committed Jul 30, 2024
2 parents 0103130 + fd157a8 commit 464ce91
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 44 deletions.
73 changes: 58 additions & 15 deletions src/binaryop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,23 @@ function affinetransform(M::MPS,
shift::AbstractVector{Int},
bc::AbstractVector{Int};
kwargs...)
transformer = affinetransformmpo(siteinds(M), tags, coeffs_dic, shift, bc)
return apply(transformer, M; kwargs...)
end


function affinetransformmpo(sites::AbstractVector{Index{T}},
tags::AbstractVector{String},
coeffs_dic::AbstractVector{Dict{String,Int}},
shift::AbstractVector{Int},
bc::AbstractVector{Int})::MPO where {T}
# f(x, y) = g(a * x + b * y + s1, c * x + d * y + s2)
# = h(a * x + b * y, c * x + d * y),
# where h(x, y) = g(x + s1, y + s2).
# The transformation is executed in this order: g -> h -> f.

mpos = MPO[]

# Number of variables involved in transformation
ntransvars = length(tags)

Expand All @@ -140,7 +152,7 @@ function affinetransform(M::MPS,

sites_for_tag = []
for tag in tags
push!(sites_for_tag, findallsites_by_tag(siteinds(M); tag=tag))
push!(sites_for_tag, findallsites_by_tag(sites; tag=tag))
if length(sites_for_tag[end]) == 0
error("Tag $tag is not found.")
end
Expand All @@ -154,20 +166,45 @@ function affinetransform(M::MPS,
# If shift is required
if !all(shift .== 0)
for i in 1:ntransvars
M = shiftaxis(M, shift[i]; tag=tags[i], bc=bc[i], kwargs...)
push!(mpos, shiftaxismpo(sites, shift[i]; tag=tags[i], bc=bc[i]))
end
end

# Followed by a rotation
return affinetransform(M, tags, coeffs_dic, bc; kwargs...)
push!(mpos, affinetransformmpo(sites, tags, coeffs_dic, bc))

# Contract MPOs
res = mpos[1]
for n in 2:length(mpos)
res = apply(mpos[n], res; cutoff=1e-25, maxdim=typemax(Int))
end

return res
end

# Version without shift

"""
Affine transform of a MPS with no shift
Significant bits are assumed to be aligned from left to right for all tags.
"""
function affinetransform(M::MPS,
tags::AbstractVector{String},
coeffs_dic::AbstractVector{Dict{String,Int}},
bc::AbstractVector{Int};
kwargs...)
transformer = affinetransformmpo(siteinds(M), tags, coeffs_dic, bc)
return apply(transformer, M; kwargs...)
end

"""
Generate an MPO representing an affine transform of a MPS with no shift
Significant bits are assumed to be aligned from left to right for all tags.
"""
function affinetransformmpo(sites::AbstractVector{Index{T}},
tags::AbstractVector{String},
coeffs_dic::AbstractVector{Dict{String,Int}},
bc::AbstractVector{Int})::MPO where {T}
mpos = MPO[]

# f(x, y) = g(a * x + b * y + s1, c * x + d * y + s2)
# = h(a * x + b * y, c * x + d * y),
Expand All @@ -182,7 +219,7 @@ function affinetransform(M::MPS,

sites_for_tag = []
for tag in tags
push!(sites_for_tag, findallsites_by_tag(siteinds(M); tag=tag))
push!(sites_for_tag, findallsites_by_tag(sites; tag=tag))
if length(sites_for_tag[end]) == 0
error("Tag $tag is not found.")
end
Expand Down Expand Up @@ -214,8 +251,6 @@ function affinetransform(M::MPS,
length(pos_sites_in) == ntransvars ||
error("Length of pos_sites_in does not match that of coeffs")

sites = siteinds(M)

# Check if the order of significant bits is consistent among all tags
rev_carrydirecs = Bool[]
pos_for_tags = []
Expand All @@ -231,19 +266,20 @@ function affinetransform(M::MPS,
valid_rev_carrydirecs ||
error("The order of significant bits must be consistent among all tags!")

#all(rev_carrydirecs .== true) ||
#error("Significant bits are aligned from left to right for all tags!")

length(unique([length(s) for s in sites_for_tags])) == 1 ||
error("The number of sites for each tag must be the same! $([length(s) for s in sites_for_tags])")

rev_carrydirec = all(rev_carrydirecs .== true) # If true, significant bits are at the left end.

if !rev_carrydirec
M_ = MPS([M[i] for i in length(M):-1:1]) # Reverse the order of sites
M_ = affinetransform(M_, reverse(tags), reverse(coeffs_dic), reverse(bc); kwargs...)
return MPS([M_[i] for i in length(M_):-1:1])
transformer_ = affinetransformmpo(
reverse(sites), reverse(tags), reverse(coeffs_dic), reverse(bc))
return MPO([transformer_[n] for n in reverse(1:length(transformer_))])
end

# Below, we assume rev_carrydirec = true (left significant bits are at the left end)

# First check transformations with -1 and -1; e.g., (a, b) = (-1, -1)
# These transformations are not supported in the backend.
# To support this case, we need to flip the sign of coeffs as follows:
Expand All @@ -254,7 +290,7 @@ function affinetransform(M::MPS,

for v in 1:ntransvars
if sign_flips[v]
M = bc[v] * reverseaxis(M; tag=tags[v], bc=bc[v], kwargs...)
push!(mpos, bc[v] * reverseaxismpo(sites; tag=tags[v], bc=bc[v]))
end
end

Expand All @@ -265,9 +301,16 @@ function affinetransform(M::MPS,
transformer = _binaryop_mpo(sites_mpo, coeffs_positive, pos_sites_in;
rev_carrydirec=true, bc=bc)
transformer = matchsiteinds(transformer, sites)
M = apply(transformer, M; kwargs...)

return M
push!(mpos, transformer)

# Contract MPOs
res = mpos[1]
for n in 2:length(mpos)
res = apply(mpos[n], res; cutoff=1e-25, maxdim=typemax(Int))
end

return res
end

"""
Expand Down
56 changes: 43 additions & 13 deletions src/transformer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,25 +65,35 @@ Note that x = 0, 1, 2, ..., N-1 are mapped to x = 0, N-1, N-2, ..., 1 mod N.
"""
function reverseaxis(M::MPS; tag="x", bc::Int=1, kwargs...)
bc [1, -1] || error("bc must be either 1 or -1")
return apply(reverseaxismpo(siteinds(M); tag=tag, bc=bc), M; kwargs...)
end


sites = siteinds(M)
function reverseaxismpo(sites::AbstractVector{Index{T}}; tag="x", bc::Int=1)::MPO where {T}
bc [1, -1] || error("bc must be either 1 or -1")
targetsites = findallsiteinds_by_tag(sites; tag=tag)
pos = findallsites_by_tag(sites; tag=tag)
!isascendingordescending(pos) && error("siteinds for tag $(tag) must be sorted.")
rev_carrydirec = isascendingorder(pos)
siteinds_MPO = rev_carrydirec ? targetsites : reverse(targetsites)
transformer_tag = flipop(siteinds_MPO; rev_carrydirec=rev_carrydirec, bc=bc)
transformer = matchsiteinds(transformer_tag, sites)
return apply(transformer, M; kwargs...)
return matchsiteinds(transformer_tag, sites)
end

"""
f(x) = g(x + shift) for x = 0, 1, ..., 2^R-1 and 0 <= shift < 2^R.
"""
function shiftaxis(M::MPS, shift::Int; tag="x", bc::Int=1, kwargs...)
bc [1, -1] || error("bc must be either 1 or -1")
return apply(shiftaxismpo(siteinds(M), shift; tag=tag, bc=bc), M; kwargs...)
end


sites = siteinds(M)
"""
f(x) = g(x + shift) for x = 0, 1, ..., 2^R-1 and 0 <= shift < 2^R.
"""
function shiftaxismpo(sites::AbstractVector{Index{T}}, shift::Int; tag="x", bc::Int=1)::MPO where {T}
bc [1, -1] || error("bc must be either 1 or -1")
targetsites = findallsiteinds_by_tag(sites; tag=tag) # From left to right: x=1, 2, ...
pos = findallsites_by_tag(sites; tag=tag)
!isascendingordescending(pos) && error("siteinds for tag $(tag) must be sorted.")
Expand All @@ -101,25 +111,45 @@ function shiftaxis(M::MPS, shift::Int; tag="x", bc::Int=1, kwargs...)
transformer = matchsiteinds(transformer, sites)
transformer *= bc^nbc

return apply(transformer, M; kwargs...)
return transformer
end

"""
Multiply by exp(i θ x), where x = (x_1, ..., x_R)_2.
"""
function phase_rotation(M::MPS, θ::Float64; targetsites=nothing, tag="")
sitepos, target_sites = _find_target_sites(M; sitessrc=targetsites, tag=tag)
res = copy(M)
function phase_rotation(M::MPS, θ::Float64; targetsites=nothing, tag="")::MPS
transformer = phase_rotation_mpo(siteinds(M), θ; targetsites=targetsites, tag=tag)
apply(transformer, M)
end

"""
Create an MPO for multiplication by `exp(i θ x)`, where `x = (x_1, ..., x_R)_2`.
`sites`: site indices for `x_1`, `x_2`, ..., `x_R`.
"""
function phase_rotation_mpo(sites::AbstractVector{Index{T}}, θ::Float64; targetsites=nothing, tag="")::MPO where {T}
_, target_sites = _find_target_sites(sites; sitessrc=targetsites, tag=tag)
transformer = _phase_rotation_mpo(target_sites, θ)
return matchsiteinds(transformer, sites)
end

nqbit = length(sitepos)
for n in 1:nqbit
p = sitepos[n]
res[p] *= op("Phase", siteind(res, p); ϕ=θ * 2^(nqbit - n))
function _phase_rotation_mpo(sites::AbstractVector{Index{T}}, θ::Float64)::MPO where {T}
R = length(sites)
tensors = [ITensor(true) for _ in 1:R]
for n in 1:R
tensors[n] = op("Phase", sites[n]; ϕ=θ * 2^(R - n))
end
links = [Index(1, "Link,l=$l") for l in 1:(R-1)]
tensors[1] = ITensor(Array(tensors[1], sites[1]', sites[1]), sites[1], sites[1]', links[1])
for l in 2:(R-1)
tensors[l] = ITensor(Array(tensors[l], sites[l]', sites[l]), links[l-1], sites[l], sites[l]', links[l])
end
tensors[end] = ITensor(Array(tensors[end], sites[end]', sites[end]), links[end], sites[end], sites[end]')

return noprime(res)
return MPO(tensors)
end


function _upper_lower_triangle(upper_or_lower::Symbol)::Array{Float64,4}
upper_or_lower [:upper, :lower] || error("Invalid upper_or_lower $(upper_or_lower)")
T = Float64
Expand Down
7 changes: 5 additions & 2 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ function directprod(::Type{T}, sites, indices) where {T}
end

function _find_target_sites(M::MPS; sitessrc=nothing, tag="")
_find_target_sites(siteinds(M); sitessrc, tag)
end

function _find_target_sites(sites::AbstractVector{Index{T}}; sitessrc=nothing, tag="") where {T}
if tag == "" && sitessrc === nothing
error("tag or sitesrc must be specified")
elseif tag != "" && sitessrc !== nothing
Expand All @@ -388,12 +392,11 @@ function _find_target_sites(M::MPS; sitessrc=nothing, tag="")

# Set input site indices
if tag != ""
sites = siteinds(M)
sitepos = findallsites_by_tag(sites; tag=tag)
target_sites = [sites[p] for p in sitepos]
elseif sitessrc !== nothing
target_sites = sitessrc
sitepos = Int[findsite(M, s) for s in sitessrc]
sitepos = Int[findfirst(x->x==s, sites) for s in sitessrc]
end

return sitepos, target_sites
Expand Down
38 changes: 24 additions & 14 deletions test/transformer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,21 +166,7 @@ end
@test fz_reconst fz_ref
end

@testset "phase_rotation" begin
nqbit = 3
xvec = collect(0:(2^nqbit - 1))
θ = 0.1
sites = [Index(2, "Qubit,x=$x") for x in 1:nqbit]
_reconst(x) = vec(Array(reduce(*, x), reverse(sites)))

f = randomMPS(sites)
f_vec = _reconst(f)

ref = exp.(im * θ * xvec) .* f_vec

@test ref _reconst(Quantics.phase_rotation(f, θ; tag="x"))
@test ref _reconst(Quantics.phase_rotation(f, θ; targetsites=sites))
end

#==
@testset "asdiagonal" begin
Expand All @@ -204,6 +190,30 @@ end
==#
end

@testitem "transformer_tests.jl/phase_rotation" begin
using Test
import Quantics
using ITensors
using LinearAlgebra

@testset "phase_rotation" begin
nqbit = 3
xvec = collect(0:(2^nqbit - 1))
θ = 0.1
sites = [Index(2, "Qubit,x=$x") for x in 1:nqbit]
_reconst(x) = vec(Array(reduce(*, x), reverse(sites)))

f = randomMPS(sites)
f_vec = _reconst(f)

ref = exp.(im * θ * xvec) .* f_vec

@test ref _reconst(Quantics.phase_rotation(f, θ; tag="x"))
@test ref _reconst(Quantics.phase_rotation(f, θ; targetsites=sites))
end

end

@testitem "transformer_tests.jl/shiftaxis" begin
using Test
import Quantics
Expand Down

0 comments on commit 464ce91

Please sign in to comment.