Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cleanup HMM SGD code #33

Closed
murphyk opened this issue Jun 24, 2022 · 0 comments · Fixed by #47
Closed

cleanup HMM SGD code #33

murphyk opened this issue Jun 24, 2022 · 0 comments · Fixed by #47
Assignees

Comments

@murphyk
Copy link
Member

murphyk commented Jun 24, 2022

Rename hmm_fit_minibatch_gradient_descent in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L75
to be hmm_fit_sgd.
Rename emissions to be batch_emissions. Add commetn that input is (N,T)
but you take a minibatch of size (B,T) at each step.

Remove old hmm_fit_sgd.

Move permutation step in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L93
inside of _sample_minibatches. Split the RNG key.
(Also check if B=N no need to do random permutation)
Add comment that you are sampling a random susbet of entire sequence, not time steps.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants