Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random key does not typecheck Key[Scalar, ""] #142

Closed
EtaoinWu opened this issue Oct 26, 2023 · 2 comments
Closed

Random key does not typecheck Key[Scalar, ""] #142

EtaoinWu opened this issue Oct 26, 2023 · 2 comments
Labels
bug Something isn't working

Comments

@EtaoinWu
Copy link

EtaoinWu commented Oct 26, 2023

Problem

Code:

import jax.random as random
from jaxtyping import Key, Scalar

k = random.key(1)
print(isinstance(k, Key[Scalar, ""]))

Output should be True. Instead it is False.

Potential reason

Commit 1e5229c added these lines:

def _is_jax_extended_dtype(dtype: Any) -> bool:
    if not has_jax:
        return False
    try:
        is_dtype = issubclass(dtype, jax.numpy.generic)
    except TypeError:
        # `dtype` not a class
        return False
    else:
        if is_dtype:
            if hasattr(jax.dtypes, "extended"):  # jax>=0.4.14
                return jax.numpy.issubdtype(dtype, jax.dtypes.extended)
            else:  # jax<=0.4.13
                return jax.core.is_opaque_dtype(dtype)
        else:
            return False

However, in this case our object's dtype is key<fry>. It is not a class, but rather an object with type jax._src.prng.KeyTy, so this function returns False.

Due to the faulty boolean value, jaxtyping chose a wrong way to stringify the dtype (should be key<fry> as in the old versions; got prng_key), and the regex match fails for Key.

Suggested Fix: Semantically it should return True since KeyTy is derived from jax._src.dtypes.ExtendedDType.

Environment

Windows x64, but platform should be irrevalent.

micromamba create -n jaxtest python=3.12 # or use conda/mamba, should be the same
micromamba activate jaxtest
pip install jax[cpu] jaxtyping # jax==0.4.19, jaxtyping==0.2.23
@patrick-kidger patrick-kidger added the bug Something isn't working label Oct 28, 2023
@patrick-kidger
Copy link
Owner

Thanks for the report! Looks like things changed a little in underlying JAX, and we were missing a few test cases to pick up on this.

This should be fixed in #143. This will be included in the next release of jaxtyping (~3 weeks).

@patrick-kidger
Copy link
Owner

Closing as fixed in the v0.2.24 release!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants