Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement EKF #42

Merged
merged 61 commits into from Jul 11, 2022
Merged

Implement EKF #42

merged 61 commits into from Jul 11, 2022

Conversation

petergchang
Copy link
Collaborator

@petergchang petergchang commented Jun 30, 2022

Description

  • Extended Kalman filter is implemented as extended_kalman_filter() in ekf/inference.py.
  • Unit tests implemented in ekf/inference_test.py
    • test_extended_kalman_filter_linear() verifies EKF-filtered means and covariances using randomly-generated data (using helper function random_linear_args) against those of regular Kalman filter (using lgssm_filter).
    • test_extended_kalman_filter_nonlinear() verifies the results against EKF implementation from the Sarkka-jax library. The Sarkka-jax implementations of various nlgssm algorithms are stored in nlgssm/sarkka_lib.py.
  • Extended Kalman smoother is implemented as extended_kalman_smoother() in ekf/inference.py.
    • test_extended_kalman_smoother_linear() and test_extended_kalman_smoother_nonlinear() verify the results.
  • EKF/smoother applied to the pendulum demo (from Sarkka) is implemented in ekf/demos/ekf_pendulum.py.

Issue

#40

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@murphyk murphyk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a unit test, I guess we can do 2 things:

from typing import Callable

@chex.dataclass
class ESSMParams:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this NLGSSMParams.
Also this should probably live in the nlgssm/models.py file, since it will be shared between EKF, UKF, etc.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, since it's so lightweight I was thinking of having a separate Params dataclass for each nlgssm algorithm b/c each algorithm requires a slightly different combination of parameters (e.g. UKF requires alpha, beta, kappa, etc.), but you're right in that I should maybe factor out the common parameters into nlgssm/models.py and extend the dataclass for each algorithm.

emission_covariance: chex.Array

@chex.dataclass
class ESSMPosterior:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this NLGSSMPosterior? (It's actually the same as LGSSMPosterior :)
Also this should probably live in the nlgssm/models.py file, since it will be shared between EKF, UKF, etc.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I wonder if we should use the code to reinforce the point that the EKF/UKF also return the same type of posterior as the Kalman smoother? I.e. they're all linear Gaussian chains. The same is true of Laplace approximations to, e.g., Poisson LDS posteriors.

return mu_cond, Sigma_cond


def essm_filter(params, emissions, inputs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we call this ekf_filter or extended_kalman_filter?
(I know https://github.com/probml/ssm-jax/blob/main/ssm_jax/lgssm/inference.py#L114 is called lgssm_filter but I am not wild about that...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like extended_kalman_filter!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for extended_kalman_filter.

Happy to revisit the discussion about lgssm_filter vs kalman_fitler. The latter is certainly more standard.

ssm_jax/nlgssm/extended_inference.py Outdated Show resolved Hide resolved

return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov)

# Run the Kalman filter
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very elegant!

@petergchang petergchang changed the title Implement NLGSSM Implement EKF Jul 6, 2022
@petergchang petergchang marked this pull request as ready for review July 6, 2022 19:27
Copy link
Member

@murphyk murphyk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I left a few small comments.

ssm_jax/ekf/inference.py Outdated Show resolved Hide resolved
f, h = params.dynamics_function, params.emission_function
F, H = jacfwd(f), jacfwd(h)
# If no input, add dummy input to functions
if inputs is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be factored out into a helper function, similar to _get_params?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea! Thank you. Implemented as _process_fn and _process_input now.

ssm_jax/ekf/inference_test.py Outdated Show resolved Hide resolved
ssm_jax/nlgssm/demos/simulations.py Outdated Show resolved Hide resolved
ssm_jax/nlgssm/models.py Outdated Show resolved Hide resolved
self.emission_covariance)

@property
def unconstrained_params(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Model fitting for nonlinear models is quite complex. I don't think return self.dynamics_function will be sufficient for SGD training. I propose we omit all functions related to learning (eg unconstrained_params)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, fixed!

ssm_jax/nlgssm/models.py Outdated Show resolved Hide resolved
ssm_jax/nlgssm/models.py Show resolved Hide resolved
ssm_jax/nlgssm/models.py Show resolved Hide resolved
@murphyk
Copy link
Member

murphyk commented Jul 7, 2022

In sarkka-lib include link to http://users.aalto.fi/~ssarkka/pub/cup_book_online_20131111.pdf

@murphyk murphyk merged commit 800252d into probml:main Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants