# Physics-enhanced regression

In [1]:
import functools

import jax
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
from diffeqzoo import backend, ivps
from jax.config import config

from odefilter import dense_output, ivpsolve, solvers
from odefilter.implementations import dense, isotropic
from odefilter.strategies import filters, smoothers

config.update("jax_enable_x64", True)

if not backend.has_been_selected:
    backend.select("jax")

In [2]:
f, u0, (t0, t1), f_args = ivps.lotka_volterra()
f_args = jnp.asarray(f_args)

parameter_true = f_args + 0.1
parameter_guess = f_args


@jax.jit
def vf(y, t, p):
    return f(y, *p)



In [3]:
# make data

ts = jnp.linspace(t0, t1, endpoint=True, num=50)

strategy = smoothers.Smoother(
    extrapolation=isotropic.IsoIBM.from_params(num_derivatives=1),
)
solver = solvers.Solver(strategy=strategy, output_scale_sqrtm=1.0)


solution_true = ivpsolve.solve_fixed_grid(
    vf, initial_values=(u0,), ts=ts, solver=solver, parameters=parameter_true
)
data = solution_true.u
print(data[::2])

[[20.         20.        ]
 [ 5.85644978 23.58435139]
 [ 2.70237477 17.5369367 ]
 [ 1.02345233 12.57839811]
 [ 0.49518531  8.42506564]
 [ 0.36112363  5.46917492]
 [ 0.34736801  3.52811454]
 [ 0.40201114  2.2832835 ]
 [ 0.52365567  1.49283184]
 [ 0.7346505   0.99456952]
 [ 1.07935826  0.68377481]
 [ 1.63139296  0.49473127]
 [ 2.50730094  0.38859735]
 [ 3.88610151  0.34800812]
 [ 6.02858423  0.38297769]
 [ 9.25549781  0.57507167]
 [13.63547512  1.321135  ]
 [16.86465626  4.74749333]
 [ 8.86963908 16.01450933]
 [ 2.87510391 14.87751959]
 [ 1.10271054 11.0140109 ]
 [ 0.62948685  7.43269459]
 [ 0.49984863  4.89394377]
 [ 0.50551723  3.20990475]
 [ 0.60123878  2.11952303]]


In [4]:
# Initial guess

solution_wrong = ivpsolve.solve_fixed_grid(
    vf, initial_values=(u0,), ts=ts, solver=solver, parameters=parameter_guess
)
print(solution_wrong.u[::2])

[[20.         20.        ]
 [11.5270392  25.24509569]
 [ 6.31580598 23.38662539]
 [ 4.05411369 18.87187758]
 [ 3.12429197 14.43167203]
 [ 2.83061602 10.81775331]
 [ 2.91289134  8.08888401]
 [ 3.29387268  6.10787681]
 [ 3.98366713  4.71136769]
 [ 5.0468919   3.76221401]
 [ 6.59223256  3.16351572]
 [ 8.765498    2.86448784]
 [11.72704148  2.87476913]
 [15.57193688  3.30933878]
 [20.08359684  4.51464726]
 [24.0802936   7.36065494]
 [24.38287093 13.39071735]
 [17.86934738 21.90001611]
 [ 9.76967622 25.30493964]
 [ 5.53783212 22.30122466]
 [ 3.72832728 17.66198507]
 [ 3.00516145 13.40762431]
 [ 2.82351889 10.03115351]
 [ 2.98608021  7.51226778]
 [ 3.44372109  5.69792762]]


In [11]:
def data_likelihood(parameters, u0, ts, solver, vf, data):
    solution_wrong = ivpsolve.solve_fixed_grid(
        vf, initial_values=(u0,), ts=ts, solver=solver, parameters=parameters
    )

    observation_std = jnp.ones_like(ts) * 1e-4
    return dense_output.negative_marginal_log_likelihood(
        observation_std=observation_std, u=data, solution=solution_wrong, solver=solver
    )


parameter_to_solution = functools.partial(
    data_likelihood, solver=solver, ts=ts, vf=vf, u0=u0, data=data
)
sensitivity = jax.jit(jax.grad(parameter_to_solution))

In [12]:
parameter_to_solution(parameter_guess)
sensitivity(parameter_guess)

DeviceArray([  7702.76602445, -11055.5809865 ,   8484.00293328,
             -23099.32153709], dtype=float64)

In [16]:
f1 = parameter_guess
lrate = 1e-7
for i in range(100):
    for _ in range(100):
        f1 = f1 - lrate * sensitivity(f1)
    print(f"{i+1}00 iterations:", f1, parameter_true)


print(f1, parameter_true)

100 iterations: [0.50377739 0.06743706 0.5048235  0.05757416] [0.6  0.15 0.6  0.15]
200 iterations: [0.5104678  0.07049987 0.51201061 0.06099591] [0.6  0.15 0.6  0.15]
300 iterations: [0.51619257 0.07325552 0.51819264 0.06406593] [0.6  0.15 0.6  0.15]
400 iterations: [0.52114524 0.07579252 0.52357648 0.06681523] [0.6  0.15 0.6  0.15]
500 iterations: [0.52548343 0.07814263 0.52832601 0.06930586] [0.6  0.15 0.6  0.15]
600 iterations: [0.52933061 0.08033515 0.53256949 0.07158982] [0.6  0.15 0.6  0.15]
700 iterations: [0.53278303 0.08239589 0.53640668 0.07370962] [0.6  0.15 0.6  0.15]
800 iterations: [0.53591551 0.08434667 0.53991486 0.07569955] [0.6  0.15 0.6  0.15]
900 iterations: [0.53878595 0.0862053  0.5431536  0.0775868 ] [0.6  0.15 0.6  0.15]
1000 iterations: [0.54143893 0.08798582 0.54616845 0.07939242] [0.6  0.15 0.6  0.15]
1100 iterations: [0.54390838 0.08969871 0.54899379 0.08113212] [0.6  0.15 0.6  0.15]
1200 iterations: [0.54621974 0.09135132 0.55165514 0.08281692] [0.6  0.15 