diff --git a/.github/workflows/torchax.yml b/.github/workflows/torchax.yml index 2f1e930f48b5..a60e01341597 100644 --- a/.github/workflows/torchax.yml +++ b/.github/workflows/torchax.yml @@ -46,8 +46,8 @@ jobs: shell: bash working-directory: torchax run: | - pip install -r test-requirements.txt pip install -e .[cpu] + pip install -r test-requirements.txt - name: Run tests if: needs.check_code_changes.outputs.has_code_changes == 'true' working-directory: torchax diff --git a/torchax/pyproject.toml b/torchax/pyproject.toml index 2f30f30e7c68..da8149ef83f3 100644 --- a/torchax/pyproject.toml +++ b/torchax/pyproject.toml @@ -40,11 +40,11 @@ classifiers = [ path = "torchax/__init__.py" [project.optional-dependencies] -cpu = ["jax[cpu]>=0.6.2", "jax[cpu]"] +cpu = ["jax[cpu]>=0.6.2, <0.8.0"] # Add libtpu index `-f https://storage.googleapis.com/libtpu-wheels/index.html -f https://storage.googleapis.com/libtpu-releases/index.html` -tpu = ["jax[cpu]>=0.6.2", "jax[tpu]"] -cuda = ["jax[cpu]>=0.6.2", "jax[cuda12]"] -odml = ["jax[cpu]>=0.6.2", "jax[cpu]"] +tpu = ["jax[cpu,tpu]>=0.6.2, <0.8.0"] +cuda = ["jax[cpu,cuda12]>=0.6.2, <0.8.0"] +odml = ["jax[cpu]>=0.6.2, <0.8.0"] [tool.hatch.build.targets.wheel] packages = ["torchax"]