In [295]:
import Pkg; Pkg.add("FastChebInterp")
import Pkg; Pkg.add("ThreadsX")
import Pkg; Pkg.add("Zygote")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `C:\Users\zinli\.julia\environments\v1.8\Project.toml`
[32m[1m  No Changes[22m[39m to `C:\Users\zinli\.julia\environments\v1.8\Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `C:\Users\zinli\.julia\environments\v1.8\Project.toml`
[32m[1m  No Changes[22m[39m to `C:\Users\zinli\.julia\environments\v1.8\Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m   Installed[22m[39m RealDot ────────────── v0.1.0
[32m[1m   Installed[22m[39m GPUArraysCore ──────── v0.1.4
[32m[1m   Installed[22m[39m DiffRules ──────────── v1.13.0
[32m[1m   Installed[22m[39m DiffResults ────────── v1.1.0
[32m[1m   Installed[22m[39m SpecialFunctions ───── v2.2.0
[32m[1m   Installed[22m[39m IRTools ────────────── v0.4.9
[32m[1m   Installed[22m[39m Zygote ─────────────── v0.6.60
[32m[1m   Installed[22m[39m NaNMath ──────

In [1]:
using DelimitedFiles
using FastChebInterp
using ThreadsX
using Base.Threads
using Zygote

const c∂tup{N} = Tuple{Array{ComplexF64,N},Array{ComplexF64,N}}
const r∂tup = Tuple{Float64,Vector{Float64}}
const cheb = FastChebInterp.ChebPoly

"""
    getmodel

Generates a chebyshev polynomial interpolated from the datafile. 
The latter must be in the format ipt, DoF, Re(t[freq1]), Im(t[freq2]) ... 
In other words, the dimensions of the datafile must be (order+1,2+2*nfreqs)
"""
function getmodel(lb,ub,filename)
    dat = readdlm(filename,' ',Float64,'\n')
    dat = dat[:,3:end]'
    dat = [dat[:,i] for i in 1:size(dat,2)]
    model = chebinterp(dat,lb,ub)
end

"""
    eval2c

Multi-threaded evaluation of meta-atom transmission coefficients 
for multiple frequencies using the chebyshev "model". 
Known benchmark: 9 million 5 freqs (=90 mil evals) of a 1000-degree poly take 4 sec on 64 threads
"""
function eval2c(model::cheb,p::Vector{Float64})::c∂tup{2}
    ndof = size(p)[1]
    nfreqs = size(model(p[1]))[1]÷2
    F = Array{ComplexF64,2}(undef,ndof,nfreqs)
    ∂F = Array{ComplexF64,2}(undef,ndof,nfreqs)
    Threads.@threads for i in 1:ndof
        @inbounds t = chebjacobian(model,p[i])
        @inbounds F[i,:]  .= t[1][1:2:end] .+ im * t[1][2:2:end]
        @inbounds ∂F[i,:] .= t[2][1:2:end] .+ im * t[2][2:2:end]
    end
    (F,∂F)
end

function end2end(model::cheb, p::Vector{Float64}, getF::Function, f::Function, fdat::Any)::r∂tup
    F,∂F = getF(model,p)
    ret,back = Zygote.pullback(ξ->f(ξ,fdat),F)
    (ret, real.(sum(conj.(back(1)[1]).*∂F, dims=2))[:,1])
end


# setup(;ncells::Int64=3000,npix::Int64=500,nintg::Int64=5,nspatl::Int64=120,
#        Dz::Float64=5000, freqs::Vector{Float64}=[1.2,1.1,1.0,0.9,0.8],
#        lb::Float64=0.11,ub::Float64=0.68,
#        filename::String="alldat_5wavs.dat", 
#        kwargs...)

end2end (generic function with 1 method)

In [4]:
function evalmodel(model::Any,p::Vector{Float64})
    ThreadsX.map(a->chebjacobian(model,a),p)
end

function r2c(t)
    nrows = size(t)[1]
    ncols = size(t[1][1])[1]÷2
    F = Array{ComplexF64,2}(undef,nrows,ncols)
    ∂F = Array{ComplexF64,2}(undef,nrows,ncols)
    Threads.@threads for i in 1:nrows
        F[i,:]  = t[i][1][1:2:end] + im * t[i][1][2:2:end]
        ∂F[i,:] = t[i][2][1:2:end] + im * t[i][2][2:2:end]
    end
    (F,∂F)
end

lb,ub=0.11,0.68
filename="alldat_5wavs.dat"
model = getmodel(lb,ub,filename)

x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);
x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);
x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);


p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);
p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);
p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);

println("hello")

  3.264581 seconds (2.65 M allocations: 372.743 MiB, 3.64% gc time, 4.86% compilation time)
  2.211243 seconds (12.09 M allocations: 1.892 GiB, 14.39% gc time, 49.65% compilation time)
  3.140972 seconds (2.05 M allocations: 338.568 MiB, 2.84% gc time, 1.11% compilation time)
  1.186752 seconds (7.33 M allocations: 1.645 GiB, 14.99% gc time, 9.32% compilation time)
  3.123095 seconds (2.05 M allocations: 338.560 MiB, 1.20% gc time, 1.15% compilation time)
  1.152620 seconds (7.33 M allocations: 1.645 GiB, 10.71% gc time, 20.91% compilation time)
  1.110561 seconds (664.40 k allocations: 984.273 MiB, 12.98% gc time, 23.32% compilation time)
  0.817125 seconds (1.38 k allocations: 949.722 MiB, 4.65% gc time)
  0.811426 seconds (1.37 k allocations: 949.722 MiB)
hello


In [4]:
lb,ub=0.11,0.68
filename="alldat_5wavs.dat"
model = getmodel(lb,ub,filename)
p = rand(lb:0.000000001:ub,1000000)
@time eval2c(model,p)[2][4435,:];
p = rand(lb:0.000000001:ub,1000000)
@time eval2c(model,p)[2][5131,:];
p = rand(lb:0.000000001:ub,1000000)
@time eval2c(model,p)[1][1155,:];
println("hello")

  0.986189 seconds (6.00 M allocations: 793.463 MiB, 9.89% gc time)
  0.965276 seconds (6.00 M allocations: 793.473 MiB, 9.19% gc time)
  0.994329 seconds (6.00 M allocations: 793.474 MiB, 8.44% gc time)
hello


In [7]:
function f(F::Array{ComplexF64,2})::Float64
    #sum(abs2.(F))
    sum(real.(F.^2) + 2.0*imag.(F))
end

function makeF(p::Vector{Float64})::c∂tup{2}
    F = reduce(hcat,(1+2im)*[p,p,p])
    tmp = ones(size(p))
    ∂F = reduce(hcat,(1+2im)*[tmp,tmp,tmp])
    (F,∂F)
end

function e2e(p::Vector{Float64},getF::Function,obj::Function)::r∂tup
    F,∂F = getF(p)
    ret,back = Zygote.pullback(obj,F)
    (ret, real.(sum(conj.(back(1)[1]).*∂F, dims=2))[:,1])
end
    
p=[1.,2.,3.]
@time e2e(p,makeF,f)
p=rand(10)
@time a = e2e(p,makeF,f)
typeof(a)

  0.048871 seconds (4.79 k allocations: 178.695 KiB, 99.84% compilation time)
  0.000049 seconds (54 allocations: 7.688 KiB)


Tuple{Float64, Vector{Float64}}

In [5]:
typeof(makeF)

typeof(makeF) (singleton type of function makeF, subtype of Function)