# Parameter Inference on Differential Equations

We need to charge the package DifferentialEquations.jl and write our differential equation as a function.

DiffEqFlux.jl is an implicit deep learning library built using the SciML ecosystem. 

Optim is a Julia package for optimizing functions of various kinds.

In [None]:
using Pkg
Pkg.add("DifferentialEquations")
Pkg.add("Plots")
Pkg.add("DiffEqFlux")
Pkg.add("Optim")

In [None]:
using DifferentialEquations

In [None]:
using Plots; gr()

In [None]:
using DiffEqFlux, Optim

# Defining the Equation

In [None]:
function henonheilis(du,u,p,t)     
    β,δ = p
    du[1] = u[3]
    du[2] = u[4]
    du[3] = -u[1] - β*u[1]*u[2]
    du[4] = -u[2] + δ*(-u[1]^2+u[2]^2)
end 

In [None]:
u0 = [0.2;0.0;0.4;0.0];
tspan = (0.0,500.0);
p = [2.0,1.0];
prob = ODEProblem(henonheilis,u0,tspan,p)
sol=solve(prob,Vern9(),saveat=0.1);

In [None]:
dataset = Array(sol);

In [None]:
pinit = [1.90,0.85];

Our goal will be to find parameters that make the Henon-Heiles solution the one we had in the first part, so we define our loss as the squared distance from our real solution dataset = Array(sol) with parameters p given by β = 2, and δ = 1. 

In [None]:
function loss(p)
    chn_prob = remake(prob, p = p)
    chn_sol = solve(chn_prob,Vern9(),saveat = 0.1)
    sum(abs2, Array(chn_sol) - dataset), chn_sol
end

In every step solutions which solved with given parameters p = [2.0,1.0] and the solution which is trained is going to shown in the graph. The key feature to do this Callbacks.

In [None]:
function plot_callback(p,l,chn_sol)
    chn_prob = remake(prob, p = p)
    chn_sol = solve(chn_prob, Vern9(), saveat = 0.1)
    dataset2 = Array(chn_sol)
    p1 = plot(sol.t, dataset2'[:,1], xlim=(0,150)) 
    p2 = plot(sol.t, dataset2'[:,2], xlim=(0,150)) 
    p3 = plot(sol.t, dataset2'[:,3], xlim=(0,150))
    p4 = plot(sol.t, dataset2'[:,4], xlim=(0,150)) 
    p5 = plot!(p1, sol.t,dataset'[:,1],xlim=(0,150))
    p6 = plot!(p2, sol.t,dataset'[:,2],xlim=(0,150))
    p7 = plot!(p3, sol.t,dataset'[:,3],xlim=(0,150))
    p8 = plot!(p4, sol.t,dataset'[:,4],xlim=(0,150))
    fig = plot!(p5, p6, p7, p8, layout = (2, 2), legend = false)
    display(fig)
    false
end

sciml_train allows defining a callback that will be called at each step of our training loop. It takes in the current parameter vector and the returns of the last call to the loss function. We will display the current loss and make a plot of the current situation.

There different type of method to solve your problem, however we only use BFGS and ADAM for this problem. This methods can the problem types.

Bascily, BFGS makes us converge quicker than using ADAM. Usually, ADAM is pretty good for the first iterations to get local optima but then it's better to change to BFGS to do the final steps.

In [None]:
res = DiffEqFlux.sciml_train(loss, pinit, ADAM(0.01),cb=plot_callback, maxiters = 200)

In [None]:
res2 = DiffEqFlux.sciml_train(loss,res.minimizer,BFGS(initial_stepnorm=0.02), cb=plot_callback,maxiters=300)

###  A closer look at problem solutions.

The first one is solved with parameters that are found by the training algorithm. 

The other one is solved with the given parameters.

In [None]:
u0 = [0.2;0.0;0.4;0.0];
tspan = (0.0,500.0);
p = res2.minimizer;
prob = ODEProblem(henonheilis,u0,tspan,p)
sol2 = solve(prob,Vern9(),saveat=0.1);
dataset2 = Array(sol2)
p1 = plot(sol2.t, dataset2'[:,1], xlim=(0,250)) 
p2 = plot(sol2.t, dataset2'[:,2], xlim=(0,250)) 
p3 = plot(sol2.t, dataset2'[:,3], xlim=(0,250))
p4 = plot(sol2.t, dataset2'[:,4], xlim=(0,250))
plot(p1, p2, p3, p4, layout = (2, 2), legend = false)

In [None]:
plot(sol2, vars=(1,2))

In [None]:
u0 = [0.2;0.0;0.4;0.0];
tspan = (0.0,500.0);
p = [2.0,1.0];
prob = ODEProblem(henonheilis,u0,tspan,p)
sol = solve(prob,Vern9(),saveat=0.1);
dataset = Array(sol)
p1 = plot(sol.t, dataset'[:,1], xlim=(0,250))
p2 = plot(sol.t, dataset'[:,2], xlim=(0,250)) 
p3 = plot(sol.t, dataset'[:,3], xlim=(0,250))
p4 = plot(sol.t, dataset'[:,4], xlim=(0,250))
plot(p1, p2, p3, p4, layout = (2, 2), legend = false)

In [None]:
plot(sol, vars=(1,2))