diff --git a/tests/optix_test.py b/tests/optix_test.py index a4bae03f96f0..ff86d02375d2 100644 --- a/tests/optix_test.py +++ b/tests/optix_test.py @@ -19,7 +19,7 @@ from jax import numpy as jnp from jax.experimental import optimizers from jax.experimental import optix -import jax.test_util # imported only for flags +import jax.test_util from jax.tree_util import tree_leaves import numpy as onp @@ -60,6 +60,7 @@ def test_sgd(self): for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): onp.testing.assert_allclose(x, y, rtol=1e-5) + jax.test_util.skip_on_devices("tpu") def test_apply_every(self): # The frequency of the application of sgd k = 4