New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement EKF #42
Implement EKF #42
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a unit test, I guess we can do 2 things:
- linear case should match KF
- compare against eg
https://filterpy.readthedocs.io/en/latest/kalman/ExtendedKalmanFilter.html
ssm_jax/nlgssm/extended_inference.py
Outdated
from typing import Callable | ||
|
||
@chex.dataclass | ||
class ESSMParams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe call this NLGSSMParams
.
Also this should probably live in the nlgssm/models.py file, since it will be shared between EKF, UKF, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, since it's so lightweight I was thinking of having a separate Params dataclass for each nlgssm algorithm b/c each algorithm requires a slightly different combination of parameters (e.g. UKF requires alpha, beta, kappa, etc.), but you're right in that I should maybe factor out the common parameters into nlgssm/models.py and extend the dataclass for each algorithm.
ssm_jax/nlgssm/extended_inference.py
Outdated
emission_covariance: chex.Array | ||
|
||
@chex.dataclass | ||
class ESSMPosterior: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe call this NLGSSMPosterior
? (It's actually the same as LGSSMPosterior
:)
Also this should probably live in the nlgssm/models.py file, since it will be shared between EKF, UKF, etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. I wonder if we should use the code to reinforce the point that the EKF/UKF also return the same type of posterior as the Kalman smoother? I.e. they're all linear Gaussian chains. The same is true of Laplace approximations to, e.g., Poisson LDS posteriors.
ssm_jax/nlgssm/extended_inference.py
Outdated
return mu_cond, Sigma_cond | ||
|
||
|
||
def essm_filter(params, emissions, inputs=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we call this ekf_filter
or extended_kalman_filter
?
(I know https://github.com/probml/ssm-jax/blob/main/ssm_jax/lgssm/inference.py#L114 is called lgssm_filter
but I am not wild about that...)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like extended_kalman_filter
!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 for extended_kalman_filter
.
Happy to revisit the discussion about lgssm_filter
vs kalman_fitler
. The latter is certainly more standard.
ssm_jax/nlgssm/extended_inference.py
Outdated
|
||
return (ll, pred_mean, pred_cov), (filtered_mean, filtered_cov) | ||
|
||
# Run the Kalman filter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very elegant!
…id circular import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great. I left a few small comments.
ssm_jax/ekf/inference.py
Outdated
f, h = params.dynamics_function, params.emission_function | ||
F, H = jacfwd(f), jacfwd(h) | ||
# If no input, add dummy input to functions | ||
if inputs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be factored out into a helper function, similar to _get_params
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! Thank you. Implemented as _process_fn
and _process_input
now.
ssm_jax/nlgssm/models.py
Outdated
self.emission_covariance) | ||
|
||
@property | ||
def unconstrained_params(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Model fitting for nonlinear models is quite complex. I don't think return self.dynamics_function
will be sufficient for SGD training. I propose we omit all functions related to learning (eg unconstrained_params
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, fixed!
In |
Description
extended_kalman_filter()
inekf/inference.py
.ekf/inference_test.py
test_extended_kalman_filter_linear()
verifies EKF-filtered means and covariances using randomly-generated data (using helper functionrandom_linear_args
) against those of regular Kalman filter (usinglgssm_filter
).test_extended_kalman_filter_nonlinear()
verifies the results against EKF implementation from the Sarkka-jax library. The Sarkka-jax implementations of various nlgssm algorithms are stored innlgssm/sarkka_lib.py
.extended_kalman_smoother()
inekf/inference.py
.test_extended_kalman_smoother_linear()
andtest_extended_kalman_smoother_nonlinear()
verify the results.ekf/demos/ekf_pendulum.py
.Issue
#40