In [None]:
import Pkg; 
Pkg.add("FastChebInterp");
Pkg.add("ThreadsX");
Pkg.add("Zygote");
Pkg.add("Memoize");

In [5]:
using DelimitedFiles
using FastChebInterp
using ThreadsX
using Base.Threads
using Zygote
using BenchmarkTools
using Memoize

const c∂tup{N} = Tuple{Array{ComplexF64,N},Array{ComplexF64,N}}
const r∂tup = Tuple{Float64,Vector{Float64}}
const cheb = FastChebInterp.ChebPoly

"""
    getmodel

Generates a chebyshev polynomial interpolated from the datafile. 
The latter must be in the format ipt, DoF, Re(t[freq1]), Im(t[freq2]) ... 
In other words, the dimensions of the datafile must be (order+1,2+2*nfreqs)
"""
function getmodel(lb,ub,filename)
    dat = readdlm(filename,' ',Float64,'\n')
    dat = dat[:,3:end]'
    dat = [dat[:,i] for i in 1:size(dat,2)]
    model = chebinterp(dat,lb,ub)
end

"""
    eval2c

Multi-threaded evaluation of meta-atom transmission coefficients 
for multiple frequencies using the chebyshev "model". 
"""

@memoize function wrapcheb(model::cheb,p::Float64)
    chebjacobian(model,p)
end

function eval2c(model::cheb,p::Vector{Float64};moi::Bool=false)::c∂tup{2}
    ndof = size(p)[1]
    nfreqs = size(model(p[1]))[1]÷2
    F = Array{ComplexF64,2}(undef,ndof,nfreqs)
    ∂F = Array{ComplexF64,2}(undef,ndof,nfreqs)
    Threads.@threads for i in 1:ndof
        if moi==true
            @inbounds t = wrapcheb(model,p[i])
        else
            @inbounds t = chebjacobian(model,p[i])
        end
        @inbounds @views F[i,:]  .= t[1][1:2:end] .+ im * t[1][2:2:end]
        @inbounds @views ∂F[i,:] .= t[2][1:2:end] .+ im * t[2][2:2:end]
    end
    (F,∂F)
end

function end2end(model::cheb, p::Vector{Float64}, getF::Function, f::Function, fdat::Any)::r∂tup
    F,∂F = getF(model,p)
    ret,back = Zygote.pullback(ξ->f(ξ,fdat),F)
    (ret, real.(sum(conj.(back(1)[1]).*∂F, dims=2))[:,1])
end


# setup(;ncells::Int64=3000,npix::Int64=500,nintg::Int64=5,nspatl::Int64=120,
#        Dz::Float64=5000, freqs::Vector{Float64}=[1.2,1.1,1.0,0.9,0.8],
#        lb::Float64=0.11,ub::Float64=0.68,
#        filename::String="alldat_5wavs.dat", 
#        kwargs...)

end2end (generic function with 1 method)

In [6]:
lb,ub=0.11,0.68
filename="alldat_5wavs.dat"
model = getmodel(lb,ub,filename)
(F,∂F) = @btime eval2c(model,rand(lb:0.0000000001:ub,9000000),moi=false);
(F,∂F) = @btime eval2c(model,rand(lb:0.0000000001:ub,9000000),moi=true);

  9.380 s (18000075 allocations: 3.82 GiB)


9000000×5 Matrix{ComplexF64}:
   0.216324+0.0615456im      0.428683-0.252186im   …   -0.168921+0.376084im
   0.496724+0.209764im      -0.487921-0.105827im        0.434532+0.351672im
   0.510531-0.347361im      -0.574915+0.0140278im       -0.12277+0.549228im
   0.493134+0.109135im       0.110919-0.122127im         0.24294+0.142552im
   0.154665-0.39737im       -0.216191-0.532624im       -0.474834+0.576495im
   0.215984+0.336301im     -0.0711706+0.0919682im  …    0.156992+0.101803im
  -0.276101-0.392033im       0.252094-0.212244im        0.118452+0.183981im
   0.544555-0.258547im      -0.441464-0.44301im       -0.0358631+0.0585361im
   0.507922-0.519292im      -0.361114-0.333995im       -0.730169+0.117613im
 -0.0512592-0.417539im       0.417643-0.156095im       -0.173691+0.258694im
   0.341069+0.109064im      -0.198795-0.323981im   …     0.12124+0.426028im
   0.293174-0.334204im       0.168859-0.553083im        0.261522+0.63242im
   0.520892+0.119029im      -0.411917-0.405262im      -0.0

In [3]:
model

ChebPoly{1,StaticArraysCore.SVector{10, Float64},Float64} order (999,) polynomial on [0.11,0.68]

In [4]:
function evalmodel(model::Any,p::Vector{Float64})
    ThreadsX.map(a->chebjacobian(model,a),p)
end

function r2c(t)
    nrows = size(t)[1]
    ncols = size(t[1][1])[1]÷2
    F = Array{ComplexF64,2}(undef,nrows,ncols)
    ∂F = Array{ComplexF64,2}(undef,nrows,ncols)
    Threads.@threads for i in 1:nrows
        F[i,:]  = t[i][1][1:2:end] + im * t[i][1][2:2:end]
        ∂F[i,:] = t[i][2][1:2:end] + im * t[i][2][2:2:end]
    end
    (F,∂F)
end

lb,ub=0.11,0.68
filename="alldat_5wavs.dat"
model = getmodel(lb,ub,filename)

x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);
x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);
x = rand(lb:0.000000001:ub,1000000)
@time map(a->chebjacobian(model,a),x);
@time ThreadsX.map(a->chebjacobian(model,a),x);


p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);
p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);
p = rand(lb:0.000000001:ub,1000000)
@time evalmodel(model,p);

println("hello")

  3.264581 seconds (2.65 M allocations: 372.743 MiB, 3.64% gc time, 4.86% compilation time)
  2.211243 seconds (12.09 M allocations: 1.892 GiB, 14.39% gc time, 49.65% compilation time)
  3.140972 seconds (2.05 M allocations: 338.568 MiB, 2.84% gc time, 1.11% compilation time)
  1.186752 seconds (7.33 M allocations: 1.645 GiB, 14.99% gc time, 9.32% compilation time)
  3.123095 seconds (2.05 M allocations: 338.560 MiB, 1.20% gc time, 1.15% compilation time)
  1.152620 seconds (7.33 M allocations: 1.645 GiB, 10.71% gc time, 20.91% compilation time)
  1.110561 seconds (664.40 k allocations: 984.273 MiB, 12.98% gc time, 23.32% compilation time)
  0.817125 seconds (1.38 k allocations: 949.722 MiB, 4.65% gc time)
  0.811426 seconds (1.37 k allocations: 949.722 MiB)
hello


In [7]:
function f(F::Array{ComplexF64,2})::Float64
    #sum(abs2.(F))
    sum(real.(F.^2) + 2.0*imag.(F))
end

function makeF(p::Vector{Float64})::c∂tup{2}
    F = reduce(hcat,(1+2im)*[p,p,p])
    tmp = ones(size(p))
    ∂F = reduce(hcat,(1+2im)*[tmp,tmp,tmp])
    (F,∂F)
end

function e2e(p::Vector{Float64},getF::Function,obj::Function)::r∂tup
    F,∂F = getF(p)
    ret,back = Zygote.pullback(obj,F)
    (ret, real.(sum(conj.(back(1)[1]).*∂F, dims=2))[:,1])
end
    
p=[1.,2.,3.]
@time e2e(p,makeF,f)
p=rand(10)
@time a = e2e(p,makeF,f)
typeof(a)

  0.048871 seconds (4.79 k allocations: 178.695 KiB, 99.84% compilation time)
  0.000049 seconds (54 allocations: 7.688 KiB)


Tuple{Float64, Vector{Float64}}