In [8]:
using Revise
import TensorCrossInterpolation as TCI
using Plots
using BenchmarkTools
using Test
import LinearAlgebra as LA

In [9]:
A = [
    0.711002 0.724557 0.789335 0.382373
    0.910429 0.726781 0.719957 0.486302
    0.632716 0.39967 0.571809 0.0803125
    0.885709 0.531645 0.569399 0.481214
]
A = diagm([1.0, 0.1, 0.01])

lu = TCI.rrlu(A)

TensorCrossInterpolation.rrLU{Float64}([1, 2, 3], [1, 2, 3], [1.0 0.0 0.0; 0.0 1.0 0.0; 0.0 0.0 1.0], [1.0 0.0 0.0; 0.0 0.1 0.0; 0.0 0.0 0.01], true, 3, 0.0)

In [10]:
lu.L

3×3 Matrix{Float64}:
 1.0  0.0  0.0
 0.0  1.0  0.0
 0.0  0.0  1.0

In [13]:
lu.D

ErrorException: type rrLU has no field D

In [12]:
function LDU(lu)
    n1, n2 = size(lu)
    similar(lu.L), similar(lu.U), similar(lu.D)
end

LDU (generic function with 1 method)

In [None]:
localdims = [2, 2, 2, 2, 2]
localset = [collect(1:d) for d in localdims]
leftindexset = [[1,1] for _ in 1:100]
rightindexset = [[1,1] for _ in 1:100]

module My
import TensorCrossInterpolation as TCI
struct TestFunction <: TCI.BatchEvaluator{Float64}
    localset::Vector{Vector{Int}}
    function TestFunction(localset)
        new(localset)
    end
end

(obj::TestFunction)(indexset)::Float64 = sum(indexset)
function (obj::TestFunction)(leftindexset, rightindexset, ::Val{M})::Array{Float64,M+2} where {M}
    nl = length(first(leftindexset))
    result = [sum(vcat(l, collect(c), r)) for l in leftindexset, c in Iterators.product((1:length(d) for d in obj.localset[nl+1:nl+M])...), r in rightindexset]
    return reshape(result, length(leftindexset), length.(obj.localset[nl+1:nl+M])..., length(rightindexset))
end

end

f = TCI.CachedFunction{Float64}(My.TestFunction(localset), localdims)
@assert TCI.isbatchevaluable(f)
result = TCI._batchevaluate_dispatch(Float64, f, localset, leftindexset, rightindexset, Val(1))
ref = [sum(vcat(l, c, r)) for l in leftindexset, c in localset[3], r in rightindexset]

result ≈ ref

In [None]:
@code_warntype TCI._batchevaluate_dispatch(Float64, f, localset, leftindexset, rightindexset, Val(1))

In [None]:
function func(leftindexset, rightindexset, dims)
    nl = length(first(leftindexset))
    nr = length(first(rightindexset))
    L = length(dims)

    r = 0.0
    for i in Iterators.product(leftindexset, (1:dims[l] for l in nl+1:L-nr)..., rightindexset)
        r += 1.0
    end
    #return V[
        1.0
        #for i in Iterators.product(leftindexset, (1:dims[l] for l in nl+1:L-nr)..., rightindexset)
    #]
end

In [None]:
localdims = [2, 2, 2, 2, 2]
x = [[1,1] for _ in 1:100]
y = [[1,1] for _ in 1:100]
@code_warntype func(x, y, localdims)

In [None]:
@code_warntype func(x, y, localdims)

In [None]:
function f()
    return 2
end

function has_zero_arg_method(f)
    methods_list = methodswith(typeof(f))
    for m in methods_list
        if m.nargs == 1  # 引数が自身の関数オブジェクトのみ（つまり引数なし）
            return true
        end
    end
    return false
end

@benchmark has_zero_arg_method(f)

In [None]:
function run(N, maxiter)
    firstpivot = ones(Int, N)
    qtt, ranks, errors = TCI.crossinterpolate(
        Float64, x->randn(), fill(2, N), firstpivot; tolerance=0.0, maxiter=maxiter, verbosity=1
    )
end

In [None]:
# Compile
run(10, 10)

In [None]:
Ds = [20, 40, 80, 160, 240, 400, 800, 1000]

timings = Float64[]
for maxiter in Ds
    t1 = time_ns()
    @time run(20, maxiter)
    t2 = time_ns()
    push!(timings, (t2-t1)*1e-9)
end

In [None]:
p = plot(xaxis=:log, yaxis=:log)
plot!(p, Ds, 1e-3*Ds.^2, label="D^2")
plot!(p, Ds, 1e-5*Ds.^3, label="D^3")
plot!(p, Ds, timings, marker=:cross)
