Skip to content

Commit

Permalink
Deprecate Python 3.8 (#554)
Browse files Browse the repository at this point in the history
* Bump Python>=3.9

* Update pre-commits

* Add a comment

* Always install `tslearn` for tests

* Run new pre-commit

* Fix docs linter

* Fix a typo
  • Loading branch information
michalk8 committed Jun 27, 2024
1 parent 96366f7 commit 2147cbe
Show file tree
Hide file tree
Showing 13 changed files with 24 additions and 48 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.9']
python-version: ['3.10']
jax-version: [jax-default, jax-latest]

steps:
Expand Down Expand Up @@ -75,7 +75,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8', '3.10', '3.11', '3.12']
python-version: ['3.10', '3.11', '3.12']
os: [ubuntu-latest]
include:
- python-version: '3.9'
Expand Down
10 changes: 5 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ default_stages:
minimum_pre_commit_version: 3.0.0
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v4.6.0
hooks:
- id: detect-private-key
- id: check-ast
Expand All @@ -20,7 +20,7 @@ repos:
- id: trailing-whitespace
- id: check-case-conflict
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.2.1
rev: v0.4.10
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -35,14 +35,14 @@ repos:
- id: yapf
additional_dependencies: [toml]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.7.1
rev: 1.8.5
hooks:
- id: nbqa-pyupgrade
args: [--py38-plus]
args: [--py39-plus]
- id: nbqa-black
- id: nbqa-isort
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.12.0
rev: v2.13.0
hooks:
- id: pretty-format-yaml
args: [--autofix, --indent, '2']
Expand Down
1 change: 0 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ In order to run tests, we utilize [tox](https://tox.wiki/):
```shell
tox run # run linter and all tests on all available Python versions
tox run -- -n auto -m fast # run linter and fast tests in parallel
tox -e py3.8 # run all tests on Python3.8
tox -e py3.9 -- -k "test_euclidean_point_cloud" # run tests matching the expression on Python3.9
tox -e py3.10 -- --memray # test also memory on Python3.10
```
Expand Down
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@
"https://doi.org/10.1145/2766963",
"https://keras.io/examples/nlp/pretrained_word_embeddings/",
]
linkcheck_report_timeouts_as_broken = False

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
9 changes: 5 additions & 4 deletions docs/tutorials/Monge_Gap.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,9 @@
"outputs": [],
"source": [
"import dataclasses\n",
"from collections.abc import Iterator, Mapping\n",
"from types import MappingProxyType\n",
"from typing import Any, Dict, Iterator, Literal, Mapping, Optional, Tuple, Union\n",
"from typing import Any, Dict, Literal, Optional, Tuple, Union\n",
"\n",
"import jax\n",
"import jax.numpy as jnp\n",
Expand Down Expand Up @@ -154,7 +155,7 @@
" train_batch_size: int = 256,\n",
" valid_batch_size: int = 256,\n",
" rng: Optional[jax.Array] = None,\n",
") -> Tuple[datasets.Dataset, datasets.Dataset, int]:\n",
") -> tuple[datasets.Dataset, datasets.Dataset, int]:\n",
" \"\"\"Samplers from ``SklearnDistribution``.\"\"\"\n",
" rng = jax.random.PRNGKey(0) if rng is None else rng\n",
" rng1, rng2, rng3, rng4 = jax.random.split(rng, 4)\n",
Expand Down Expand Up @@ -200,10 +201,10 @@
"outputs": [],
"source": [
"def plot_samples(\n",
" batch: Dict[str, Any],\n",
" batch: dict[str, Any],\n",
" num_points: Optional[int] = None,\n",
" title: Optional[str] = None,\n",
" figsize: Tuple[int, int] = (8, 6),\n",
" figsize: tuple[int, int] = (8, 6),\n",
" rng: Optional[jax.Array] = None,\n",
"):\n",
" \"\"\"Plot samples from the source and target measures.\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/gromov_wasserstein_multiomics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
"source": [
"## Using {mod}`ott` {class}`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` solvers\n",
"\n",
"The following `OTTSCOT` class inherits from the `SCOT` class but overrides the `find_correspondences` method in order to use `OTT` instead of `POT`. The matrix `T` is the optimal transport matrix (noted usuall `P` in `OTT`), coupling points in $x$ to $y$."
"The following `OTTSCOT` class inherits from the `SCOT` class but overrides the `find_correspondences` method in order to use `OTT` instead of `POT`. The matrix `T` is the optimal transport matrix (noted usually `P` in `OTT`), coupling points in $x$ to $y$."
]
},
{
Expand Down Expand Up @@ -16883,7 +16883,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.10.14"
},
"vscode": {
"interpreter": {
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/sinkhorn_divergence_gradient_flow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"def gradient_flow(\n",
" x: jnp.ndarray,\n",
" y: jnp.ndarray,\n",
" divergence: Callable[[jnp.ndarray, jnp.ndarray, float], Tuple[float, Any]],\n",
" divergence: Callable[[jnp.ndarray, jnp.ndarray, float], tuple[float, Any]],\n",
" num_iter: int = 500,\n",
" lr: float = 0.2,\n",
" dump_every: int = 50,\n",
Expand Down
7 changes: 0 additions & 7 deletions environment.yml

This file was deleted.

18 changes: 7 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "ott-jax"
description = "Optimal Transport Tools in JAX."
requires-python = ">=3.8"
requires-python = ">=3.9"
dynamic = ["version"]
readme = {file = "README.md", content-type = "text/markdown"}
license = {file = "LICENSE"}
Expand All @@ -15,7 +15,7 @@ authors = [
dependencies = [
"jax>=0.4.0",
"jaxopt>=0.8",
"lineax>=0.0.1; python_version >= '3.9'",
"lineax>=0.0.5",
"numpy>=1.20.0",
]
keywords = [
Expand All @@ -42,7 +42,6 @@ classifiers = [
"Operating System :: MacOS :: MacOS X",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
Expand Down Expand Up @@ -75,9 +74,7 @@ test = [
"networkx>=2.5",
"scikit-learn>=1.0",
"tqdm",
# tslearn needs numba, which isn't supported for 3.12
"tslearn>=0.5; python_version < '3.12'",
"lineax; python_version >= '3.9'",
"tslearn>=0.5",
"matplotlib",
]
docs = [
Expand All @@ -98,7 +95,7 @@ include-package-data = true

[tool.black]
line-length = 80
target-version = ["py38"]
target-version = ["py39"]
include = '\.ipynb$'

[tool.isort]
Expand All @@ -115,7 +112,6 @@ known_plotting = ["IPython", "matplotlib", "mpl_toolkits", "seaborn"]

[tool.pytest.ini_options]
minversion = "6.0"
addopts = '-m "not notebook"'
testpaths = [
"tests",
]
Expand Down Expand Up @@ -188,7 +184,7 @@ ignore_path = ["docs/**/_autosummary", "docs/contributing.rst"]
legacy_tox_ini = """
[tox]
min_version = 4.0
env_list = lint-code,py{3.8,3.9,3.10,3.11,3.12},py3.9-jax-default
env_list = lint-code,py{3.9,3.10,3.11,3.12},py3.9-jax-default
skip_missing_interpreters = true
[testenv]
Expand All @@ -207,7 +203,7 @@ commands =
[testenv:lint-code]
description = Lint the code.
deps = pre-commit>=2.16.0
deps = pre-commit>=3.0.0
skip_install = true
commands =
pre-commit run --all-files --show-diff-on-failure
Expand Down Expand Up @@ -276,7 +272,7 @@ exclude = [
"dist"
]
line-length = 80
target-version = "py38"
target-version = "py38" # TODO(michalk8): use py39 and fix the type hints

[tool.ruff.lint]
ignore = [
Expand Down
7 changes: 1 addition & 6 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,12 @@
import jax
import jax.numpy as jnp
import numpy as np
from tslearn import metrics as ts_metrics

from ott.geometry import costs, pointcloud
from ott.math import utils as mu
from ott.solvers import linear

try:
from tslearn import metrics as ts_metrics
except ImportError:
ts_metrics = None


def _proj(matrix: jnp.ndarray) -> jnp.ndarray:
u, _, v_h = jnp.linalg.svd(matrix, full_matrices=False)
Expand Down Expand Up @@ -280,7 +276,6 @@ def test_h_legendre_elastic_l2(self, rng: jax.Array, d: int):
)


@pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11")
@pytest.mark.fast()
class TestSoftDTW:

Expand Down
2 changes: 0 additions & 2 deletions tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from typing import Type

import pytest
Expand Down Expand Up @@ -281,7 +280,6 @@ def test_potentials_sinkhorn_divergence(self, rng: jax.Array, eps: float):
with pytest.raises(AssertionError):
np.testing.assert_allclose(div_ref, div_points, rtol=1e-1, atol=1e-1)

@pytest.mark.skipif(sys.version_info < (3, 9), reason="Old JAX version.")
@pytest.mark.parametrize("cost_type", [costs.ElasticL1, costs.ElasticL2])
def test_potentials_diff_param_costs(
self, rng: jax.Array, cost_type: Type[costs.RegTICost]
Expand Down
4 changes: 0 additions & 4 deletions tests/solvers/linear/sinkhorn_diff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,12 +363,10 @@ def test_apply_transport_jacobian(
mode (False) is tested with looser convergence settings.
tau_a: loosen up 1st marginal constraint when <1.0
tau_b: loosen up 2nd marginal constraint when <1.0
shape: size for point clouds n, m.
arg: test jacobian w.r.t. either weight vectors a or locations x
axis: test the jacobian of the application of the (right) application of
transport to arbitrary vec (axis=0) or the left (axis=1).
"""
_ = pytest.importorskip("lineax") # only tested using lineax
n, m = (27, 13)
dim = 4
rngs = jax.random.split(rng, 9)
Expand Down Expand Up @@ -467,7 +465,6 @@ def test_potential_jacobian_sinkhorn(
shape: Tuple[int, int], arg: int
):
"""Test Jacobian of optimal potential w.r.t. weights and locations."""
_ = pytest.importorskip("lineax") # only tested using lineax
atol = 1e-2 if lse_mode else 5e-2 # lower tolerance for lse mode.
rtol = 1e-2 if lse_mode else 1.5e-1 # lower tolerance for lse mode.
n, m = shape
Expand Down Expand Up @@ -647,7 +644,6 @@ def test_potential_jacobian_sinkhorn_precond(
shape: Tuple[int, int], arg: int
):
"""Test Jacobian of optimal potential works across 2 precond_fun."""
_ = pytest.importorskip("lineax") # only tested using lineax
atol = 1e-2 if lse_mode else 5e-2 # lower tolerance for lse mode.
rtol = 1e-2 if lse_mode else 1.5e-1 # lower tolerance for lse mode.
n, m = shape
Expand Down
3 changes: 0 additions & 3 deletions tests/tools/gaussian_mixture/fit_gmm_pair_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@
gaussian_mixture_pair,
)

# on 3.8, neural (flax/optax) is not installed
_ = pytest.importorskip("optax")


class TestFitGmmPair:

Expand Down

0 comments on commit 2147cbe

Please sign in to comment.