Skip to content

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Mar 5, 2025

Reference issue

gh-22585

What does this implement/fix?

gh-22585 adjusted the condition under which p is normalized in scipy.stats.multinomial. multinomial has converted p to float64 internally since it was introduced, but users have been allowed to pass in input with other dtypes, and the change did not take this into account. This compensates by making the tolerance dtype-dependent.

@mdhaber mdhaber added scipy.stats maintenance Items related to regular maintenance tasks labels Mar 5, 2025
n_ and p_ are arrays of the correct shape; npcond is a boolean array
flagging values out of the domain.
"""
eps = np.finfo(np.result_type(np.asarray(p), np.float32)).eps * 10
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eps = np.finfo(np.result_type(np.asarray(p), np.float32)).eps * 10
eps = np.finfo(np.result_type(np.asarray(p), np.float16)).eps * 10

maybe?

Copy link
Member

@ev-br ev-br Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be just `eps = xp.finfo(xp_default_dtype(xp)).eps * 10 # or other wiggle factor`

This will not work currently since xp_default_dtype only account for torch defaulting to f32;
Assuming that the failure mode is that jax returns f32 array even when explicitly asked for an f64, xp_default_dtype should grow some logic to account for JAX idiosyncrasy.

EDIT: no, what I wrote above makes no sense: it's not related to JAX at all, it's a pure float32 thing #22585 (comment)

Then the right thing to do seems to be just eps = np.finfo(asarray(p)).dtype).eps * 10

Copy link
Contributor Author

@mdhaber mdhaber Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are tests that are passing integers like [0, 1] or something, so not without other changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok then, that figures. Then indeed, there's no way around result_type. Sorry for the noise

@mdhaber mdhaber marked this pull request as ready for review March 5, 2025 17:45
@mdhaber
Copy link
Contributor Author

mdhaber commented Mar 5, 2025

@ev-br marked this as ready given #22585 (comment). I suppose we could also require that input be float64 if that conversion is going to be done anyway, but I'd rather not.

@ev-br ev-br merged commit b0ef326 into scipy:main Mar 5, 2025
40 checks passed
@ev-br
Copy link
Member

ev-br commented Mar 5, 2025

LGTM, merged. Thanks Matt

@ev-br ev-br added this to the 1.16.0 milestone Mar 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maintenance Items related to regular maintenance tasks scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants