Skip to content

Commit

Permalink
Implement makesitediagonal and extractdiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
shinaoka committed Jun 22, 2024
1 parent c9c107e commit e978bbb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/tag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ function findallsites_by_tag(sites::Vector{Vector{Index{T}}}; tag::String="x",
sitesflatten = collect(Iterators.flatten(sites))
for n in 1:maxnsites
tag_ = tag * "=$n"
idx = findall(hastags(tag_), sitesflatten)
idx = findall(i -> hastags(i, tag_) && hasplev(i, 0), sitesflatten)
if length(idx) == 0
break
elseif length(idx) > 1
Expand Down
28 changes: 28 additions & 0 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -495,3 +495,31 @@ function makesitediagonal(M::AbstractMPS, tag::String)::MPS

return MPS(collect(M_))
end

"""
Extract diagonal components
"""
function extractdiagonal(M::AbstractMPS, tag::String)::MPS
M_ = deepcopy(MPO(collect(M)))
sites = siteinds(M_)

target_positions = findallsites_by_tag(siteinds(M_); tag=tag)

for t in eachindex(target_positions)
i, j = target_positions[t]
M_[i] = _extract_diagonal(M_[i], sites[i][j], sites[i][j]')
end

return MPS(collect(M_))
end

function _extract_diagonal(t, site::Index{T}, site2::Index{T}) where {T<:Number}
dim(site) == dim(site2) || error("Dimension mismatch")
restinds = uniqueinds(inds(t), site, site2)
newdata = zeros(eltype(t), dim.(restinds)..., dim(site))
olddata = Array(t, restinds..., site, site2)
for i in 1:dim(site)
newdata[.., i] = olddata[.., i, i]
end
return ITensor(newdata, restinds..., site)
end
2 changes: 2 additions & 0 deletions test/transformer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ end
f_ref[i] = g_reconst[mod(2^nbit - (i - 1), 2^nbit) + 1]
end
f_ref[1] *= bc

@test f_reconst f_ref
end

@testset "reverseaxis2" for nbit in 2:3, rev_carrydirec in [true, false]
Expand Down

0 comments on commit e978bbb

Please sign in to comment.