Skip to content

Commit 78bcac6

Browse files
Merge pull request #1152 from willtebbutt/wct/mooncake-inside-problems
Mooncake Inside Problems
2 parents 99156d1 + 6ac50d1 commit 78bcac6

10 files changed

+181
-13
lines changed

Project.toml

+10-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SciMLSensitivity"
22
uuid = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
33
authors = ["Christopher Rackauckas <accounts@chrisrackauckas.com>", "Yingbo Ma <mayingbo5@gmail.com>"]
4-
version = "7.71.2"
4+
version = "7.72.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -42,6 +42,12 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
4242
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
4343
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4444

45+
[weakdeps]
46+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
47+
48+
[extensions]
49+
SciMLSensitivityMooncakeExt = "Mooncake"
50+
4551
[compat]
4652
ADTypes = "1.9"
4753
Accessors = "0.1.36"
@@ -71,6 +77,7 @@ LinearSolve = "2"
7177
Lux = "1"
7278
Markdown = "1.10"
7379
ModelingToolkit = "9.42"
80+
Mooncake = "0.4.52"
7481
NLsolve = "4.5.1"
7582
NonlinearSolve = "3.0.1"
7683
Optimization = "4"
@@ -110,6 +117,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
110117
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
111118
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
112119
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
120+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
113121
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
114122
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
115123
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
@@ -123,4 +131,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
123131
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
124132

125133
[targets]
126-
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
134+
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]

docs/src/manual/differential_equation_sensitivities.md

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ ZygoteVJP
212212
EnzymeVJP
213213
TrackerVJP
214214
ReverseDiffVJP
215+
MooncakeVJP
215216
```
216217

217218
## More Details on Sensitivity Algorithm Choices

ext/SciMLSensitivityMooncakeExt.jl

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
module SciMLSensitivityMooncakeExt
2+
3+
using SciMLSensitivity, Mooncake
4+
import SciMLSensitivity: get_paramjac_config, mooncake_run_ad, MooncakeVJP, MooncakeLoaded
5+
6+
function get_paramjac_config(::MooncakeLoaded, ::MooncakeVJP, pf, p, f, y, _t)
7+
dy_mem = zero(y)
8+
λ_mem = zero(y)
9+
cache = Mooncake.prepare_pullback_cache(pf, dy_mem, y, p, _t)
10+
return cache, pf, λ_mem, dy_mem
11+
end
12+
13+
function mooncake_run_ad(paramjac_config::Tuple, y, p, t, λ)
14+
cache, pf, λ_mem, dy_mem = paramjac_config
15+
λ_mem .= λ
16+
dy, _ = Mooncake.value_and_pullback!!(cache, λ_mem, pf, dy_mem, y, p, t)
17+
y_grad = cache.tangents[3]
18+
p_grad = cache.tangents[4]
19+
return dy, y_grad, p_grad
20+
end
21+
22+
end

src/adjoint_common.jl

+47
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,9 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
211211
paramjac_config = get_paramjac_config(autojacvec, p, f, y, _p, _t; numindvar, alg)
212212
pf = get_pf(autojacvec; _f = unwrappedf, isinplace = isinplace, isRODE = isRODE)
213213
paramjac_config = (paramjac_config..., Enzyme.make_zero(pf))
214+
elseif autojacvec isa MooncakeVJP
215+
pf = get_pf(autojacvec, prob, unwrappedf)
216+
paramjac_config = get_paramjac_config(MooncakeLoaded(), autojacvec, pf, p, f, y, _t)
214217
elseif SciMLBase.has_paramjac(f) || quad || !(autojacvec isa Bool) ||
215218
autojacvec isa EnzymeVJP
216219
paramjac_config = nothing
@@ -460,6 +463,15 @@ function get_paramjac_config(autojacvec::EnzymeVJP, p, f, y, _p, _t; numindvar,
460463
return paramjac_config
461464
end
462465

466+
# Dispatched on inside extension.
467+
struct MooncakeLoaded end
468+
469+
function get_paramjac_config(::Any, ::MooncakeVJP, pf, p, f, y, _t)
470+
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
471+
"`using Mooncake` to use this functionality"
472+
error(msg)
473+
end
474+
463475
function get_pf(autojacvec::ReverseDiffVJP; _f = nothing, isinplace = nothing,
464476
isRODE = nothing)
465477
nothing
@@ -492,6 +504,41 @@ function get_pf(autojacvec::EnzymeVJP; _f, isinplace, isRODE)
492504
end
493505
end
494506

507+
function get_pf(::MooncakeVJP, prob, _f)
508+
isinplace = DiffEqBase.isinplace(prob)
509+
isRODE = isa(prob, RODEProblem)
510+
pf = let f = _f
511+
if isinplace && isRODE
512+
function (out, u, _p, t, W)
513+
f(out, u, _p, t, W)
514+
return out
515+
end
516+
elseif isinplace
517+
function (out, u, _p, t)
518+
f(out, u, _p, t)
519+
return out
520+
end
521+
elseif !isinplace && isRODE
522+
function (out, u, _p, t, W)
523+
out .= f(u, _p, t, W)
524+
return out
525+
end
526+
else
527+
# !isinplace
528+
function (out, u, _p, t)
529+
out .= f(u, _p, t)
530+
return out
531+
end
532+
end
533+
end
534+
end
535+
536+
function mooncake_run_ad(paramjac_config, y, p, t, λ)
537+
msg = "MooncakeVJP requires Mooncake.jl is loaded. Install the package and do " *
538+
"`using Mooncake` to use this functionality"
539+
error(msg)
540+
end
541+
495542
function getprob(S::SensitivityFunction)
496543
(S isa ODEBacksolveSensitivityFunction) ? S.prob : S.sol.prob
497544
end

src/derivative_wrappers.jl

+8
Original file line numberDiff line numberDiff line change
@@ -751,6 +751,14 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
751751
return
752752
end
753753

754+
function _vecjacobian!(dλ, y, λ, p, t, S::SensitivityFunction, ::MooncakeVJP, dgrad, dy, W)
755+
_dy, y_grad, p_grad = mooncake_run_ad(S.diffcache.paramjac_config, y, p, t, λ)
756+
dy !== nothing && recursive_copyto!(dy, _dy)
757+
!== nothing && recursive_copyto!(dλ, y_grad)
758+
dgrad !== nothing && recursive_copyto!(dgrad, p_grad)
759+
return
760+
end
761+
754762
function jacNoise!(λ, y, p, t, S::SensitivityFunction;
755763
dgrad = nothing, dλ = nothing, dy = nothing)
756764
_jacNoise!(λ, y, p, t, S, S.sensealg.autojacvec, dgrad, dλ, dy)

src/gauss_adjoint.jl

+8
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,11 @@ function GaussIntegrand(sol, sensealg, checkpoints, dgdp = nothing)
428428
end
429429
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
430430
pJ = nothing
431+
elseif sensealg.autojacvec isa MooncakeVJP
432+
pf = get_pf(sensealg.autojacvec, prob, f)
433+
paramjac_config = get_paramjac_config(
434+
MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2])
435+
pJ = nothing
431436
elseif isautojacvec # Zygote
432437
paramjac_config = nothing
433438
pf = nothing
@@ -500,6 +505,9 @@ function vec_pjac!(out, λ, y, t, S::GaussIntegrand)
500505
Enzyme.Reverse, Enzyme.Duplicated(pf, tmp6), Enzyme.Const,
501506
Enzyme.Duplicated(tmp3, tmp4),
502507
Enzyme.Const(y), Enzyme.Duplicated(p, out), Enzyme.Const(t))
508+
elseif sensealg.autojacvec isa MooncakeVJP
509+
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
510+
out .= p_grad
503511
else
504512
error("autojacvec choice $(sensealg.autojacvec) is not supported by GaussAdjoint")
505513
end

src/quadrature_adjoint.jl

+8
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,11 @@ function AdjointSensitivityIntegrand(sol, adj_sol, sensealg, dgdp = nothing)
235235
end
236236
paramjac_config = zero(y), zero(y), Enzyme.make_zero(pf)
237237
pJ = nothing
238+
elseif sensealg.autojacvec isa MooncakeVJP
239+
pf = get_pf(sensealg.autojacvec, prob, f)
240+
paramjac_config = get_paramjac_config(
241+
MooncakeLoaded(), sensealg.autojacvec, pf, p, f, y, tspan[2])
242+
pJ = nothing
238243
elseif isautojacvec # Zygote
239244
paramjac_config = nothing
240245
pf = nothing
@@ -288,6 +293,9 @@ function vec_pjac!(out, λ, y, t, S::AdjointSensitivityIntegrand)
288293
else
289294
out[:] .= vec(tmp[1])
290295
end
296+
elseif sensealg.autojacvec isa MooncakeVJP
297+
_, _, p_grad = mooncake_run_ad(paramjac_config, y, p, t, λ)
298+
out .= p_grad
291299
elseif sensealg.autojacvec isa EnzymeVJP
292300
tmp3, tmp4, tmp6 = paramjac_config
293301
tmp4 .= λ

src/sensitivity_algorithms.jl

+17
Original file line numberDiff line numberDiff line change
@@ -1226,6 +1226,23 @@ struct ReverseDiffVJP{compile} <: VJPChoice
12261226
ReverseDiffVJP(compile = false) = new{compile}()
12271227
end
12281228

1229+
"""
1230+
```julia
1231+
MooncakeVJP <: VJPChoice
1232+
```
1233+
1234+
Uses Mooncake.jl to compute the vector-Jacobian products.
1235+
1236+
Does not support GPUs (CuArrays).
1237+
1238+
## Constructor
1239+
1240+
```julia
1241+
MooncakeVJP()
1242+
```
1243+
"""
1244+
struct MooncakeVJP <: VJPChoice end
1245+
12291246
@inline convert_tspan(::ForwardDiffSensitivity{CS, CTS}) where {CS, CTS} = CTS
12301247
@inline convert_tspan(::Any) = nothing
12311248
@inline function alg_autodiff(alg::AbstractSensitivityAlgorithm{

test/adjoint.jl

+59-11
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ _, easy_res14 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
135135
abstol = 1e-14,
136136
reltol = 1e-14,
137137
sensealg = GaussAdjoint())
138+
_, easy_res15 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
139+
abstol = 1e-14,
140+
reltol = 1e-14,
141+
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
142+
_, easy_res16 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
143+
abstol = 1e-14,
144+
reltol = 1e-14,
145+
sensealg = QuadratureAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
138146
_, easy_res142 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
139147
abstol = 1e-14,
140148
reltol = 1e-14,
@@ -158,6 +166,10 @@ _, easy_res146 = adjoint_sensitivities(sol_nodense, Tsit5(), t = t, dgdu_discret
158166
sensealg = GaussAdjoint(checkpointing = true,
159167
autojacvec = false),
160168
checkpoints = sol.t[1:500:end])
169+
_, easy_res147 = adjoint_sensitivities(solb, Tsit5(), t = t, dgdu_discrete = dg,
170+
abstol = 1e-14,
171+
reltol = 1e-14,
172+
sensealg = GaussAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
161173
adj_prob = ODEAdjointProblem(sol,
162174
QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14,
163175
autojacvec = SciMLSensitivity.ReverseDiffVJP()),
@@ -189,11 +201,14 @@ res, err = quadgk(integrand, 0.0, 10.0, atol = 1e-14, rtol = 1e-12)
189201
@test isapprox(res, easy_res12, rtol = 1e-9)
190202
@test isapprox(res, easy_res13, rtol = 1e-9)
191203
@test isapprox(res, easy_res14, rtol = 1e-9)
204+
@test isapprox(res, easy_res15, rtol = 1e-9)
205+
@test isapprox(res, easy_res16, rtol = 1e-9)
192206
@test isapprox(res, easy_res142, rtol = 1e-9)
193207
@test isapprox(res, easy_res143, rtol = 1e-9)
194208
@test isapprox(res, easy_res144, rtol = 1e-9)
195209
@test isapprox(res, easy_res145, rtol = 1e-9)
196210
@test isapprox(res, easy_res146, rtol = 1e-9)
211+
@test isapprox(res, easy_res147, rtol = 1e-9)
197212

198213
println("OOP adjoint sensitivities ")
199214

@@ -203,14 +218,11 @@ _, easy_res = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
203218
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
204219
abstol = 1e-14,
205220
reltol = 1e-14,
206-
sensealg = QuadratureAdjoint(abstol = 1e-14,
207-
reltol = 1e-14))
221+
sensealg = QuadratureAdjoint(abstol = 1e-14, reltol = 1e-14))
208222
_, easy_res22 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
209223
abstol = 1e-14,
210224
reltol = 1e-14,
211-
sensealg = QuadratureAdjoint(autojacvec = false,
212-
abstol = 1e-14,
213-
reltol = 1e-14))
225+
sensealg = QuadratureAdjoint(autojacvec = false, abstol = 1e-14, reltol = 1e-14))
214226
_, easy_res2 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
215227
abstol = 1e-14,
216228
reltol = 1e-14,
@@ -224,17 +236,15 @@ _, easy_res3 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
224236
@test easy_res32 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
225237
abstol = 1e-14,
226238
reltol = 1e-14,
227-
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa
228-
AbstractArray
239+
sensealg = InterpolatingAdjoint(autojacvec = false))[1] isa AbstractArray
229240
_, easy_res4 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
230241
abstol = 1e-14,
231242
reltol = 1e-14,
232243
sensealg = BacksolveAdjoint())
233244
@test easy_res42 = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
234245
abstol = 1e-14,
235246
reltol = 1e-14,
236-
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa
237-
AbstractArray
247+
sensealg = BacksolveAdjoint(autojacvec = false))[1] isa AbstractArray
238248
_, easy_res5 = adjoint_sensitivities(soloop,
239249
Kvaerno5(nlsolve = NLAnderson(), smooth_est = false),
240250
t = t, dgdu_discrete = dg, abstol = 1e-12,
@@ -248,8 +258,7 @@ _, easy_res6 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discre
248258
_, easy_res62 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
249259
dgdu_discrete = dg, abstol = 1e-14,
250260
reltol = 1e-14,
251-
sensealg = InterpolatingAdjoint(checkpointing = true,
252-
autojacvec = false),
261+
sensealg = InterpolatingAdjoint(checkpointing = true, autojacvec = false),
253262
checkpoints = soloop_nodense.t[1:5:end])
254263

255264
_, easy_res8 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_discrete = dg,
@@ -289,6 +298,39 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
289298
reltol = 1e-14,
290299
sensealg = GaussAdjoint(checkpointing = true),
291300
checkpoints = soloop_nodense.t[1:5:end])
301+
302+
_, easy_res2_mc_quad = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
303+
abstol = 1e-14,
304+
reltol = 1e-14,
305+
sensealg = QuadratureAdjoint(
306+
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
307+
_, easy_res2_mc_interp = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
308+
abstol = 1e-14,
309+
reltol = 1e-14,
310+
sensealg = InterpolatingAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
311+
_, easy_res2_mc_back = adjoint_sensitivities(soloop, Tsit5(), t = t, dgdu_discrete = dg,
312+
abstol = 1e-14,
313+
reltol = 1e-14,
314+
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
315+
_, easy_res6_mc_quad = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
316+
dgdu_discrete = dg,
317+
abstol = 1e-14,
318+
reltol = 1e-14,
319+
sensealg = QuadratureAdjoint(
320+
abstol = 1e-14, reltol = 1e-14, autojacvec = SciMLSensitivity.MooncakeVJP()))
321+
_, easy_res6_mc_interp = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
322+
dgdu_discrete = dg,
323+
abstol = 1e-14,
324+
reltol = 1e-14,
325+
sensealg = InterpolatingAdjoint(checkpointing = true,
326+
autojacvec = SciMLSensitivity.MooncakeVJP()),
327+
checkpoints = soloop_nodense.t[1:5:end])
328+
_, easy_res6_mc_back = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t,
329+
dgdu_discrete = dg,
330+
abstol = 1e-14,
331+
reltol = 1e-14,
332+
sensealg = BacksolveAdjoint(autojacvec = SciMLSensitivity.MooncakeVJP()))
333+
292334
@test isapprox(res, easy_res, rtol = 1e-10)
293335
@test isapprox(res, easy_res2, rtol = 1e-10)
294336
@test isapprox(res, easy_res22, rtol = 1e-10)
@@ -309,6 +351,12 @@ _, easy_res123 = adjoint_sensitivities(soloop_nodense, Tsit5(), t = t, dgdu_disc
309351
@test isapprox(res, easy_res12, rtol = 1e-9)
310352
@test isapprox(res, easy_res122, rtol = 1e-9)
311353
@test isapprox(res, easy_res123, rtol = 1e-4)
354+
@test isapprox(res, easy_res2_mc_quad, rtol = 1e-9)
355+
@test isapprox(res, easy_res2_mc_interp, rtol = 1e-9)
356+
@test isapprox(res, easy_res2_mc_back, rtol = 1e-9)
357+
@test isapprox(res, easy_res6_mc_quad, rtol = 1e-4)
358+
@test isapprox(res, easy_res6_mc_interp, rtol = 1e-9)
359+
@test isapprox(res, easy_res6_mc_back, rtol = 1e-9)
312360

313361
println("Calculate adjoint sensitivities ")
314362

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using SciMLSensitivity, SafeTestsets
22
using Test, Pkg
3+
import Mooncake
34

45
const GROUP = get(ENV, "GROUP", "All")
56

0 commit comments

Comments
 (0)