# Neural ODEs

This notebook shows how to use Julia's DiffEqFlux library to implement neural ODEs. This is just a simple example, but see [here](https://diffeqflux.sciml.ai/stable/) for official tutorials and examples.

Use 'Shift + Enter' to run a cell.

We first need to load in the Julia libraries for building and training neural ODEs. This step might take a couple minutes.

In [None]:
using DiffEqFlux, OrdinaryDiffEq, Flux, Optim
using Plots

The first thing we need to do is load in some data. We will choose a spiral of the form $\langle x(t),y(t) \rangle= \langle (1+t)\cos(t), (1+t)\sin(t) \rangle$. This is inspired by an example given in the [2018 Neural ODEs paper](https://arxiv.org/abs/1806.07366).

In [None]:
ϵ = 0.3                 # scale of noise
t_final = 4.0           # final time point
n_data = 60             # number of data points

t_data = range(0,stop=t_final,length=n_data)
data = zeros(2,n_data)
data[1,:] = (t_data .+ 1).*cos.(t_data) + ϵ*rand(n_data)
data[2,:] = (t_data .+ 1).*sin.(t_data) + ϵ*rand(n_data)

To give us an idea of what we're looking at, let's plot it.

In [None]:
scatter(data[1,:],data[2,:],xlabel="x",ylabel="y",title="Spiral Data",label="")

Now that we have our data, we need to set up a machine learning model. We could use a regression technique, or a classical neural network. Instead, we'll use a "Neural ODE", where the goal is to model the *derivatives* of our data. We could do this from scratch, but instead we will make use of the Julia library [DiffEqFlux](https://diffeqflux.sciml.ai/dev/).

In [None]:
u0 = Float32[data[1,1]; data[2,1]]            # initial condition for ODE
tspan = (0.0f0, Float32(t_final))             # span of data (as 32-bit float)
n_dims = length(u0)                           # dimensions of data
width = 20                                    # width of neural network

model = FastChain(FastDense(n_dims, width, tanh),
                  FastDense(width, n_dims))
prob_neuralode = NeuralODE(model, tspan, Tsit5(), 
                           saveat = t_data,
                           relerr = 1e-6, abserr = 1e-6)

The variable 'model' sets up our neural network, and 'prob_neuralode' sets up the problem to be solved using the [DifferentialEquations](https://diffeq.sciml.ai/v2.0/) library. This library has many advanced solver features such as adaptive time stepping, callback control, and many, many choices for discretization. Above we chose the Tsit5 solver which is [recommended](https://diffeq.sciml.ai/stable/solvers/ode_solve/) for non-stiff systems.

We now have to define two things:
* How to evaluate this neural network
* How the loss function is defined

In [None]:
function predict_neuralode(p)
  Array(prob_neuralode(u0, p))
end

function loss_neuralode(p)
    pred = predict_neuralode(p)
    loss = sum(abs2, data .- pred)
    return loss, pred
end

Because we are learning the dynamics of our data, and not the data values itself, whenever we evaluate our neural network we really need to solve an ODE system. This evaluation is computed in 'predict_neuralode', which takes in 'p' for the network parameters. In a real problem we might also want to make 'predict_neuralode' a function of the initial condition u0, but for our purposes this works fine as it is.

Our loss function is a standard sum of squares, defined in 'loss_neuralode', which also depends on the network parameters.

One last thing we will define is a 'callback function' to be inserted into the ODE solver. Typically these are used for event handling, but we will use it to plot the model prediction as it is trained.

In [None]:
iter = 0
callback = function (p, l, pred; doplot = true)
  global iter
  iter += 1

  if doplot && (mod(iter,3) == 0)
    IJulia.clear_output(true) #Passing true says to wait until new ouput before clearing, this prevents flickering
    plt = scatter(data[1,1:size(pred,2)], data[2,1:size(pred,2)], label = "data",title = string("iter: ",iter))
    scatter!(plt, pred[1,1:size(pred,2)], pred[2,1:size(pred,2)], label = "prediction",xlim=(-5,2.5),ylim=(-4,4))
    plot(plt) |> IJulia.display
  end

  return false
end

Everything is set up to run, all that is left to do is actually train the model. Again, this is where the DiffEqFlux library does the heavy lifting, as we only need to call a single function: 'sciml_train'. 

This function takes in: 
* loss function (loss_neuralode)
* model parameters (prob_neuralode.p)
* optimization method ( ADAM(0.05) )
* callback function (cb=callback)
* maximum iterations it should run (maxiters=300)

In [None]:
iter = 0       # reset iter (in case you run this multiple times)
result_neuralode = DiffEqFlux.sciml_train(loss_neuralode, prob_neuralode.p,
                                          ADAM(0.05), cb = callback,
                                          maxiters = 600)

and now our model is trained! The summary above give some idea of how the training process went. 

(Note: The NaN values in "Convergence measures" are actual values when I run this code outside of a jupyter notebook. Not sure what's going on here)

So, how do you access the parameter values that define your trained model? With result_neuralode.minimizer. So, for example, we could generate our prediction with 'predict_neuralode(result_neuralode.minimizer)'.

Occasionally, we will want to combine different optimization techniques on the same problem, for example using a cheap, coarse optimization method in the beginning and then switch to something that is more accurate once we get 'close' to the local minimum. The code below does exactly this, using the [L-BFGS optimization algorithm](https://en.wikipedia.org/wiki/Limited-memory_BFGS). Note that 'result_neuralode.minimizer' is passed in as the parameter value so that we start from the previous parameter values and don't start training the model from scratch again.

In [None]:
result_neuralode2 = DiffEqFlux.sciml_train(loss_neuralode,
                                           result_neuralode.minimizer,
                                           LBFGS(),
                                           cb = callback,
                                           allow_f_increases = true)

That's all there is to it! The Julia library DiffEqFlux allows us to handle much more complex cases than the one demonstrated here.

For more information, see one of the following links:
* [DiffEqFlux announcement](https://julialang.org/blog/2019/01/fluxdiffeq/)
* [DiffEqFlux github page](https://github.com/SciML/DiffEqFlux.jl)
* [SciML homepage](https://sciml.ai/)