In [None]:
import os
import numpy as np
import sunode
import sympy as sym
import matplotlib.pyplot as plt
from scipy import optimize
import theano.tensor as tt
import theano
import xarray as xr

from nitrogene.common import sunsolve, ParamSet
from nitrogene.common.better_lambdify import lambdify_consts

In [None]:
def make_ode(*, ode_error_order=None):
    assert ode_error_order is None # TODO
    paramset = ParamSet.ParamSet(
        [
            ('a', True),
            ('b', True),
            ('c', True),
            ('d', False),
            ('e', True),
            ('f', [
                ('g', False, (3,))
            ])
        ],
        {
            'a': 0.1,
            'b': 0.22,
            'c': 0.11,
            'd': 0.1,
            'e': 0.5,
            'f': {
                'g': np.ones(3)
            }
        }
    )

    states = {
        'x': (),
        'y': (),
    }
    
    def rhs_sympy(t, y, params):
        return {
            'x': y.y,
            'y': params.b,
        }

    return sunsolve.SympyOde(paramset, states, rhs_sympy)

In [None]:
ode = make_ode()

In [None]:
tvals = np.linspace(0, 10, 20)

solver = sunsolve.AdjointSolver(
    ode.n_states,
    ode.make_sundials_rhs(),
    ode.make_sundials_adjoint(),
    ode.make_sundials_adjoint_quad(),
    tvals,
    ode.user_data.copy(),
    n_params=ode.n_params,
)

In [None]:
out, grad_out, lamda_out = solver.make_output_buffers(tvals)

In [None]:
y0 = np.ones(ode.n_states)
%timeit solver.solve_forward(0, tvals, y0, out)

In [None]:
grads = np.ones((len(tvals), ode.n_states))

In [None]:
%%timeit
solver.solve_forward(0, tvals, y0, out)
solver.solve_backward(tvals[-1], 0, tvals, grads, grad_out, lamda_out)

In [None]:
sol = ode.xarray_solution(tvals, out, solver.user_data)
sol

In [None]:
#sol.to_zarr('solution.zarr')

In [None]:
import pandas as pd
sol = ode.xarray_solution(tvals, out, solver.user_data, unstack_state=False)
pd.DataFrame.from_records(sol.solution.data).plot()

In [None]:
params = tt.dvector('params')
y0 = tt.dvector('y0')

solve_ode = sunsolve.SolveODEAdjoint(solver, 0, tvals)

out = solve_ode(y0, params)

loss = ((out[:, :] - 1) ** 2).sum()
grads = tt.grad(loss, [y0, params])

func = theano.function([y0, params], [loss, out, *grads], profile=True)

In [None]:
y0 = np.random.randn(ode.n_states) + 2
params = np.random.randn(ode.n_params) + 2
#%timeit func(y0, params)

In [None]:
eps = 1e-6
h = np.zeros_like(y0)
h[0] += eps
a = func(y0, params)
b = func(y0 + h, params)

In [None]:
(b[0] - a[0]) / h, a[2]

In [None]:
eps = 1e-3
h = np.zeros_like(params)
h[1] += eps
a = func(y0, params)
b = func(y0, params + h)

In [None]:
(b[0] - a[0]) / h, a[3]

In [None]:
import pymc3 as pm

with pm.Model() as model:
    params = pm.HalfNormal('params', sd=10, shape=ode.n_params)
    y0 = pm.HalfNormal('y0', shape=ode.n_states)
    
    mu = solve_ode(y0, params)
    pm.Normal('y', mu=mu[:, 0], sd=0.1, observed=np.arange(len(tvals)) + 1)
    
    trace = pm.sample()

In [None]:
with model:
    tr = pm.sample_posterior_predictive(trace)