Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change scipy.special.logit into jnp.log #57

Merged
merged 1 commit into from May 24, 2022
Merged

Change scipy.special.logit into jnp.log #57

merged 1 commit into from May 24, 2022

Conversation

petergchang
Copy link
Contributor

IIUC, the logits argument for jax.random.categorical correspond to log probabilities instead of their logit values
(defined by log(p/(1-p))),
and so the usage of jax.random.categorical in hmm_forwards_filtering_backwards_sampling_jax and hmm_sample_jax should be modified to take in
jnp.log(.) instead of logit(.).

From the documentation for jax.random.catgegorical:

logits - Unnormalized log probabilities of the categorical distribution(s) to sample from, so that softmax(logits, axis) gives the corresponding probabilities.

IIUC, the `logits` argument for `jax.random.categorical` take in log probabilities instead of their logit values (defined by log(p/(1-p))),
and so the usage of `jax.random.categorical` in `hmm_forwards_filtering_backwards_sampling_jax` and `hmm_sample_jax` should be modified to take in
jnp.log(.) instead of logit(.).
@murphyk
Copy link
Member

murphyk commented May 24, 2022

Yes, you are right. thanks.

@murphyk murphyk merged commit 18c78a4 into probml:main May 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants