You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
However, in this case our object's dtype is key<fry>. It is not a class, but rather an object with type jax._src.prng.KeyTy, so this function returns False.
Due to the faulty boolean value, jaxtyping chose a wrong way to stringify the dtype (should be key<fry> as in the old versions; got prng_key), and the regex match fails for Key.
Suggested Fix: Semantically it should return True since KeyTy is derived from jax._src.dtypes.ExtendedDType.
Environment
Windows x64, but platform should be irrevalent.
micromamba create -n jaxtest python=3.12 # or use conda/mamba, should be the same
micromamba activate jaxtest
pip install jax[cpu] jaxtyping # jax==0.4.19, jaxtyping==0.2.23
The text was updated successfully, but these errors were encountered:
Problem
Code:
Output should be
True
. Instead it isFalse
.Potential reason
Commit 1e5229c added these lines:
However, in this case our object's dtype is
key<fry>
. It is not a class, but rather an object with typejax._src.prng.KeyTy
, so this function returnsFalse
.Due to the faulty boolean value, jaxtyping chose a wrong way to stringify the dtype (should be
key<fry>
as in the old versions; gotprng_key
), and the regex match fails forKey
.Suggested Fix: Semantically it should return
True
sinceKeyTy
is derived fromjax._src.dtypes.ExtendedDType
.Environment
Windows x64, but platform should be irrevalent.
The text was updated successfully, but these errors were encountered: