# Training Neural Ordinary Differential Equations

Docs: https://diffeqflux.sciml.ai/dev/examples/neural_ode/

## Runtime information

In [1]:
versioninfo()

Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: AMD EPYC 7B13
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-12.0.1 (ORCJIT, znver3)
Environment:
  JULIA_PATH = /usr/local/julia/


In [2]:
using Pkg
Pkg.status()

[32m[1m      Status[22m[39m `/tmp/cirrus-ci-build/Project.toml`
 [90m [aae7a2af] [39mDiffEqFlux v1.51.2
 [90m [0c46a032] [39mDifferentialEquations v7.2.0
 [90m [b2108857] [39mLux v0.4.9
 [90m [7f7a1694] [39mOptimization v3.7.1
 [90m [36348300] [39mOptimizationOptimJL v0.1.1
 [90m [500b13db] [39mOptimizationPolyalgorithms v0.1.0
 [90m [91a5bcdd] [39mPlots v1.31.2
 [90m [37e2e46d] [39mLinearAlgebra
 [90m [10745b16] [39mStatistics


## First N-ODE example

A neural ODE is an ODE where a neural network defines its derivative function. $\dot{u} = NN(u)$

In [3]:
using Lux
using DiffEqFlux
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using Random
using Plots

rng = Random.default_rng()

TaskLocalRNG()

In [4]:
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

0.0f0:0.05172414f0:1.5f0

A sprial ODE to train against.

In [5]:
const true_A = Float32[-0.1 2.0; -2.0 -0.1]

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

trueODEfunc! (generic function with 1 method)

In [6]:
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

2×30 Matrix{Float32}:
 2.0  1.9465    1.74178  1.23837  0.577127  …  1.40696   1.37033   1.29217
 0.0  0.798831  1.46473  1.80877  1.86465      0.451557  0.728934  0.972362

In [7]:
nodeFunc = Lux.Chain(
    ActivationFunction(x -> x.^3),
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, nodeFunc)

((layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.103460394 0.31779912; 0.121634796 0.20668077; … ; 0.038555816 0.15278105; 0.20319603 0.3091554], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.011574831 -0.18084669 … 0.30424872 0.16619845; -0.26771912 -0.30523896 … -0.2698946 -0.066985734], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

In [8]:
# Parameters for neural network
p

(layer_1 = NamedTuple(), layer_2 = (weight = Float32[0.103460394 0.31779912; 0.121634796 0.20668077; … ; 0.038555816 0.15278105; 0.20319603 0.3091554], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.011574831 -0.18084669 … 0.30424872 0.16619845; -0.26771912 -0.30523896 … -0.2698946 -0.066985734], bias = Float32[0.0; 0.0;;]))

In [9]:
prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)

NeuralODE()         [90m# 252 parameters[39m

In [10]:
function predict_neuralode(p)
    Array(prob_node(u0, p, st)[1])
end
  
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

# Callback function to observe training
anim = Animation()
callback = function (p, l, pred; doplot = true)
  # display(l)
  if doplot
	plt = scatter(tsteps, ode_data[1,:], label = "data")
    scatter!(plt, tsteps, pred[1,:], label = "prediction")
    frame(anim)
    # display(plot(plt))
  end
  return false
end

#3 (generic function with 1 method)

In [11]:
# Train using the ADAM optimizer
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))

[36mOptimizationProblem[0m. In-place: [36mtrue[0m
u0: [0mComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[0.103460394 0.31779912; 0.121634796 0.20668077; … ; 0.038555816 0.15278105; 0.20319603 0.3091554], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[0.011574831 -0.18084669 … 0.30424872 0.16619845; -0.26771912 -0.30523896 … -0.2698946 -0.066985734], bias = Float32[0.0; 0.0;;]))

In [12]:
result_neuralode = Optimization.solve(
    optprob,
    ADAM(0.05),
    callback = callback,
    maxiters = 300
)

u: [0mComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.6050676 0.33445832; -0.8355316 0.8275537; … ; -0.3130034 0.40927017; -0.060234003 0.7589445], bias = Float32[0.11882788; -0.11025208; … ; -0.11265482; -0.9667147;;]), layer_3 = (weight = Float32[-0.90814555 -0.7977073 … -0.88943666 -0.5020189; -0.6035449 0.51919854 … -0.80775946 -0.079122886], bias = Float32[-0.5328792; 0.17050569;;]))

In [13]:
# Retrain using the LBFGS optimizer
optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    LBFGS(),
    callback = callback,
    allow_f_increases = false
)

u: [0mComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.58544064 0.3144334; -1.0717564 0.6722638; … ; -0.1276328 0.38965508; -0.17196192 0.7304533], bias = Float32[0.118197456; -0.15167712; … ; -0.117327794; -1.0128113;;]), layer_3 = (weight = Float32[-0.95375854 -0.86862737 … -0.93861306 -0.60804164; -0.6028863 0.57106614 … -0.8449319 -0.20607091], bias = Float32[-0.46012595; 0.40458617;;]))

In [14]:
mp4(anim, fps=15)

┌ Info: Saved animation to 
│   fn = /tmp/cirrus-ci-build/docs/tmp.mp4
└ @ Plots /root/.julia/packages/Plots/OeNV1/src/animation.jl:126


## Multiple Shooting

Docs: <https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/>

In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.

In [15]:
using Lux
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using DifferentialEquations
using DiffEqFlux: group_ranges
using Random
rng = Random.default_rng()

TaskLocalRNG()

In [16]:
# Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

0.0f0:0.1724138f0:5.0f0

In [17]:
# const true_A = Float32[-0.1 2.0; -2.0 -0.1]
# Get the data
function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

trueODEfunc! (generic function with 1 method)

In [18]:
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

2×30 Matrix{Float32}:
 2.0  1.02407  -1.07772  -1.70874   …  0.332336  0.0334096  -0.252243
 0.0  1.84867   1.72465   0.323451     0.958119  0.946642    0.931118

In [19]:
# Define the Neural Network
nn = Lux.Chain(
    ActivationFunction(x -> x.^3),
    Lux.Dense(2, 16, tanh),
    Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)

((layer_1 = NamedTuple(), layer_2 = (weight = Float32[-0.5716352 0.41558835; 0.24233463 -0.2886684; … ; 0.50533813 -0.15855698; -0.2739899 -0.53264177], bias = Float32[0.0; 0.0; … ; 0.0; 0.0;;]), layer_3 = (weight = Float32[-0.22297698 -0.4532547 … -0.12405861 0.24362002; 0.31044325 -0.4119331 … 0.5233362 -0.09594146], bias = Float32[0.0; 0.0;;])), (layer_1 = NamedTuple(), layer_2 = NamedTuple(), layer_3 = NamedTuple()))

In [20]:
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, Lux.ComponentArray(p_init))

[36mODEProblem[0m with uType [36mVector{Float32}[0m and tType [36mFloat32[0m. In-place: [36mfalse[0m
timespan: (0.0f0, 5.0f0)
u0: 2-element Vector{Float32}:
 2.0
 0.0

In [21]:
function plot_multiple_shoot(plt, preds, group_size)
	step = group_size-1
	ranges = group_ranges(datasize, group_size)

	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
	end
end

# Animate training
anim = Animation()
callback = function (p, l, preds; doplot = true)
  # display(l)
  if doplot
	# plot the original data
	plt = scatter(tsteps, ode_data[1,:], label = "Data")

	# plot the different predictions for individual shoot
	plot_multiple_shoot(plt, preds, group_size)

    frame(anim)
    # display(plot(plt))
  end
  return false
end


#12 (generic function with 1 method)

In [22]:
# Define parameters for Multiple Shooting
group_size = 3
continuity_term = 200

function loss_function(data, pred)
	return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
    return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
                          group_size; continuity_term)
end

loss_multiple_shooting (generic function with 1 method)

In [23]:
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)

u: [0mComponentVector{Float32}(layer_1 = Float32[], layer_2 = (weight = Float32[-0.15761043 0.311609; -0.095350154 0.13569558; … ; 0.16677283 0.166042; -0.34521943 -0.588498], bias = Float32[1.3108658; -0.2862964; … ; -0.30260047; 2.4755518;;]), layer_3 = (weight = Float32[-1.7368861 -0.9135754 … -0.39605406 2.011941; -0.27080452 -1.115441 … 1.7061656 -1.3873035], bias = Float32[-0.58091336; -0.038488653;;]))

In [24]:
mp4(anim, fps=15)

┌ Info: Saved animation to 
│   fn = /tmp/cirrus-ci-build/docs/tmp.mp4
└ @ Plots /root/.julia/packages/Plots/OeNV1/src/animation.jl:126
