# First Neural ODE example

A neural ODE is an ODE where a neural network defines its derivative function. $\dot{u} = NN(u)$

From: https://docs.sciml.ai/DiffEqFlux/stable/examples/neural_ode/

In [None]:
using Lux, DiffEqFlux, DifferentialEquations, ComponentArrays
using Optimization, OptimizationOptimJL, OptimizationOptimisers
using Random, Plots

rng = Random.default_rng()

True solution

In [None]:
function trueODEfunc(du, u, p, t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

The data used for training

In [None]:
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)
prob_trueode = ODEProblem(trueODEfunc, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

Make a `NeuralODE` problem with a neural network defined by `Lux.jl`.

In [None]:
dudt2 = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, dudt2)
prob_neuralode = NeuralODE(dudt2, tspan, Tsit5(), saveat = tsteps)

Define output, loss, and callback functions.

In [None]:
function predict_neuralode(p)
    Array(prob_neuralode(u0, p, st)[1])
  end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

Do not generate plots by default. Users could change doplot=true to see the figures in the callback fuction.

In [None]:
callback = function (p, l, pred; doplot = false)
    println(l)
    # plot current prediction against data
    if doplot
      plt = scatter(tsteps, ode_data[1,:], label = "data")
      scatter!(plt, tsteps, pred[1,:], label = "prediction")
      plot(plt)
    end
    return false
end

Try the callback function on the first iteration.

In [None]:
pinit = ComponentArray(p)
callback(pinit, loss_neuralode(pinit)...; doplot=true)

Use Optimization.jl to solve the problem.
- `Zygote` for automatic differentiation (AD)
- `loss_neuralode` as the function to be optimized
- Make an `OptimizationProblem`

In [None]:
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, pinit)

Solve the `OptimizationProblem`.

In [None]:
result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)

Use another optimization algorithm `Optim.BFGS()` and start from where the `ADAM()` algorithm stopped.

In [None]:
optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.BFGS(initial_stepnorm=0.01),
    callback=callback,
    allow_f_increases = false
)

Plot the solution to see if it matches the provided data.

In [None]:
callback(result_neuralode2.u, loss_neuralode(result_neuralode2.u)...; doplot=true)

## Animated solving process
Let's reset the problem and visualize the training process.

In [None]:
rng = Random.default_rng()
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

Setup truth values for validation

In [None]:
true_A = Float32[-0.1 2.0; -2.0 -0.1]

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

In [None]:
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

In [None]:
nodeFunc = Lux.Chain(
    x -> x.^3,
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, nodeFunc)

Parameters in the neural network:

In [None]:
p

Use `NeuroODE()` to construct the problem

In [None]:
prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)

Predicted values.

In [None]:
function predict_neuralode(p)
    Array(prob_node(u0, p, st)[1])
end

The loss function.

In [None]:
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

Callback function to observe training process

In [None]:
anim = Animation()
callback = function (p, l, pred; doplot = true)
    if doplot
        plt = scatter(tsteps, ode_data[1,:], label = "data")
        scatter!(plt, tsteps, pred[1,:], label = "prediction")
        frame(anim)
    end
    return false
end

In [None]:
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))

Solve the problem using the ADAM optimizer

In [None]:
result_neuralode = Optimization.solve(
    optprob,
    OptimizationOptimisers.ADAM(0.05),
    callback = callback,
    maxiters = 300
)

And then solve the problem using the LBFGS optimizer

In [None]:
optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    Optim.LBFGS(),
    callback = callback,
    allow_f_increases = false
)

Visualize fitting process

In [None]:
mp4(anim, fps=15)