In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import sympy as sp
from jax.lax import scan
import seaborn as sns
import projection_filter.n_d_exponential_family_projection_filter as pf
import symbolic.one_d
import projection_filter.util as util
from collections import namedtuple
from sde import SDESolverTypes, sde_solver
from sde.wiener import multidimensional_wiener_process



In [2]:
sns.set_context('notebook')
plt.rcParams['font.family'] = 'VictorMono Nerd Font'
plt.rcParams['mathtext.rm'] = 'VictorMono Nerd Font ExtraLight'
plt.rcParams['mathtext.it'] = 'VictorMono Nerd Font ExtraLight Italic'
plt.rcParams['mathtext.bf'] = 'VictorMono Nerd Font Light'
plt.rcParams['mathtext.default'] = 'rm'
# plt.rcParams['mathtext.fontset'] = 'stix'
SMALL_SIZE = 12
MEDIUM_SIZE = 14
BIGGER_SIZE = 16
plt.rc('font', size=SMALL_SIZE)  # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)  # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc('xtick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('ytick', labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc('legend', fontsize=BIGGER_SIZE)  # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

In [22]:
SIGMA_W = 1e0
SIGMA_V = 1e-1


def boyd_bijection(x, param):
    return x / (1 - x ** 2)

def archtanh_bijection(x,params):
        return jnp.arctanh(x)


def F(x, t):
    return jnp.array([-x[0],
                      -x[1],
                      -x[0],
                      -x[1]])  # Default


def G(x, t):
    return jnp.array([[SIGMA_W, 0., 0., 0.],
                      [0., SIGMA_W, 0., 0.],
                      [0., 0., SIGMA_V, 0.],
                      [0., 0., 0., SIGMA_V]])


def kalman_bucy(m0: jnp.ndarray, P0: jnp.ndarray, ys: jnp.ndarray, dt: float, Q: float, R: float):
    """Simple Kalman--Bucy filter in Euler's scheme.
    """
    dys = jnp.diff(ys, axis=0)
    eye = jnp.eye(2)

    def scan_body(carry, elem):
        m, P = carry
        dy = elem

        K = P * (-1) / R
        m += -m * dt + K @ (dy - (-1) * m * dt)
        P += (-2 * P + Q * eye - K @ K.T * R) * dt
        return (m, P), (m, P)

    _, mmPP = scan(scan_body, (m0, P0), dys)

    return mmPP[0], mmPP[1]

In [12]:
jax_key = jax.random.PRNGKey(666)
r = 0.2
measurement_length = 500
np.random.seed(666)

In [13]:
# Measurement
nt = 1000
dt = 1e-3  # default is 1e-2
tspan = jnp.arange(nt) * dt

# Initial condition x0 ~ N(0, 1)
init_mean = jnp.array([0., 0.])
var_init = (SIGMA_W ** 2) * jnp.eye(2)  #
var_init_inv = jnp.linalg.solve(var_init, jnp.eye(2))
theta_1 = var_init_inv @ init_mean
theta_2 = -0.5 * jnp.array([var_init_inv[0, 0], 2 * var_init_inv[0, 1], var_init_inv[1, 1]])

dW = multidimensional_wiener_process((nt, 4), dt, jax.random.PRNGKey(15))
X0 = jnp.array([init_mean[0], init_mean[1], init_mean[0], init_mean[1]])
X_integrated = sde_solver(F, G, X0, tspan, dW, solver_type=SDESolverTypes.ItoEulerMaruyama)
measurement_record = X_integrated[:, 2:]

# Kalman--Bucy
kb_mm, kb_PP = kalman_bucy(m0=init_mean, P0=var_init,
                           ys=measurement_record, dt=dt, Q=SIGMA_W ** 2,
                           R=SIGMA_V ** 2)

x, dw, dv = sp.symbols(('x0:2', 'dw0:2', 'dv0:2'))
t = sp.symbols('t')
f = sp.Matrix([-x[0],
               -x[1]])
g = sp.Matrix([[SIGMA_W, 0], [0, SIGMA_W]])
dynamic_sde = symbolic.one_d.SDE(f, g, t, x, dw)
measurement_sde = symbolic.one_d.SDE(drifts=sp.Matrix([-x[0], -x[1]]),
                                     diffusions=sp.Matrix([[SIGMA_V, 0], [0, SIGMA_V]]),
                                     time=t,
                                     variables=x,
                                     brownians=dv)

