Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaNs returned bylgssm_posterior_sample #320

Closed
calebweinreb opened this issue May 21, 2023 · 1 comment
Closed

NaNs returned bylgssm_posterior_sample #320

calebweinreb opened this issue May 21, 2023 · 1 comment

Comments

@calebweinreb
Copy link
Contributor

calebweinreb commented May 21, 2023

We have been trying to incorporate dynamax into jax-moseq, a tool for unsupervised analysis of animal behavior. Specifically, we would like to replace our custom Kalman sampling code with the lgssm_posterior_sample method in dynamax. @ezhang94 has already done all the heavy lifting and tested it on some small-scale examples. However we are still getting all-NaN outputs for more realistically-sized datasets.

It seems like the problem can be solved by adding a small amount to the diagonal of the posterior covariance during each backward sampling step. Below is a brief recipe to reproduce the issue and a diagnosis of where the NaNs first appear.

  • The issue can be reproduced by running the keypoint-moseq tutorial, using the version of jax-moseq as of this commit. At some point during fitting, lgssm_posterior_sample returns all NaNs for some of the dataset.
  • The NaNs first appear during the forward filtering pass. This can be solved by forcing the output of the conditioning function to be symmetric. I recently submitted a separate issue and PR that implements this change.
  • NaNs still appear even after the above fix, however, but now during the backward sampling pass. They are rare, and seem to be caused by sharp discontinuities in the emissions. Once a NaN appears though, it is propagated through the rest of the backward pass.
  • These sampling NaNs specifically appear during the MVN sampling step when the covariance is non-PSD (min eigenvalue < -1e-4). The problem can be solved by padding the diagonal of the covariance matrix before passing it to the MVN sampler.
@slinderman
Copy link
Collaborator

Thanks for digging into this and finding/fixing these issues. I've merged your PRs. Hope we can get jax-moseq to work with dynamax!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants