From cc53aa956b0571efb3b0237dd87d92d509f8b1fd Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Tue, 10 Mar 2020 06:59:54 -0700 Subject: [PATCH] skip new optix test on tpu (cf. #2350) --- tests/optix_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/optix_test.py b/tests/optix_test.py index a4bae03f96f0..ff86d02375d2 100644 --- a/tests/optix_test.py +++ b/tests/optix_test.py @@ -19,7 +19,7 @@ from jax import numpy as jnp from jax.experimental import optimizers from jax.experimental import optix -import jax.test_util # imported only for flags +import jax.test_util from jax.tree_util import tree_leaves import numpy as onp @@ -60,6 +60,7 @@ def test_sgd(self): for x, y in zip(tree_leaves(jax_params), tree_leaves(optix_params)): onp.testing.assert_allclose(x, y, rtol=1e-5) + jax.test_util.skip_on_devices("tpu") def test_apply_every(self): # The frequency of the application of sgd k = 4