Skip to content

Commit

Permalink
Add NFFT pullback
Browse files Browse the repository at this point in the history
  • Loading branch information
ptiede committed Aug 19, 2022
1 parent 690d918 commit cdbcc6b
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/Comrade.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using FLoops
#using MappedArrays: mappedarray
import MeasureBase as MB
import MeasureTheory as MT
using NFFT: nfft, plan_nfft
using NFFT
using PaddedViews
using PyCall: pyimport, PyNULL, PyObject
using SpecialFunctions
Expand Down
8 changes: 4 additions & 4 deletions src/models/modelimage/nfft_alg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function padimage(alg::NFFTAlg, img)
nny = nextprod((2,3,5,7), padfac*ny)
nsx = nnx÷2-nx÷2
nsy = nny÷2-ny÷2
return PaddedView(zero(eltype(img)), img,
return PaddedView(zero(eltype(img)), img.im,
(1:nnx, 1:nny),
(nsx+1:nsx+nx, nsy+1:nsy+ny)
)
Expand All @@ -71,7 +71,7 @@ function plan_nuft(alg::ObservedNUFT{<:NFFTAlg}, img, dx, dy)
uv2 = similar(alg.uv)
uv2[1,:] .= alg.uv[1,:]*dx
uv2[2,:] .= alg.uv[2,:]*dy
plan = plan_nfft(uv2, size(img); precompute=NFFT.POLYNOMIAL)
plan = plan_nfft(uv2, size(img'); precompute=alg.alg.precompute)
return plan
end

Expand All @@ -83,8 +83,8 @@ function make_phases(alg::ObservedNUFT{<:NFFTAlg}, img)
end

@inline function create_cache(alg::ObservedNUFT{<:NFFTAlg}, plan, phases, img)
timg = IntensityMap(transpose(img.im), img.fovx, img.fovy, img.pulse)
return NUFTCache(alg, plan, phases, img.pulse, timg)
#timg = #IntensityMap(transpose(img.im), img.fovx, img.fovy, img.pulse)
return NUFTCache(alg, plan, phases, img.pulse, img.im')
end

# Allow NFFT to work with ForwardDiff.
Expand Down
30 changes: 17 additions & 13 deletions src/models/modelimage/nuft.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ padfac(alg::NUFT) = alg.padfac
function padimage(::NUFT, img)
#pf = padfac(alg)
#cimg = convert(Matrix{Complex{eltype(img)}}, img.img)
return IntensityMap(img, img.fovx, img.fovy, img.pulse)
return IntensityMap(img, img.fovx, img.fovy, img.pulse).im
# ny,nx = size(img)
# nnx = nextpow(2, pf*nx)
# nny = nextpow(2, pf*ny)
Expand Down Expand Up @@ -69,18 +69,17 @@ function nuft(A, b)
return A*complex.(b)
end

# function ChainRulesCore.rrule(::typeof(nuft), A, b)
# bc = complex.(b)
# vis = A*bc
# println("Pld")
# function nuft_pullback(Δy)
# Δf = NoTangent()
# ΔA = @thunk(A'*Δy)
# Δb = @thunk(Δy*bc')
# return Δf, ΔA, Δb
# end
# return vis, nuft_pullback
# end
function ChainRulesCore.rrule(::typeof(nuft), A::NFFTPlan, b)
bc = complex.(b)
pr = ChainRulesCore.ProjectTo(b)
vis = A*bc
function nuft_pullback(Δy)
Δf = NoTangent()
ΔA = pr(A'*unthunk(Δy))
return Δf, NoTangent(), ΔA
end
return vis, nuft_pullback
end


# ReverseDiff.@grad_from_chainrules nuft(A::ReverseDiff.TrackedArray, b::ReverseDiff.TrackedArray)
Expand All @@ -89,6 +88,7 @@ end
# ReverseDiff.@grad_from_chainrules nuft(A, b::Vector{<:ReverseDiff.TrackedReal})


ChainRulesCore.@non_differentiable checkuv(alg, u::AbstractArray, v::AbstractArray)

function _visibilities(m::ModelImage{M,I,<:NUFTCache{A}},
u::AbstractArray,
Expand Down Expand Up @@ -208,6 +208,10 @@ Base.@kwdef struct NFFTAlg <: NUFT
Controls the accuracy of the NFFT usually don't need to change this
"""
m::Int = 10
"""
NFFT interpolation algorithm
"""
precompute=NFFT.TENSOR
end
include(joinpath(@__DIR__, "nfft_alg.jl"))

Expand Down
40 changes: 40 additions & 0 deletions test/Core/models.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ChainRulesTestUtils
using ChainRulesCore


function testmodel(m::Comrade.AbstractModel, npix=1024, atol=1e-4)
Expand Down Expand Up @@ -326,3 +327,42 @@ end
c = intensitymap(rotated(stretched(Gaussian(), 2.0, 1.0), π/8), 12.0, 12.0, 12, 12; pulse=BSplinePulse{3}())
testmodel(modelimage(c, FFTAlg(padfac=3)), 1024, 1e-3)
end

@testset "modelimage cache" begin
img = intensitymap(rotated(stretched(Gaussian(), μas2rad(2.0), μas2rad(1.0)), π/8),
μas2rad(12.0), μas2rad(12.0), 24, 12; pulse=BSplinePulse{3}())
_,_, amp, lcamp, cphase = load_data()

cache_nf = create_cache(NFFTAlg(amp), img)
cache_df = create_cache(DFTAlg(amp), img)
ac_amp = arrayconfig(amp)
ac_lcamp = arrayconfig(lcamp)
ac_cphase = arrayconfig(cphase)

mimg_nf = modelimage(img, cache_nf)
mimg_df = modelimage(img, cache_df)

vnf = visibilities(mimg_nf, ac_amp)
vdf = visibilities(mimg_df, ac_amp)

atol = 1e-5

@test isapprox(maximum(abs, vnf-vdf), 0, atol=atol)


cpnf = closure_phases(mimg_nf, ac_cphase)
cpdf = closure_phases(mimg_df, ac_cphase)

@test isapprox(maximum(abs, cis.(cpnf-cpdf) .- 1.0 ), 0, atol=atol)

lcnf = logclosure_amplitudes(mimg_nf, ac_lcamp)
lcdf = logclosure_amplitudes(mimg_df, ac_lcamp)

@test isapprox(maximum(abs, lcnf-lcdf), 0, atol=atol)



@testset "nuft pullback" begin
test_rrule(Comrade.nuft, cache_nf.plan NoTangent(), img.im')
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Expand Down

0 comments on commit cdbcc6b

Please sign in to comment.