Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added EKF sampler #313

Merged
merged 5 commits into from
May 21, 2023
Merged

Added EKF sampler #313

merged 5 commits into from
May 21, 2023

Conversation

calebweinreb
Copy link
Contributor

@calebweinreb calebweinreb commented May 13, 2023

This PR adds extended_kalman_posterior_sample, which has the same signature as lgssm_posterior_sample. It behaves as expected for the pendulum example in the docs:

from dynamax.nonlinear_gaussian_ssm import extended_kalman_posterior_sample

ekf_params = ParamsNLGSSM(
    initial_mean=pendulum_params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=pendulum_params.dynamics_function,
    dynamics_covariance=pendulum_params.dynamics_covariance,
    emission_function=pendulum_params.emission_function,
    emission_covariance=pendulum_params.emission_covariance,
)

ekf_posterior = extended_kalman_smoother(ekf_params, obs)
sampled_states = extended_kalman_posterior_sample(jr.PRNGKey(0), ekf_params, obs)

m_ekf = sampled_states[:, 0]
plot_pendulum(time_grid, states[:, 0], obs, x_est=m_ekf, est_type="EKF (sampled)")
compute_and_print_rmse_comparison(states[:, 0], m_ekf, r, "EKF")
Screenshot 2023-05-13 at 2 25 22 PM

Increasing the process and observation noise 20-fold:

Screenshot 2023-05-13 at 6 07 12 PM

This PR also adds the following tests:

from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
    test_extended_kalman_sampler_linear,
    test_extended_kalman_sampler_nonlinear)

test_extended_kalman_sampler_linear()
test_extended_kalman_sampler_nonlinear()

