Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcasting in calculating emission log probs #22

Merged
merged 1 commit into from May 26, 2022

Conversation

gileshd
Copy link
Collaborator

@gileshd gileshd commented May 25, 2022

Fix reshaping of emissions so that broadcasting works with HMM.emission_distribution.log_prob() works for both scalar and vector observations.

The behaviour is now:

  • scalar - (T,) --[reshape]--> (T,1) --[tfp broadcast]--> (T,K).
  • vector - (T,D) --[reshape]--> (T,1,D) --[tfp broadcast]--> (T,K).

This fixes many methods in HMM subclasses with scalar observations (e.g. CategoricalHMM.filter(), CategoricalHMM.smoother(), etc.).

The PR also changes a line in BaseHMM.log_prob() to only sum over the leading axis, this helps for a failure case whereby extra dimensions in the input can cause broadcasting of the tfp log prob which is subsequently hidden by a .sum().

With the present change the method will output an array with shape determined by the broadcasting (e.g. if x is shape (T,) inputing x[:,None] will now output an array of shape (T,) rather than () which at least makes it clear that broadcasting has occurred). Going forward it might be useful to have some shape checking steps.

@murphyk murphyk merged commit c30107e into main May 26, 2022
@gileshd gileshd deleted the hmm-reshaping-bug branch May 26, 2022 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants