In [None]:
import Pkg; 
Pkg.add("FFTW");
Pkg.add("PaddedViews");
Pkg.add("BenchmarkTools");
Pkg.add("ChainRules");
Pkg.add("ChainRulesCore");

In [8]:
using FFTW
using Base.Threads
using PaddedViews
using BenchmarkTools
using ChainRulesCore

#try out different fft planners https://docs.juliahub.com/FourierAnalysis/1aWDG/1.0.1/fftw/
#also switch to in-place operations at some point
function create_2Dffp_plans(n::Int64,nthr::Int64)
    FFTW.set_num_threads(nthr)
    plan_fft(rand(ComplexF64, n,n),flags=FFTW.MEASURE)
end

function fftconv2d(arr::Array{ComplexF64,2},fftker::Array{ComplexF64,2},plan,fwd::Bool=true)::Array{ComplexF64,2}
    narr = size(arr)[1]
    nker = size(fftker)[1]
    nout = nker - narr

    i1 = 1+(narr÷2)
    i2 = i1 + nout -1
    if fwd==true
        return ifftshift(plan \ (fftker .* (plan * collect(sym_paddedviews(0.0+im*0.0,arr,fftker)[1]))))[i1:i2,i1:i2]
    else
        return circshift(ifftshift(plan \ (fftker .* (plan * collect(sym_paddedviews(0.0+im*0.0,arr,fftker)[1])))),(1,1))[i1:i2,i1:i2]
    end
    
end
    
function green2d(nx::Int64,ny::Int64, δx::Float64, δy::Float64, freq::Float64, ϵ::Float64,μ::Float64, Dz::Float64, plan)::Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}

    ω = 2*π*freq
    n = sqrt(ϵ*μ)
    k = n*ω
    ik = im * k

    Lx,Ly = nx*δx, ny*δy
    δxy = δx*δy

    x = range(-Lx/2,Lx/2-δx, nx)'
    y = range(-Ly/2,Ly/2-δy, ny)

    gz = @. Dz * (-1 + ik * sqrt(x^2 + y^2 + Dz^2)) * cis(k*sqrt(x^2 + y^2 + Dz^2))/(4*π*sqrt(x^2 + y^2 + Dz^2)^3) * δxy * (-μ/ϵ)

    fgz = plan * gz
    fgzT = plan * reverse(gz)

    (fgz,fgzT)

end

# Field matrix F has the format (unit cells, frequencies)
function ffp(F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan, fwd::Bool)::Array{ComplexF64,2}
    
    narr = Int64(sqrt(size(F)[1]))
    nker = size(fgs[1][1])[1]
    nout = nker - narr
    nfreqs = size(F)[2]

    out = zeros(ComplexF64, nout*nout, nfreqs)
    Threads.@threads for i in 1:nfreqs
        @inbounds out[:,i] .= vec(fftconv2d( reshape(F[:,i],narr,narr),
                                             fwd==true ? fgs[i][1] : fgs[i][2],
                                             plan, fwd ))
    end
    out
end

function near2far(F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan)::Array{ComplexF64,2}
    ffp(F,fgs,plan, true)
end
    

function ChainRulesCore.rrule(::typeof(near2far), F::Array{ComplexF64,2}, fgs::Vector{Tuple{Array{ComplexF64,2},Array{ComplexF64,2}}}, plan)
    efar = near2far(F,fgs,plan)
    function near2far_pullback(vec::Array{ComplexF64,2})

        dF = @thunk(ffp(conj.(vec), fgs, plan, false))
        NoTangent(), conj.(dF), ZeroTangent(), ZeroTangent()

    end
    efar, near2far_pullback
end

In [9]:
using Zygote
using FiniteDifferences
using LinearAlgebra
using Random
n1=26
n2=10
n3=4
plan = create_2Dffp_plans(n1,4);
freqs = range(0.8,1.2,n3)
fgs = [ green2d(n1,n1,0.7,0.7, freq,1.0,1.0, 1000., plan) for freq in freqs ]
F = rand(ComplexF64,n2*n2,n3)
function test(F,fgs,plan)
    G = near2far(F,fgs,plan)
    sum(real.(G).*imag.(G).^2)
end
ret,back = Zygote.pullback(x->test(x,fgs,plan),F)
gdat = back(1)[1]
###
plan0 = create_2Dffp_plans(n1,1);
function tmp1(Fr)
    test(Fr .+ im .* imag.(F),fgs,plan0)
end
Δ = grad(central_fdm(5,1), tmp1, real.(F))[1]
display(maximum(abs.(Δ .- real.(gdat))[1:end-1])/mean(abs.(Δ)))
function tmp2(Fi)
    test(real.(F) .+ im .* Fi,fgs,plan0)
end
Δ = grad(central_fdm(5,1), tmp2, imag.(F))[1]
display(maximum(abs.(Δ .- imag.(gdat))[1:end-1])/mean(abs.(Δ)))


8.75326694839738e-12

2.188285606871917e-11

In [21]:
plan = create_2Dffp_plans(5000,2);
freqs = [0.8,0.9,1.0,1.1,1.2]
fgs = [ green2d(5000,5000,0.7,0.7, freq,1.0,1.0, 1000., plan) for freq in freqs ]
display(typeof(fgs))
F = rand(ComplexF64,3000*3000,5)
@btime near2far($F,$fgs,$plan);

Vector{Tuple{Matrix{ComplexF64}, Matrix{ComplexF64}}}[90m (alias for [39m[90mArray{Tuple{Array{Complex{Float64}, 2}, Array{Complex{Float64}, 2}}, 1}[39m[90m)[39m

  15.357 s (773 allocations: 10.58 GiB)


In [22]:
arr = rand(ComplexF64,5000,5000)
fftker = rand(ComplexF64,8000,8000)

p1 = create_2Dffp_plans(8000,2); @btime fftconv2d(arr,fftker,p1,true);
p1 = create_2Dffp_plans(8000,4); @btime fftconv2d(arr,fftker,p1,true);

In [7]:
function fftconv1d(arr,ker)

    FFTW.set_num_threads(1)
    plan = plan_fft(ker,flags=FFTW.MEASURE)

    narr = size(arr)[1]
    nker = size(ker)[1]
    nout = nker - narr
    i1,i2 = 1,nker
    padarr = sym_paddedviews(0.0+im*0.0,arr,ker)[1]
    fftarr = plan * padarr 
    fftker = plan * ker 
    return fftshift(plan \ (fftker .* fftarr))[i1:i2]
   
end

arr = vec([1.0,3.0,4.0,1.0,6.0,3.0] .+ im .* zeros((6,1)));
ker = vec([2.0,2.0,4.0,3.0,1.0,1.0,3.0,1.0,2.0,2.0,4.0,5.0] .+ im .* zeros((12,1)));
fftconv1d(arr,ker)

12-element Vector{ComplexF64}:
 43.0 + 0.0im
 45.0 + 0.0im
 33.0 + 0.0im
 24.0 + 0.0im
 40.0 + 0.0im
 42.0 + 0.0im
 50.0 + 0.0im
 50.0 + 0.0im
 53.0 + 0.0im
 67.0 + 0.0im
 55.0 + 0.0im
 38.0 + 0.0im