In [1]:
using BenchmarkTools

In [13]:
_getindex(x, indices) = ntuple(i->x[indices[i]], length(indices))

function _contract(
    a::AbstractArray{T1,N1},
    b::AbstractArray{T2,N2},
    idx_a::NTuple{n1,Int},
    idx_b::NTuple{n2,Int}
    ) where {T1, T2, N1, N2, n1, n2}
    length(idx_a) == length(idx_b) || error("length(idx_a) != length(idx_b)")
    # check if idx_a contains only unique elements
    length(unique(idx_a)) == length(idx_a) || error("idx_a contains duplicate elements")
    # check if idx_b contains only unique elements
    length(unique(idx_b)) == length(idx_b) || error("idx_b contains duplicate elements")
    # check if idx_a and idx_b are subsets of 1:N1 and 1:N2
    all(1 <= idx <= N1 for idx in idx_a) || error("idx_a contains elements out of range")
    all(1 <= idx <= N2 for idx in idx_b) || error("idx_b contains elements out of range")

    rest_idx_a = setdiff(1:N1, idx_a)
    rest_idx_b = setdiff(1:N2, idx_b)

    amat = reshape(permutedims(a, (rest_idx_a..., idx_a...)), prod(_getindex(size(a), rest_idx_a)), prod(_getindex(size(a), idx_a)))
    bmat = reshape(permutedims(b, (idx_b..., rest_idx_b...)), prod(_getindex(size(b), idx_b)), prod(_getindex(size(b), rest_idx_b)))

    return reshape(amat * bmat, size(a)[rest_idx_a...]..., size(b)[rest_idx_b]...)
end

_contract (generic function with 1 method)

In [16]:
a = rand(2, 3, 4)
b = rand(2, 5, 4)
ab = _contract(a, b, (1, 3), (1, 3))

3×5 Matrix{Float64}:
 1.9065   3.86227  2.6992   2.38542  2.32756
 1.35216  2.76803  1.49784  1.68708  1.66962
 1.56337  2.81816  2.12559  1.76629  1.60034

In [9]:
Tuple([1,2])

(1, 2)

In [8]:
module My

struct IndexSetCache{T}
    cache::Dict{BigInt,T}
    coeffs::Vector{BigInt}
    function IndexSetCache{T}(dims) where T
        coeffs = ones(BigInt, length(dims))
        for n in 2:length(dims)
            coeffs[n] = dims[n-1] * coeffs[n-1]
        end
        new(Dict{BigInt,T}(), coeffs)
    end
end

function (cf::IndexSetCache{T})(
    x::Vector{Int}
) where {T}
    length(x) == length(cf.coeffs) || error("Invalid length of x")
    return get!(cf.d, _key(cf, x)) do
        cf.f(x)
    end
end

end



Main.My

In [11]:
import .My

localdims = fill(2, 20)
cache = My.IndexSetCache{Int}(localdims)

Main.My.IndexSetCache{Int64}(Dict{BigInt, Int64}(), BigInt[1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072, 262144, 524288])

In [None]:
tolerance = 1e-12
localdims = localdims1 .* localdims3
tci, ranks, errors =
    TCI.crossinterpolate2(Float64, ab, localdims, [ones(Int, N)]; tolerance = tolerance, loginterval=1)

In [None]:
sitesa = [Index(d, "a=$i") for (i, d) in enumerate(localdims1)]
sitesb = [Index(d, "b=$i") for (i, d) in enumerate(localdims3)]
M = Quantics.TCItoMPS(tci)
M = Quantics.unfuse_siteinds_nosplit(M, siteinds(M), [collect(i) for i in zip(sitesa, sitesb)])

In [None]:
ab_ref = contract(ab.a_MPO, ab.b_MPO; alg="naive")
for n in eachindex(ab_ref)
    ab_ref[n] = replaceind(ab_ref[n], ab.sites1[n], sitesa[n])
    ab_ref[n] = replaceind(ab_ref[n], ab.sites3[n], sitesb[n])
end
ab_ref[1] *= onehot(ab.links_a[1] => 1)
ab_ref[1] *= onehot(ab.links_b[1] => 1)
ab_ref[end] *= onehot(ab.links_a[end] => 1)
ab_ref[end] *= onehot(ab.links_b[end] => 1)
ab_ref

In [None]:
ab_ref ≈ M