# Training Neural Ordinary Differential Equations

Docs: https://diffeqflux.sciml.ai/dev/examples/neural_ode/

## First N-ODE example

A neural ODE is an ODE where a neural network defines its derivative function. $\dot{u} = NN(u)$

In [None]:
using Lux
using DiffEqFlux
using DifferentialEquations
using Optimization
using OptimizationOptimJL
using Random
using Plots

rng = Random.default_rng()

In [None]:
u0 = Float32[2.0; 0.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

A sprial ODE to train against.

In [None]:
const true_A = Float32[-0.1 2.0; -2.0 -0.1]

function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

In [None]:
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

In [None]:
nodeFunc = Lux.Chain(
    ActivationFunction(x -> x.^3),
    Lux.Dense(2, 50, tanh),
    Lux.Dense(50, 2)
)

p, st = Lux.setup(rng, nodeFunc)

In [None]:
# Parameters for neural network
p

In [None]:
prob_node = NeuralODE(nodeFunc, tspan, Tsit5(), saveat = tsteps)

In [None]:
function predict_neuralode(p)
    Array(prob_node(u0, p, st)[1])
end
  
function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, ode_data .- pred)
    return loss, pred
end

# Callback function to observe training
anim = Animation()
callback = function (p, l, pred; doplot = true)
  # display(l)
  if doplot
	plt = scatter(tsteps, ode_data[1,:], label = "data")
    scatter!(plt, tsteps, pred[1,:], label = "prediction")
    frame(anim)
    # display(plot(plt))
  end
  return false
end

In [None]:
# Train using the ADAM optimizer
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_neuralode(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p))

In [None]:
result_neuralode = Optimization.solve(
    optprob,
    ADAM(0.05),
    callback = callback,
    maxiters = 300
)

In [None]:
# Retrain using the LBFGS optimizer
optprob2 = remake(optprob, u0 = result_neuralode.u)

result_neuralode2 = Optimization.solve(
    optprob2,
    LBFGS(),
    callback = callback,
    allow_f_increases = false
)

In [None]:
mp4(anim, fps=15)

## Multiple Shooting

Docs: <https://diffeqflux.sciml.ai/dev/examples/multiple_shooting/>

In Multiple Shooting, the training data is split into overlapping intervals. The solver is then trained on individual intervals.

In [None]:
using Lux
using DiffEqFlux
using Optimization
using OptimizationPolyalgorithms
using DifferentialEquations
using DiffEqFlux: group_ranges
using Random
rng = Random.default_rng()

In [None]:
# Define initial conditions and time steps
datasize = 30
u0 = Float32[2.0, 0.0]
tspan = (0.0f0, 5.0f0)
tsteps = range(tspan[begin], tspan[end], length = datasize)

In [None]:
# const true_A = Float32[-0.1 2.0; -2.0 -0.1]
# Get the data
function trueODEfunc!(du, u, p, t)
    du .= ((u.^3)'true_A)'
end

In [None]:
prob_trueode = ODEProblem(trueODEfunc!, u0, tspan)
ode_data = Array(solve(prob_trueode, Tsit5(), saveat = tsteps))

In [None]:
# Define the Neural Network
nn = Lux.Chain(
    ActivationFunction(x -> x.^3),
    Lux.Dense(2, 16, tanh),
    Lux.Dense(16, 2)
)
p_init, st = Lux.setup(rng, nn)

In [None]:
neuralode = NeuralODE(nn, tspan, Tsit5(), saveat = tsteps)
prob_node = ODEProblem((u,p,t)->nn(u,p,st)[1], u0, tspan, Lux.ComponentArray(p_init))

In [None]:
function plot_multiple_shoot(plt, preds, group_size)
	step = group_size-1
	ranges = group_ranges(datasize, group_size)

	for (i, rg) in enumerate(ranges)
		plot!(plt, tsteps[rg], preds[i][1,:], markershape=:circle, label="Group $(i)")
	end
end

# Animate training
anim = Animation()
callback = function (p, l, preds; doplot = true)
  # display(l)
  if doplot
	# plot the original data
	plt = scatter(tsteps, ode_data[1,:], label = "Data")

	# plot the different predictions for individual shoot
	plot_multiple_shoot(plt, preds, group_size)

    frame(anim)
    # display(plot(plt))
  end
  return false
end


In [None]:
# Define parameters for Multiple Shooting
group_size = 3
continuity_term = 200

function loss_function(data, pred)
	return sum(abs2, data - pred)
end

function loss_multiple_shooting(p)
    return multiple_shoot(p, ode_data, tsteps, prob_node, loss_function, Tsit5(),
                          group_size; continuity_term)
end

In [None]:
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p) -> loss_multiple_shooting(x), adtype)
optprob = Optimization.OptimizationProblem(optf, Lux.ComponentArray(p_init))
res_ms = Optimization.solve(optprob, PolyOpt(), callback = callback)

In [None]:
mp4(anim, fps=15)