diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 5906ace8c..d0d2b0bb3 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -22,9 +22,9 @@ jobs: lint-kind: [code, docs] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0b01fd9be..2bf2d2ede 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -12,9 +12,9 @@ jobs: environment: publish-pypi steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: '3.10' diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a06447d9f..066788968 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -25,9 +25,9 @@ jobs: jax-version: [jax-default, jax-latest] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.9 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -51,7 +51,7 @@ jobs: image: docker://michalk8/cuda:12.2.2-cudnn8-devel-ubuntu22.04 options: --gpus="device=2" steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Install dependencies run: | @@ -84,9 +84,9 @@ jobs: os: macos-14 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index 93d346495..0a23620a6 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair @@ -170,7 +171,7 @@ def test_pytree_mapping(self, epsilon, tau, lock_gmm1): ) expected_gmm1_loc = 2.0 * self.gmm1.loc if not lock_gmm1 else self.gmm1.loc - pair_x_2 = jax.tree_map(lambda x: 2.0 * x, pair) + pair_x_2 = jtu.tree_map(lambda x: 2.0 * x, pair) # gmm parameters should be doubled np.testing.assert_allclose(2.0 * pair.gmm0.loc, pair_x_2.gmm0.loc) np.testing.assert_allclose(expected_gmm1_loc, pair_x_2.gmm1.loc) diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index fa81c723a..887ba7f1e 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from ott.tools.gaussian_mixture import gaussian_mixture, linalg @@ -163,7 +164,7 @@ def test_pytree_mapping(self, rng: jax.Array): gmm = gaussian_mixture.GaussianMixture.from_random( rng=rng, n_components=3, n_dimensions=2 ) - gmm_x_2 = jax.tree_map(lambda x: 2.0 * x, gmm) + gmm_x_2 = jtu.tree_map(lambda x: 2.0 * x, gmm) np.testing.assert_allclose(2.0 * gmm.loc, gmm_x_2.loc, atol=1e-4, rtol=1e-4) np.testing.assert_allclose( 2.0 * gmm.scale_params, gmm_x_2.scale_params, atol=1e-4, rtol=1e-4 diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 9fc4feda0..f0c4563da 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from ott.tools.gaussian_mixture import gaussian, scale_tril @@ -137,14 +138,14 @@ def test_transport(self, rng: jax.Array): def test_flatten_unflatten(self, rng: jax.Array): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) - children, aux_data = jax.tree_util.tree_flatten(g) - g_new = jax.tree_util.tree_unflatten(aux_data, children) + children, aux_data = jtu.tree_flatten(g) + g_new = jtu.tree_unflatten(aux_data, children) assert g == g_new def test_pytree_mapping(self, rng: jax.Array): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) - g_x_2 = jax.tree_map(lambda x: 2 * x, g) + g_x_2 = jtu.tree_map(lambda x: 2 * x, g) np.testing.assert_allclose(2.0 * g.loc, g_x_2.loc) np.testing.assert_allclose(2.0 * g.scale.params, g_x_2.scale.params) diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index ec2c74a56..4cb0dc370 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from ott.tools.gaussian_mixture import probabilities @@ -63,15 +64,15 @@ def test_sample(self, rng: jax.Array): def test_flatten_unflatten(self): probs = jnp.array([0.1, 0.2, 0.3, 0.4]) pp = probabilities.Probabilities.from_probs(probs) - children, aux_data = jax.tree_util.tree_flatten(pp) - pp_new = jax.tree_util.tree_unflatten(aux_data, children) + children, aux_data = jtu.tree_flatten(pp) + pp_new = jtu.tree_unflatten(aux_data, children) np.testing.assert_array_equal(pp.params, pp_new.params) assert pp == pp_new def test_pytree_mapping(self): probs = jnp.array([0.1, 0.2, 0.3, 0.4]) pp = probabilities.Probabilities.from_probs(probs) - pp_x_2 = jax.tree_map(lambda x: 2 * x, pp) + pp_x_2 = jtu.tree_map(lambda x: 2 * x, pp) np.testing.assert_allclose( 2.0 * pp.params, pp_x_2.params, rtol=1e-6, atol=1e-6 ) diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 3ef487f45..64256aac2 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -15,6 +15,7 @@ import jax import jax.numpy as jnp +import jax.tree_util as jtu import numpy as np from ott.math import matrix_square_root @@ -102,12 +103,12 @@ def test_transport(self, rng: jax.Array): def test_flatten_unflatten(self, rng: jax.Array): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) - children, aux_data = jax.tree_util.tree_flatten(scale) - scale_new = jax.tree_util.tree_unflatten(aux_data, children) + children, aux_data = jtu.tree_flatten(scale) + scale_new = jtu.tree_unflatten(aux_data, children) np.testing.assert_array_equal(scale.params, scale_new.params) assert scale == scale_new def test_pytree_mapping(self, rng: jax.Array): scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) - scale_x_2 = jax.tree_map(lambda x: 2 * x, scale) + scale_x_2 = jtu.tree_map(lambda x: 2 * x, scale) np.testing.assert_allclose(2.0 * scale.params, scale_x_2.params)