In the process of writing these tests, I found that lgssm_filter and extended_kalman_filter returned asymmetric covariance matrices for the default examples in dynamax.nonlinear_gaussian_ssm.inference_test (see issue #317). As a result, the outputs of lgssm_posterior_sample and extended_kalman_posterior_sample were all NaN and could not be used for testing. This asymmetry issue is addressed in two separate PRs (#318 for lgssm_filter and #319 for extended_kalman_filter). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be:

  1. Symmetrize lgssm filtered covariance #318 (ensures LGSSM filtered covariance is symmetric)
  2. Symmetrize LGSSM and EKF filtered covariance #319 (ensures EKF filtered covariance is symmetric)
  3. Added EKF sampler #313 (the current PR that adds an EKF sampler)

@slinderman
Copy link
Collaborator

Thanks @calebweinreb! This looks good to me.

Is there a way to test this function? I guess we could check that the sample mean matches the EKF smoother mean?

@calebweinreb
Copy link
Contributor Author

Here's two tests possible tests:

1. Match to LGSSM

The output of extended_kalman_posterior_sample should match lgssm_posterior_sample when the dynamics function is linear. It does in the following example.

from jax import numpy as jnp
from jax import random as jr
from dynamax.nonlinear_gaussian_ssm import ParamsNLGSSM
from dynamax.nonlinear_gaussian_ssm import extended_kalman_posterior_sample
from dynamax.linear_gaussian_ssm import lgssm_posterior_sample
from dynamax.linear_gaussian_ssm.inference_test import build_lgssm_for_sampling

# simulate LGSSM
num_timesteps=100
key = jr.PRNGKey(0)
sample_key, key = jr.split(key)
lgssm, lgssm_params = build_lgssm_for_sampling()
states, emissions = lgssm.sample(lgssm_params, key=sample_key, num_timesteps=num_timesteps)

# sample from LGSSM
lgssm_sampled_states = lgssm_posterior_sample(key, lgssm_params, emissions)

# sample from EKF
ekf_params = ParamsNLGSSM(
    initial_mean=lgssm_params.initial.mean,
    initial_covariance=lgssm_params.initial.cov,
    dynamics_function=(lambda x: lgssm_params.dynamics.weights @ x),
    dynamics_covariance=lgssm_params.dynamics.cov,
    emission_function=(lambda x: lgssm_params.emissions.weights @ x),
    emission_covariance=lgssm_params.emissions.cov)
ekf_sampled_states = extended_kalman_posterior_sample(key, ekf_params, emissions)

print(jnp.allclose(lgssm_sampled_states, ekf_sampled_states)) 

2. Match to smoother

The site-wise mean and variance of the sampled states should match marginal mean and covariance from the smoother. Here's an example using the pendulum example from the docs.

pendulum_params = PendulumParams()

ekf_params = ParamsNLGSSM(
    initial_mean=pendulum_params.initial_state,
    initial_covariance=jnp.eye(states.shape[-1]) * 0.1,
    dynamics_function=pendulum_params.dynamics_function,
    dynamics_covariance=pendulum_params.dynamics_covariance*10,
    emission_function=pendulum_params.emission_function,
    emission_covariance=pendulum_params.emission_covariance*10,
)

# smoothing
ekf_posterior = extended_kalman_smoother(ekf_params, obs)

# sampling
num_samples = 100000
sample_fun = jax.vmap(extended_kalman_posterior_sample, in_axes=(0,None,None))
samples = (jr.split(jr.PRNGKey(0), num_samples), ekf_params, obs)

# compare by plotting
fig,axs = plt.subplots(1,2)
axs[0].plot(ekf_posterior.smoothed_means[:,0], label='smoothed mean')
axs[0].plot(jnp.mean(samples[:,:,0],axis=0), label='mean of samples')
axs[0].set_ylabel('mean')
axs[0].legend(loc='upper left')
axs[1].plot(ekf_posterior.smoothed_covariances[:,0,0], label='smoothed variance')
axs[1].plot(jnp.var(samples[:,:,0],axis=0), label='variance of samples')
axs[1].legend(loc='upper left')
axs[1].set_ylabel('variance')
fig.set_size_inches((15,3))

Here's the result
Screenshot 2023-05-15 at 8 55 11 PM

The error seems reasonable given the sample size:

empirical_error = jnp.std(ekf_posterior.smoothed_means[:,0] - jnp.mean(samples[:,:,0],axis=0))
expected_error = jnp.sqrt(ekf_posterior.smoothed_covariances[:,0,0] / num_samples)
print(empirical_error, expected_error.mean())

yields

empirical_error = 0.000929
expected_error = 0.000768

@murphyk
Copy link
Member

murphyk commented May 16, 2023

These are both great tests. Please add them to https://github.com/probml/dynamax/blob/main/dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py.
And change print(allclose) to assert(allclose) :)

@calebweinreb
Copy link
Contributor Author

calebweinreb commented May 20, 2023

I added testing functions that are similar to the ones proposed above but use the examples generated by dynamax.nonlinear_gaussian_ssm.inference_test_utils to be consistent with all the other tests in dynamax.nonlinear_gaussian_ssm.inference_test.py. The tests can be run as follows:

from dynamax.nonlinear_gaussian_ssm.inference_ekf_test import (
    test_extended_kalman_sampler_linear,
    test_extended_kalman_sampler_nonlinear)

test_extended_kalman_sampler_linear()
test_extended_kalman_sampler_nonlinear()

In the process of writing these tests, I found that lgssm_filter and extended_kalman_filter returned asymmetric covariance matrices (see issue #317), and consequently the outputs of lgssm_posterior_sample and extended_kalman_posterior_sample were all NaN. This asymmetry issue is addressed in two separate PRs (#318 for lgssm_filter and #319 for extended_kalman_filter). Both of those PRs need to be merged before this one or the tests won't work. So the overall merge order would be:

  1. Symmetrize lgssm filtered covariance #318 (ensures LGSSM filtered covariance is symmetric)
  2. Symmetrize LGSSM and EKF filtered covariance #319 (ensures EKF filtered covariance is symmetric)
  3. Added EKF sampler #313 (the current PR that adds an EKF sampler)

@slinderman slinderman merged commit f40a2fd into probml:main May 21, 2023
@calebweinreb calebweinreb deleted the ekf_sampler branch July 26, 2023 18:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants