In [None]:
import Pkg; 
Pkg.add("FastChebInterp");
Pkg.add("ThreadsX");
Pkg.add("Zygote");
Pkg.add("Memoize");

In [1]:
using DelimitedFiles
using FastChebInterp
using ThreadsX
using Base.Threads
using Zygote
using BenchmarkTools
using Memoize

const cheb = FastChebInterp.ChebPoly

"""
    getmodel

Generates a chebyshev polynomial interpolated from the datafile. 
The latter must be in the format ipt, DoF, Re(t[freq1]), Im(t[freq2]) ... 
In other words, the dimensions of the datafile must be (order+1,2+2*nfreqs)
"""
function getmodel(lb,ub,filename)
    dat = readdlm(filename,' ',Float64,'\n')
    dat = dat[:,3:end]'
    dat = [dat[:,i] for i in 1:size(dat,2)]
    model = chebinterp(dat,lb,ub)
end

"""
    eval2c!(F,∂F, model,p)

In-place multi-threaded evaluation of meta-atom transmission coefficients for multiple frequencies using the chebyshev model. 

F and ∂F must be pre-allocated as
 F = Array{ComplexF64,2}(undef,#freqs,#unit cells)
∂F = Array{ComplexF64,2}(undef,#freqs,#unit cells)
"""
function eval2c!(F,∂F, model::cheb,p::Vector{Float64})
    ndof = size(p)[1]
    Threads.@threads for i in 1:ndof
        @inbounds t,∂t = chebjacobian(model,p[i])
        @inbounds @views @.  F[:,i] = complex( t[1:2:end], t[2:2:end])
        @inbounds @views @. ∂F[:,i] = complex(∂t[1:2:end],∂t[2:2:end])
    end
end

"""
Explanation: for f(z=x+iy) ∈ ℜ, Zygote returns df = ∂f/∂x + i ∂f/∂y 
The Wirtinger derivative is ∂f/∂z = 1/2 (∂f/∂x - i ∂f/∂y) = 1/2 conj(df)
The chain rule is ∂f/∂p = ∂f/∂z ∂z/∂p + ∂f/∂z' ∂z'/∂p = 2 real( ∂f/∂z ∂z/∂p ) = real( conj(df) ∂z/∂p ) 
grad must be pre-allocated as
grad = Vector{Float64}(undef,#unit cells)
"""
function end2end(grad, F,∂F, model::cheb, p::Vector{Float64}, getF!::Function, f::Function, fdat::Any)
    getF!(F,∂F, model,p)
    ret,back = Zygote.pullback(ξ->f(ξ,fdat),F)
    @. grad[:] = real(sum(conj(back(1)[1]) * ∂F, dims=1))[1,:]
    return ret
end


# setup(;ncells::Int64=3000,npix::Int64=500,nintg::Int64=5,nspatl::Int64=120,
#        Dz::Float64=5000, freqs::Vector{Float64}=[1.2,1.1,1.0,0.9,0.8],
#        lb::Float64=0.11,ub::Float64=0.68,
#        filename::String="alldat_5wavs.dat", 
#        kwargs...)

end2end

In [5]:
lb,ub=0.11,0.68
filename="alldat_5wavs.dat"
model = getmodel(lb,ub,filename)
ncells = 20000000
p = rand(lb:0.01/ncells:ub,ncells)
F = Array{ComplexF64,2}(undef,5,ncells)
∂F = Array{ComplexF64,2}(undef,5,ncells)
@btime eval2c!($F,$∂F, $model, $p);

  17.743 s (64 allocations: 5.80 KiB)