natural_statistics_symbolic = sp.Matrix([x[1], x[1] ** 2,
                                         x[0], x[0] * x[1], x[0] ** 2])
initial_condition = jnp.array([theta_1[1], theta_2[2], theta_1[0], theta_2[1], theta_2[0]])

In [14]:
SimulationSetting = namedtuple('SimulationSetting',['integrator','level','nodes_number'])

In [15]:
settings = [SimulationSetting('spg',3,0),
            SimulationSetting('spg',4,0),
            SimulationSetting('spg',5,0),
            SimulationSetting('spg',6,0),
            SimulationSetting('qmc',3,49),
            SimulationSetting('qmc',4,129),
            SimulationSetting('qmc',5,321),
            SimulationSetting('qmc',6,769),
    
]

In [25]:
def generate_results(a_setting):
    em_pf = pf.MultiDimensionalSStarProjectionFilter(dynamic_sde,
                                                         measurement_sde,
                                                         natural_statistics_symbolic,
                                                         constants=None,
                                                         bijection=archtanh_bijection,
                                                         initial_condition=initial_condition,
                                                         measurement_record=measurement_record,
                                                         delta_t=dt,
                                                         level=a_setting.level,
                                                         integrator=a_setting.integrator,
                                                         sRule="gauss-patterson",
                                                         nodes_number=a_setting.nodes_number,
                                                         epsilon=1e-9)
    em_pf.propagate()

    var_inv = -2 * jnp.hstack((em_pf.state_history[:, 4, jnp.newaxis], em_pf.state_history[:, 3, jnp.newaxis],
                               em_pf.state_history[:, 3, jnp.newaxis], em_pf.state_history[:, 1, jnp.newaxis]))
    var_inv = var_inv.reshape((var_inv.shape[0], 2, 2))
    I = jnp.eye(2)
    var_statistics = jnp.linalg.solve(var_inv, I[jnp.newaxis, :, :])
    scaled_mean = jnp.hstack((em_pf.state_history[:, 2, jnp.newaxis], em_pf.state_history[:, 0, jnp.newaxis]))
    mean_statistics = var_statistics @ scaled_mean[:, :, jnp.newaxis]
    mean_statistics = mean_statistics.squeeze()
    stdev_kalman = jnp.sqrt(jnp.diagonal(kb_PP, axis1=1, axis2=2))
    stdev_pf = jnp.sqrt(jnp.diagonal(var_statistics, axis1=1, axis2=2))
    hell_dist = util.hellinger_distance_between_two_gaussians(mean_statistics, kb_mm, var_statistics, kb_PP)

    plt.figure(dpi=300)
    plt.semilogy(tspan[1:], hell_dist, linewidth=0.5, color='black', label='Hellinger-distance')
    plt.xlabel('t')
    plt.ylabel('Hellinger-Distance')
    # plt.legend()
    plt.savefig('Hell_pf_vs_kf_{}_{}.pdf'.format(em_pf.integrator_type, em_pf.nodes_number))
    plt.close()

    for state_index in range(2):
        plt.figure(dpi=300)
        plt.plot(tspan[1:], kb_mm[:, state_index], linewidth=0.5, color='blue', label='Kalman-Bucy')
        plt.fill_between(tspan[1:], kb_mm[:, state_index] - stdev_kalman[:, state_index],
                         kb_mm[:, state_index] + stdev_kalman[:, state_index], facecolor='blue', alpha=0.25)
        plt.plot(tspan[1:], mean_statistics[:, state_index], linewidth=0.5, color='red', label='Projection')
        plt.fill_between(tspan[1:], mean_statistics[:, state_index] - stdev_pf[:, state_index],
                         mean_statistics[:, state_index] + stdev_pf[:, state_index], facecolor='red', alpha=0.25)
        plt.legend()
        plt.xlabel('t')
        plt.ylabel(r'$x_{}$'.format(state_index+1))
        plt.savefig('x_{}_pf_vs_kf_{}_{}.pdf'.format(state_index, em_pf.integrator_type, em_pf.nodes_number))
        plt.close()

In [26]:
for a_setting in settings:
    generate_results(a_setting)