In [1]:
import jax
import jax.numpy as np
import numpy as onp
from jax.flatten_util import ravel_pytree
from functools import partial
import matplotlib.pyplot as plt
import scipy.optimize as opt

In [None]:
# About the decorator: f is a function (not JAX array) and JAX needs to know that so we specify it as "static"
@partial(jax.jit, static_argnums=(2,))
def rk4(state, t_crt, f, diff_args):
    y_prev, t_prev = state
    h = t_crt - t_prev
    k1 = h * f(y_prev, t_prev, diff_args)
    k2 = h * f(y_prev + k1/2., t_prev + h/2., diff_args)
    k3 = h * f(y_prev + k2/2., t_prev + h/2., diff_args)
    k4 = h * f(y_prev + k3, t_prev + h, diff_args)
    y_crt = y_prev + 1./6 * (k1 + 2 * k2 + 2 * k3 + k4)
    return (y_crt, t_crt), y_crt

def Lorenz_rhs_func(state, t, diff_args):
    rho, sigma, beta = diff_args[0]
    x, y, z = state
    return np.array([sigma * (y - x), x * (rho - z) - y, x * y - beta * z])

def odeint(stepper, f, y0, ts, diff_args):

    def stepper_partial(state, t_crt):
        return stepper(state, t_crt, f, diff_args)

    ys = []
    state = (y0, ts[0])
    for (i, t_crt) in enumerate(ts[1:]):
        state, y = stepper_partial(state, t_crt)

        # if i % 100 == 0:
        #     print(f"step {i}")
        #     if not np.all(np.isfinite(y)):
        #         print(f"Found np.inf or np.nan in y - stop the program")      
        #         exit()

        ys.append(y)
    ys = np.array(ys)
    return ys


def obj_func(yf, target_yf):
    return np.sum((yf - target_yf)**2)


def compute_gradient_finite_difference(y0, ts, obj_func, state_rhs_func, diff_args):
    h = 1e-1
    diff_args_flat, unravel = ravel_pytree(diff_args)

    grads = []
    for i in range(len(diff_args_flat)):
        diff_args_flat_plus = jax.ops.index_add(diff_args_flat, i, h)
        ys_plus_ = odeint(rk4, state_rhs_func, y0, ts, unravel(diff_args_flat_plus))
        tau_plus = obj_func(ys_plus_[-1])

        diff_args_flat_minus = jax.ops.index_add(diff_args_flat, i, -h)
        ys_minus_ = odeint(rk4, state_rhs_func, y0, ts, unravel(diff_args_flat_minus))
        tau_minus = obj_func(ys_minus_[-1])

        grad = (tau_plus - tau_minus) / (2 * h)
        grads.append(grad)

    grads = unravel(np.array(grads))

    print(f"running finite difference")

    return grads

def compute_gradient_ad(ys, ts, obj_func, state_rhs_func, diff_args):
    rev_ts = ts[::-1]
    rev_ys = ys[::-1]
    y_bar, diff_args_bar = jax.grad(obj_func)(rev_ys[0]), jax.tree_map(np.zeros_like, diff_args)
    y_bars = [y_bar]
    diff_args_bars = [diff_args_bar]

    @jax.jit
    def adjoint_fn(y_prev, t_prev, t_crt, y_bar):
        y_dot, vjpfun = jax.vjp(lambda y_prev, diff_args: rk4((y_prev, t_prev), t_crt, state_rhs_func, diff_args)[1], y_prev, diff_args)
        y_bar, diff_args_bar = vjpfun(y_bar)  
        return y_bar, diff_args_bar   

    for i in range(len(ts) - 1):
        y_prev = rev_ys[i + 1]
        t_prev = rev_ts[i + 1]
        y_crt = rev_ys[i]
        t_crt = rev_ts[i]

        y_bar, diff_args_bar = adjoint_fn(y_prev, t_prev, t_crt, y_bar)

        y_bars.append(y_bar)
        diff_args_bars.append(diff_args_bar)

        if i % 100 == 0:
            print(f"Reverse step {i}") 
            if not np.all(np.isfinite(y_bar)):
                print(f"Found np.inf or np.nan in y - stop the program")             
                exit()

    y_bars = np.stack(y_bars)
    grads = jax.tree_multimap(lambda *xs: np.sum(np.stack(xs), axis=0), *diff_args_bars)


    print(f"running autodiff")

    return grads



def optimize(diff_args_0, y0, ts, obj_func, state_rhs_func, bounds=None):
    x_ini, unravel = ravel_pytree(diff_args_0)
    obj_vals = []

    def objective(x):
        print(f"\n######################### Evaluating objective value - step {objective.counter}")
        diff_args = unravel(x)
        ys_ = odeint(rk4, state_rhs_func, y0, ts, diff_args)
        ys = np.vstack((y0[None, ...], ys_))
        obj_val = obj_func(ys[-1])

        objective.diff_args = diff_args
        objective.ys = ys
        objective.x = diff_args

        objective.counter += 1
        obj_vals.append(obj_val)

        print(f"obj_val = {obj_val}")
        print(f"diff_args = {diff_args}")
        print(f"ys[-1] = {ys[-1]}")

        return obj_val

    def derivative(x):
        diff_args = objective.diff_args
        ys = objective.ys

        grads = compute_gradient_finite_difference(y0, ts, obj_func, state_rhs_func, diff_args)
        print(f"########################################################")
        print(f"grads = {grads}")
        print(f"########################################################")

        grads = compute_gradient_ad(ys, ts, obj_func, state_rhs_func, diff_args)
        print(f"########################################################")
        print(f"grads = {grads}")
        print(f"########################################################")

        grads_ravelled, _ = ravel_pytree(grads)
        # 'L-BFGS-B' requires the following conversion, otherwise we get an error message saying
        # -- input not fortran contiguous -- expected elsize=8 but got 4
        return onp.array(grads_ravelled, order='F', dtype=onp.float64)

    objective.counter = 0
    options = {'maxiter': 1000, 'disp': True}  # CG or L-BFGS-B or Newton-CG or SLSQP
    res = opt.minimize(fun=objective,
                       x0=x_ini,
                       method='SLSQP',
                       jac=derivative,
                       bounds=bounds,
                       callback=None,
                       options=options)

    return objective.x


def ground_truth(y0, ts):
    sigma = 1e-3
    key = jax.random.PRNGKey(0)
    key, noises = split_and_sample(key, sigma, (len(ts) - 1,), 3)

    diff_args_gt = [np.array([28., 10., 8./3.])]

    ys_ = odeint(rk4, Lorenz_rhs_func, y0, ts, diff_args_gt)
    ys = np.vstack((y0[None, ...], ys_))
    print(ys[-1])
    fig = plt.figure(figsize=(6, 4), dpi=150)
    ax = fig.gca(projection='3d')
    plot_3d_path(ax, ys, 'b')
    return ys_[-1]


In [None]:
def exp():
    y0 = np.array([1., 1., 1.])
    dt = 1e-2
    ts = np.arange(0, 101*dt, dt)
    target_yf = ground_truth(y0, ts)

    # diff_args_0 = [np.array([27.63452, 10.724965 , 8./3.])]
    diff_args_0 = [np.array([26., 10., 2.])]
    # diff_args_0 = [np.array([27.8, 10.3, 8./3.])]

    # obj_func_partial = lambda yf: obj_func(yf, target_yf)
    # bounds = np.array([[26., 30.], [8., 12.], [8./3., 8./3.]])
    # optimize(diff_args_0, y0, ts, obj_func_partial, Lorenz_rhs_func, bounds)