diff --git a/pyproject.toml b/pyproject.toml index 2123c0e97..c9660b78b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,6 @@ test = [ "chex", "networkx>=2.5", "scikit-learn>=1.0", - "scikit-sparse>=0.4.6; python_version < '3.11'", "tslearn>=0.5; python_version < '3.11'", ] docs = [ @@ -177,6 +176,7 @@ legacy_tox_ini = """ [testenv] extras = test pass_env = CUDA_*,PYTEST_*,CI + deps = py{3.8,3.9,3.10}: scikit-sparse>=0.4.6 commands_pre = gpu: python -I -m pip install "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jax-latest: python -I -m pip install 'git+https://github.com/google/jax@main'