Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve sequential_gaussian_filter_sample() #3146

Merged
merged 10 commits into from
Oct 23, 2022
Merged

Improve sequential_gaussian_filter_sample() #3146

merged 10 commits into from
Oct 23, 2022

Conversation

fritzo
Copy link
Member

@fritzo fritzo commented Oct 19, 2022

This makes a number of improvements to sequential_gaussian_filter_sample():

  1. Allows user to pass in noise, rather than drawing noise internally (i.e. more JAX-style). The motivation here is to allow both (1) computing the mean, and (2) supporting antithetic sampling where we can pass in cat([noise, -noise]).
  2. Avoid dropping the first entry of the returned samples. This makes sequential_gaussian_filter_sample() more useful outside of the context of GaussianHMM, e.g. it will be easier to use in a fully-observed Markov model. This attempts to make up for my admittedly bad design choice of making the hidden Z series one step longer than the observed X series in GaussianHMM 😬
  3. Adds a profiling script. I saw no speed changes due to this PR.
  4. Refactors some backward-sample logic to use placement into a torch.empty(), rather than torch.nn.functional.pad and torch.stack. This is admittedly less functional, but does reduce memory usage and IMHO reads cleaner.

This also exposes the helper matrix_and_gaussian_to_gaussian() which I'm finding useful.

Tested

  • added new gradient tests
  • added tests for antithetic sampling
  • profiled with and without gradients, saw no time difference

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes make sense to me. Looks great overall! Thanks, Fritz.

if noise is None:
noise = torch.randn(shape, dtype=loc.dtype, device=loc.device)
else:
noise = noise.reshape(shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is better to broadcast noise here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never mind, it is better to avoid broadcasting the noise.

@fehiepsi fehiepsi merged commit 1098e38 into dev Oct 23, 2022
@fritzo
Copy link
Member Author

fritzo commented Oct 24, 2022

Thanks for reviewing @fehiepsi!

@fritzo fritzo deleted the gaussian-rsample branch August 10, 2023 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants