# Physics-enhanced regression

In [1]:
import jax
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from jax.config import config

from odefilter import dense_output, ivpsolve, solvers
from odefilter.implementations import dense
from odefilter.strategies import filters, smoothers

config.update("jax_enable_x64", True)

if not backend.has_been_selected:
    backend.select("jax")

In [2]:
f, u0, (t0, t1), f_args = ivps.seir()
f_args = jnp.asarray(f_args)


@jax.jit
def vf(y, t, p):
    return f(y, *p)



In [3]:
# make data

ts = jnp.linspace(t0, t1, endpoint=True, num=3)

ek1 = solvers.MLESolver(
    strategy=filters.Filter(
        extrapolation=dense.IBM.from_params(ode_dimension=4),
        correction=dense.TaylorFirstOrder(ode_dimension=4),
    )
)

solution_true = ivpsolve.simulate_terminal_values(
    vf, initial_values=(u0,), t0=t0, t1=t1, solver=ek1, parameters=f_args + 0.05
)
data = solution_true.u
print(data)

[1.30798356e+02 2.55013487e-04 1.45378298e-03 8.70199923e+02]


In [4]:
# Initial guess

solution_wrong = ivpsolve.simulate_terminal_values(
    vf, initial_values=(u0,), t0=t0, t1=t1, solver=ek1, parameters=f_args
)
print(solution_wrong.u)

[5.89716318e+01 1.14974217e-03 1.45134400e-02 9.42012686e+02]


In [5]:
@jax.jit
def param_to_nmll(p):
    observation_std = jnp.ones_like(ts) * 0.1
    solution_wrong = ivpsolve.simulate_terminal_values(
        vf, initial_values=(u0,), t0=t0, t1=t1, solver=ek1, parameters=p
    )

    m_obs = ek1.strategy.correction._select_derivative(solution_wrong.marginals.mean, 0)
    l_obs = ek1.strategy.correction._select_derivative_vect(
        solution_wrong.marginals.cov_sqrtm_lower, 0
    )

    return (data - solution_wrong.u) @ (data - solution_wrong.u)


#     return (solution_wrong.u[-1, ...] - 20.) @ (solution_wrong.u[-1, ...] - 20.)
#     return dense_output.negative_marginal_log_likelihood(
#         observation_std=observation_std, u=data, solution=solution_wrong, solver=ek1
#     )

In [6]:
param_to_nmll(f_args)

DeviceArray(10316.15137609, dtype=float64)

In [7]:
df = jax.jit(jax.jacfwd(param_to_nmll))

In [None]:
f0 = f_args
f1 = f0 - 1e-1 * df(f0)
print(f1, f_args)
f1 = f1 - 1e-1 * df(f1)
print(f1, f_args)
f1 = f1 - 1e-1 * df(f1)
print(f1, f_args)

[-2.60744124e+00 -1.94274412e+04  5.82430278e+04  1.00384000e+03] [3.00e-01 3.00e-01 1.00e-01 9.98e+02]
