You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Move permutation step in https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L93
inside of _sample_minibatches. Split the RNG key.
(Also check if B=N no need to do random permutation)
Add comment that you are sampling a random susbet of entire sequence, not time steps.
The text was updated successfully, but these errors were encountered:
Rename
hmm_fit_minibatch_gradient_descent
inhttps://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L75
to be
hmm_fit_sgd
.Rename
emissions
to bebatch_emissions
. Add commetn that input is (N,T)but you take a minibatch of size (B,T) at each step.
Remove old
hmm_fit_sgd
.Move permutation step in
https://github.com/probml/ssm-jax/blob/main/ssm_jax/hmm/learning.py#L93
inside of
_sample_minibatches
. Split the RNG key.(Also check if B=N no need to do random permutation)
Add comment that you are sampling a random susbet of entire sequence, not time steps.
The text was updated successfully, but these errors were encountered: