Raise a runtime error when trying to convert the jax.Array
wrapped by jax.core.Token
to a numpy array, as it is an internal implementation detail and the buffer has XLA token shape.
#5531
Job | Run time |
---|---|
29s | |
29s |