In [None]:
import Pkg; 
Pkg.add("FFTW");
Pkg.add("PaddedViews");
Pkg.add("BenchmarkTools");
Pkg.add("ChainRules");
Pkg.add("ChainRulesCore");

In [1]:
using FFTW
using Base.Threads
using PaddedViews
using BenchmarkTools
using ChainRulesCore

#try out different fft planners https://docs.juliahub.com/FourierAnalysis/1aWDG/1.0.1/fftw/
#also switch to in-place operations at some point
function create_2Dffp_plans(n::Int64,nthr::Int64)
    FFTW.set_num_threads(nthr)
    plan = plan_fft(rand(ComplexF64, n,n),flags=FFTW.MEASURE)
    plan
end

function fftconv2d(arr::Array{ComplexF64,2},fftker::Array{ComplexF64,2},plan,fwd::Bool=true)::Array{ComplexF64,2}
    narr = size(arr)[1]
    nker = size(fftker)[1]
    nout = nker - narr
    if fwd==true
        i1,i2 = 1+narr÷2, narr÷2+nout
    else
        i1,i2 =   narr÷2, narr÷2+nout-1
    end
    ifftshift(plan \ (fftker .* (plan * sym_paddedviews(0.0+im*0.0,arr,fftker)[1])))[i1:i2,i1:i2]
end
    
function green2d(nx::Int64,ny::Int64, δx::Float64, δy::Float64, freq::Float64, ϵ::Float64,μ::Float64, Dz::Float64, plan)::Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}

    ω = 2*π*freq
    n = sqrt(ϵ*μ)
    k = n*ω
    ik = im * k

    Lx,Ly = nx*δx, ny*δy
    δxy = δx*δy

    x = range(-Lx/2,Lx/2-δx, nx)'
    y = range(-Ly/2,Ly/2-δy, ny)

    gz = @. Dz * (-1 + ik * sqrt(x^2 + y^2 + Dz^2)) * cis(k*sqrt(x^2 + y^2 + Dz^2))/(4*π*sqrt(x^2 + y^2 + Dz^2)^3) * δxy * (-μ/ϵ)

    fgz = plan * gz
    fgzT = plan * reverse(gz)

    (fgz,fgzT)

end

# Field matrix F has the format (unit cells, frequencies)
function ffp(F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan, fwd::Bool)::Array{ComplexF64,2}
    
    narr = Int64(sqrt(size(F)[1]))
    nker = size(fgs[1][1])[1]
    nout = nker - narr
    nfreqs = size(F)[2]

    out = zeros(ComplexF64, nout*nout, nfreqs)
    Threads.@threads for i in 1:nfreqs
        @inbounds out[:,i] = vec(fftconv2d( reshape(F[:,i],narr,narr),
                                            fwd==true ? fgs[i][1] : fgs[i][2],
                                            plan, fwd ))
    end
    out
end

function near2far(F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan)::Array{ComplexF64,2}
    ffp(F,fgs,plan, true)
end
    

function ChainRulesCore.rrule(::typeof(near2far), F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan)
    efar = near2far(F,fgs,plan)
    function near2far_pullback(vec::Array{ComplexF64,2})
        dF = @thunk(ffp(vec, fgs, plan, false))
        NoTangent(), dF, ZeroTangent(), ZeroTangent()
    end
    efar, near2far_pullback
end

In [2]:
plan = create_2Dffp_plans(500,1);
freqs = [0.8,0.9,1.0,1.1,1.2]
fgs = [ green2d(500,500,0.7,0.7, freq,1.0,1.0, 1000., plan) for freq in freqs ]
display(typeof(fgs))
F = rand(ComplexF64,300*300,5)
@btime near2far(F,fgs,plan);

Vector{Tuple{Matrix{ComplexF64}, Matrix{ComplexF64}}}[90m (alias for [39m[90mArray{Tuple{Array{Complex{Float64}, 2}, Array{Complex{Float64}, 2}}, 1}[39m[90m)[39m

  16.912 ms (171 allocations: 108.35 MiB)


In [20]:
arr = rand(ComplexF64,5000,5000)
fftker = rand(ComplexF64,8000,8000)

p1 = create_2Dffp_plans(8000,2); @btime fftconv2d(arr,fftker,p1,true);
p1 = create_2Dffp_plans(8000,4); @btime fftconv2d(arr,fftker,p1,true);

  2.506 s (139 allocations: 4.90 GiB)
  2.337 s (201 allocations: 4.90 GiB)
