From 05ff6ddb3a7f1e5f8f9739b50f242cc36a2969c4 Mon Sep 17 00:00:00 2001 From: othmanesebbouh <44581760+othmanesebbouh@users.noreply.github.com> Date: Fri, 24 Feb 2023 14:44:01 +0100 Subject: [PATCH] Make random seeding consistent with PRNG keys everywhere instead of seeds (#290) * add ENS * removed ENS * Test * Test * Change gw_barycenter and neural dual * fix random number generators in initializers * k-means and gromov * fix rng on low rank geometry and datasets * fix remaining seed changes to rng * fixed rng reference and increased epsilon to make algo converge, should open issue * fixed keys to rngs * extended rng instead of keys to inside functions and added jax.random.PRNGKeyArray as default type * fixed more keys to rng with correct default type * ran pre-commit * Fix typo --------- Co-authored-by: Anastasiia Co-authored-by: michalk8 <46717574+michalk8@users.noreply.github.com> --- src/ott/geometry/geometry.py | 17 ++- src/ott/geometry/low_rank.py | 6 +- .../initializers/linear/initializers_lr.py | 70 +++++------ src/ott/problems/nn/dataset.py | 26 ++-- .../problems/quadratic/quadratic_problem.py | 14 ++- .../solvers/linear/continuous_barycenter.py | 15 ++- src/ott/solvers/linear/sinkhorn_lr.py | 6 +- src/ott/solvers/nn/models.py | 4 +- src/ott/solvers/nn/neuraldual.py | 11 +- .../solvers/quadratic/gromov_wasserstein.py | 28 ++--- src/ott/solvers/quadratic/gw_barycenter.py | 19 ++- src/ott/tools/gaussian_mixture/fit_gmm.py | 30 ++--- src/ott/tools/gaussian_mixture/gaussian.py | 14 +-- .../gaussian_mixture/gaussian_mixture.py | 16 +-- src/ott/tools/gaussian_mixture/linalg.py | 4 +- .../tools/gaussian_mixture/probabilities.py | 8 +- src/ott/tools/gaussian_mixture/scale_tril.py | 10 +- src/ott/tools/k_means.py | 40 +++--- tests/geometry/costs_test.py | 31 ++--- tests/geometry/graph_test.py | 15 +-- tests/geometry/low_rank_test.py | 118 +++++++++--------- tests/geometry/pointcloud_test.py | 49 ++++---- tests/geometry/scaling_cost_test.py | 8 +- tests/geometry/subsetting_test.py | 34 ++--- .../initializers/linear/sinkhorn_init_test.py | 16 +-- .../linear/sinkhorn_lr_init_test.py | 38 +++--- tests/initializers/quadratic/gw_init_test.py | 25 ++-- tests/math/lse_test.py | 14 +-- tests/math/matrix_square_root_test.py | 40 +++--- tests/problems/linear/potentials_test.py | 82 ++++++------ .../linear/continuous_barycenter_test.py | 22 ++-- tests/solvers/linear/sinkhorn_diff_test.py | 82 ++++++------ tests/solvers/linear/sinkhorn_grid_test.py | 34 ++--- tests/solvers/linear/sinkhorn_lr_test.py | 2 +- tests/solvers/linear/sinkhorn_misc_test.py | 13 +- tests/solvers/linear/sinkhorn_test.py | 30 ++--- tests/solvers/nn/icnn_test.py | 18 +-- tests/solvers/quadratic/fgw_test.py | 41 +++--- tests/solvers/quadratic/gw_barycenter_test.py | 16 +-- tests/solvers/quadratic/gw_test.py | 51 ++++---- .../gaussian_mixture/fit_gmm_pair_test.py | 10 +- tests/tools/gaussian_mixture/fit_gmm_test.py | 8 +- .../gaussian_mixture_pair_test.py | 8 +- .../gaussian_mixture/gaussian_mixture_test.py | 46 +++---- tests/tools/gaussian_mixture/gaussian_test.py | 44 +++---- tests/tools/gaussian_mixture/linalg_test.py | 40 +++--- .../gaussian_mixture/probabilities_test.py | 6 +- .../tools/gaussian_mixture/scale_tril_test.py | 40 +++--- tests/tools/k_means_test.py | 90 +++++++------ tests/tools/segment_sinkhorn_test.py | 2 +- tests/tools/sinkhorn_divergence_test.py | 4 +- tests/tools/soft_sort_test.py | 22 ++-- 52 files changed, 743 insertions(+), 694 deletions(-) diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 04452854e..6a0ee4a07 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -623,7 +623,7 @@ def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, - seed: int = 0, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), scale: float = 1. ) -> 'low_rank.LRCGeometry': r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`. @@ -642,7 +642,7 @@ def to_LRCGeometry( rank: Target rank of the :attr:`cost_matrix`. tol: Tolerance of the error. The total number of sampled points is :math:`min(n, m,\frac{rank}{tol})`. - seed: Random seed. + rng: The PRNG key to use for initializing the model. scale: Value used to rescale the factors of the low-rank geometry. Useful when this geometry is used in the linear term of fused GW. @@ -664,19 +664,18 @@ def to_LRCGeometry( cost_1 = u cost_2 = (s[:, None] * vh).T else: - rng = jax.random.PRNGKey(seed) - key1, key2, key3, key4, key5 = jax.random.split(rng, 5) + rng1, rng2, rng3, rng4, rng5 = jax.random.split(rng, 5) n_subset = min(int(rank / tol), n, m) - i_star = jax.random.randint(key1, shape=(), minval=0, maxval=n) - j_star = jax.random.randint(key2, shape=(), minval=0, maxval=m) + i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n) + j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m) ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,) cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,) p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,) p_row /= jnp.sum(p_row) - row_ixs = jax.random.choice(key3, n, shape=(n_subset,), p=p_row) + row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row) # (n_subset, m) s = self.subset(row_ixs, None).cost_matrix s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None]) @@ -684,7 +683,7 @@ def to_LRCGeometry( p_col = jnp.sum(s ** 2, axis=0) # (m,) p_col /= jnp.sum(p_col) # (n_subset,) - col_ixs = jax.random.choice(key4, m, shape=(n_subset,), p=p_col) + col_ixs = jax.random.choice(rng4, m, shape=(n_subset,), p=p_col) # (n_subset, n_subset) w = s[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :]) @@ -696,7 +695,7 @@ def to_LRCGeometry( v = v.T / jnp.sqrt(d)[None, :] inv_scale = (1. / jnp.sqrt(n_subset)) - col_ixs = jax.random.choice(key5, m, shape=(n_subset,)) # (n_subset,) + col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,) # (n, n_subset) A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 17580b43a..126137c31 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -231,9 +231,13 @@ def finalize(carry): return max_value + self._bias def to_LRCGeometry( - self, rank: int = 0, tol: float = 1e-2, seed: int = 0 + self, + rank: int = 0, + tol: float = 1e-2, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> 'LRCGeometry': """Return self.""" + del rank, tol, rng return self @property diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 378a0ce55..74b20e9e1 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -66,7 +66,7 @@ def __init__(self, rank: int, **kwargs: Any): def init_q( self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, @@ -75,7 +75,7 @@ def init_q( Args: ot_prob: OT problem. - key: Random key for seeding. + rng: Random key for seeding. init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. @@ -87,7 +87,7 @@ def init_q( def init_r( self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, @@ -96,7 +96,7 @@ def init_r( Args: ot_prob: Linear OT problem. - key: Random key for seeding. + rng: Random key for seeding. init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. @@ -108,14 +108,14 @@ def init_r( def init_g( self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, **kwargs: Any, ) -> jnp.ndarray: """Initialize the low-rank factor :math:`g`. Args: ot_prob: OT problem. - key: Random key for seeding. + rng: Random key for seeding. kwargs: Additional keyword arguments. Returns: @@ -176,7 +176,7 @@ def __call__( r: Optional[jnp.ndarray] = None, g: Optional[jnp.ndarray] = None, *, - key: Optional[jnp.ndarray] = None, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), **kwargs: Any ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Initialize the factors :math:`Q`, :math:`R` and :math:`g`. @@ -189,23 +189,21 @@ def __call__( using :meth:`init_r`. g: Factor of shape ``[rank,]``. If `None`, it will be initialized using :meth:`init_g`. - key: Random key for seeding. + rng: Random key for seeding. kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r` and :meth:`init_g`. Returns: The factors :math:`Q`, :math:`R` and :math:`g`, respectively. """ - if key is None: - key = jax.random.PRNGKey(0) - key1, key2, key3 = jax.random.split(key, 3) + rng1, rng2, rng3 = jax.random.split(rng, 3) if g is None: - g = self.init_g(ot_prob, key1, **kwargs) + g = self.init_g(ot_prob, rng1, **kwargs) if q is None: - q = self.init_q(ot_prob, key2, init_g=g, **kwargs) + q = self.init_q(ot_prob, rng2, init_g=g, **kwargs) if r is None: - r = self.init_r(ot_prob, key3, init_g=g, **kwargs) + r = self.init_r(ot_prob, rng3, init_g=g, **kwargs) assert g.shape == (self.rank,) assert q.shape == (ot_prob.a.shape[0], self.rank) @@ -240,37 +238,37 @@ class RandomInitializer(LRInitializer): def init_q( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del kwargs, init_g a = ot_prob.a - init_q = jnp.abs(jax.random.normal(key, (a.shape[0], self.rank))) + init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank))) return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True)) def init_r( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del kwargs, init_g b = ot_prob.b - init_r = jnp.abs(jax.random.normal(key, (b.shape[0], self.rank))) + init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank))) return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True)) def init_g( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, **kwargs: Any, ) -> jnp.ndarray: del kwargs - init_g = jnp.abs(jax.random.uniform(key, (self.rank,))) + 1. + init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1. return init_g / jnp.sum(init_g) @@ -314,32 +312,32 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: - del key, kwargs + del rng, kwargs return self._compute_factor(ot_prob, init_g, which="q") def init_r( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: - del key, kwargs + del rng, kwargs return self._compute_factor(ot_prob, init_g, which="r") def init_g( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, **kwargs: Any, ) -> jnp.ndarray: - del key, kwargs + del rng, kwargs return jnp.ones((self.rank,)) / self.rank @@ -387,7 +385,7 @@ def _extract_array( def _compute_factor( self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, which: Literal["q", "r"], @@ -413,7 +411,7 @@ def _compute_factor( arr = self._extract_array(geom, first=which == "q") marginals = ot_prob.a if which == "q" else ot_prob.b - centroids = fn(arr, self.rank, key=key).centroids + centroids = fn(arr, self.rank, rng=rng).centroids geom = pointcloud.PointCloud( arr, centroids, epsilon=0.1, scale_cost="max_cost" ) @@ -425,34 +423,34 @@ def _compute_factor( def init_q( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: return self._compute_factor( - ot_prob, key, init_g=init_g, which="q", **kwargs + ot_prob, rng, init_g=init_g, which="q", **kwargs ) def init_r( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: return self._compute_factor( - ot_prob, key, init_g=init_g, which="r", **kwargs + ot_prob, rng, init_g=init_g, which="r", **kwargs ) def init_g( # noqa: D102 self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, **kwargs: Any, ) -> jnp.ndarray: - del key, kwargs + del rng, kwargs return jnp.ones((self.rank,)) / self.rank def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 @@ -518,7 +516,7 @@ class State(NamedTuple): # noqa: D106 def _compute_factor( self, ot_prob: Problem_t, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, *, init_g: jnp.ndarray, which: Literal["q", "r"], @@ -530,7 +528,7 @@ def _compute_factor( def init_fn() -> GeneralizedKMeansInitializer.State: n = geom.shape[0] - factor = jnp.abs(jax.random.normal(key, (n, self.rank))) + 1. # (n, r) + factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1. # (n, r) factor *= consts.marginal[:, None] / jnp.sum( factor, axis=1, keepdims=True ) diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index a39ea3782..9d4c96d2c 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -50,13 +50,13 @@ class GaussianMixture: - ``square_four`` (two-dimensional Gaussians in the corners of a rectangle) batch_size: batch size of the samples - init_key: initial PRNG key + init_rng: initial PRNG key scale: scale of the individual Gaussian samples variance: the variance of the individual Gaussian samples """ name: Name_t batch_size: int - init_key: jax.random.PRNGKey + init_rng: jax.random.PRNGKeyArray scale: float = 5.0 variance: float = 0.5 @@ -95,11 +95,11 @@ def create_sample_generators(self) -> Iterator[jnp.array]: Returns: A generator of samples from the Gaussian mixture. """ - key = self.init_key + rng = self.init_rng while True: - k1, k2, key = jax.random.split(key, 3) - means = jax.random.choice(k1, self.centers, [self.batch_size]) - normal_samples = jax.random.normal(k2, [self.batch_size, 2]) + rng1, rng2, rng = jax.random.split(rng, 3) + means = jax.random.choice(rng1, self.centers, [self.batch_size]) + normal_samples = jax.random.normal(rng2, [self.batch_size, 2]) samples = self.scale * means + self.variance ** 2 * normal_samples yield samples @@ -109,7 +109,7 @@ def create_gaussian_mixture_samplers( name_target: Name_t, train_batch_size: int = 2048, valid_batch_size: int = 2048, - key: jax.random.PRNGKey = jax.random.PRNGKey(0), + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> Tuple[Dataset, Dataset, int]: """Creates Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`. @@ -118,33 +118,33 @@ def create_gaussian_mixture_samplers( name_target: name of the target sampler train_batch_size: the training batch size valid_batch_size: the validation batch size - key: initial PRNG key + rng: initial PRNG key Returns: The dataset and dimension of the data. """ - k1, k2, k3, k4 = jax.random.split(key, 4) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) train_dataset = Dataset( source_iter=iter( GaussianMixture( - name_source, batch_size=train_batch_size, init_key=k1 + name_source, batch_size=train_batch_size, init_rng=rng1 ) ), target_iter=iter( GaussianMixture( - name_target, batch_size=train_batch_size, init_key=k2 + name_target, batch_size=train_batch_size, init_rng=rng2 ) ) ) valid_dataset = Dataset( source_iter=iter( GaussianMixture( - name_source, batch_size=valid_batch_size, init_key=k3 + name_source, batch_size=valid_batch_size, init_rng=rng3 ) ), target_iter=iter( GaussianMixture( - name_target, batch_size=valid_batch_size, init_key=k4 + name_target, batch_size=valid_batch_size, init_rng=rng4 ) ) ) diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index 8c17179ab..f73f63747 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -365,11 +365,13 @@ def convertible(geom: geometry.Geometry) -> bool: (geom_xy is None or convertible(geom_xy)) ) - def to_low_rank(self, seed: int = 0) -> "QuadraticProblem": + def to_low_rank( + self, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) + ) -> "QuadraticProblem": """Convert geometries to low-rank. Args: - seed: Random seed. + rng: Random key for seeding. Returns: Quadratic problem with low-rank geometries. @@ -388,11 +390,11 @@ def convert( return self (geom_xx, geom_yy, geom_xy, *children), aux_data = self.tree_flatten() - (s1, s2, s3) = jax.random.split(jax.random.PRNGKey(seed), 3)[:, 0] + rng1, rng2, rng3 = jax.random.split(rng, 3) (r1, r2, r3), (t1, t2, t3) = convert(self.ranks), convert(self.tolerances) - geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, seed=s1) - geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, seed=s2) + geom_xx = geom_xx.to_LRCGeometry(rank=r1, tol=t1, rng=rng1) + geom_yy = geom_yy.to_LRCGeometry(rank=r2, tol=t2, rng=rng2) if self.is_fused: if isinstance( geom_xy, pointcloud.PointCloud @@ -400,7 +402,7 @@ def convert( geom_xy = geom_xy.to_LRCGeometry(scale=self.fused_penalty) else: geom_xy = geom_xy.to_LRCGeometry( - rank=r3, tol=t3, seed=s3, scale=self.fused_penalty + rank=r3, tol=t3, rng=rng3, scale=self.fused_penalty ) return type(self).tree_unflatten( diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index 04db11c19..d584082b5 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -129,14 +129,14 @@ def solve_linear_ot( @jax.tree_util.register_pytree_node_class class FreeWassersteinBarycenter(was_solver.WassersteinSolver): - """A Continuous Wassertsein barycenter solver, built on WassersteinSolver.""" + """Continuous Wassertsein barycenter solver.""" def __call__( # noqa: D102 self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, - rng: int = 0 + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> FreeBarycenterState: # TODO(michalk8): no reason for iterations to be outside this class run_fn = jax.jit(iterations, static_argnums=1) if self.jit else iterations @@ -147,8 +147,7 @@ def init_state( bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, - # TODO(michalk8): change the API to pass the PRNG key directly - rng: int = 0, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. @@ -158,8 +157,8 @@ def init_state( x_init: Initial barycenter estimate of shape ``[bar_size, ndim]``. If `None`, ``bar_size`` points will be sampled from the input measures according to their weights - :attr:`~ott.problems.linear.barycenter_problem.BarycenterProblem.flattened_y`. - rng: Seed for :func:`jax.random.PRNGKey`. + :attr:`~ott.problems.linear.barycenter_problem.FreeBarycenterProblem.flattened_y`. + rng: Random key for seeding. Returns: The initial barycenter state. @@ -170,7 +169,7 @@ def init_state( else: # sample randomly points in the support of the y measures indices_subset = jax.random.choice( - jax.random.PRNGKey(rng), + rng, a=bar_prob.flattened_y.shape[0], shape=(bar_size,), replace=False, @@ -202,7 +201,7 @@ def output_from_state( # noqa: D102 def iterations( solver: FreeWassersteinBarycenter, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, - rng: int + rng: jax.random.PRNGKeyArray ) -> FreeBarycenterState: """Jittable Wasserstein barycenter outer loop.""" diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 550bc3c86..106457416 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -319,7 +319,7 @@ def __call__( ot_prob: linear_problem.LinearProblem, init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None, None), - key: Optional[jnp.ndarray] = None, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), **kwargs: Any, ) -> LRSinkhornOutput: """Run low-rank Sinkhorn. @@ -333,7 +333,7 @@ def __call__( - :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.g`. Any `None` values will be initialized using the initializer. - key: Random key for seeding. + rng: Random key for seeding. kwargs: Additional arguments when calling the initializer. Returns: @@ -341,7 +341,7 @@ def __call__( """ assert ot_prob.is_balanced, "Unbalanced case is not implemented." initializer = self.create_initializer(ot_prob) - init = initializer(ot_prob, *init, key=key, **kwargs) + init = initializer(ot_prob, *init, rng=rng, **kwargs) run_fn = jax.jit(run) if self.jit else run return run_fn(ot_prob, self, init) diff --git a/src/ott/solvers/nn/models.py b/src/ott/solvers/nn/models.py index 92f7d3d48..7823b0339 100644 --- a/src/ott/solvers/nn/models.py +++ b/src/ott/solvers/nn/models.py @@ -289,7 +289,7 @@ def __call__(self, x: jnp.ndarray) -> float: # noqa: D102 def create_train_state( self, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] = None, @@ -348,7 +348,7 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: # noqa: D102 def create_train_state( self, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] = None, diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index 3f148393b..5af1e9d41 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -99,7 +99,7 @@ class W2NeuralDual: valid_freq: frequency with which model is validated log_freq: frequency with training and validation are logged logging: option to return logs - seed: random seed for network initializations + rng: random key used for seeding for network initializations pos_weights: option to train networks with positive weights or regularizer beta: regularization parameter when not training with positive weights conjugate_solver: numerical solver for the Fenchel conjugate. @@ -123,7 +123,7 @@ def __init__( valid_freq: int = 1000, log_freq: int = 1000, logging: bool = False, - seed: int = 0, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER, @@ -144,9 +144,6 @@ def __init__( self.conjugate_solver = conjugate_solver self.amortization_loss = amortization_loss - # set random key - rng = jax.random.PRNGKey(seed) - # set default optimizers if optimizer_f is None: optimizer_f = optax.adam(learning_rate=0.0001, b1=0.5, b2=0.9, eps=1e-8) @@ -168,14 +165,14 @@ def __init__( ) def setup( - self, rng: jnp.ndarray, neural_f: models.ModelBase, + self, rng: jax.random.PRNGKeyArray, neural_f: models.ModelBase, neural_g: models.ModelBase, dim_data: int, optimizer_f: optax.OptState, optimizer_g: optax.OptState, init_f_params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]], init_g_params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] ) -> None: """Setup all components required to train the network.""" - # split random key + # split random number generator rng, rng_f, rng_g = jax.random.split(rng, 3) # check setting of network architectures diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 327b2243c..875e8311c 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -107,7 +107,7 @@ class GWState(NamedTuple): linearization of GW. linear_pb: Local linearization of the quadratic GW problem. old_transport_mass: Intermediary value of the mass of the transport matrix. - keys: Random keys passed to low-rank initializers at every GW iteration + rngs: Random keys passed to low-rank initializers at every GW iteration when not using warm start. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. @@ -118,7 +118,7 @@ class GWState(NamedTuple): linear_state: LinearOutput linear_pb: linear_problem.LinearProblem old_transport_mass: float - keys: Optional[jnp.ndarray] = None + rngs: Optional[jax.random.PRNGKeyArray] = None errors: Optional[jnp.ndarray] = None def set(self, **kwargs: Any) -> 'GWState': @@ -202,7 +202,7 @@ def __call__( self, prob: quadratic_problem.QuadraticProblem, init: Optional[linear_problem.LinearProblem] = None, - key: Optional[jnp.ndarray] = None, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), **kwargs: Any, ) -> GWOutput: """Run the Gromov-Wasserstein solver. @@ -216,19 +216,17 @@ def __call__( Returns: The Gromov-Wasserstein output. """ - if key is None: - key = jax.random.PRNGKey(0) - key1, key2 = jax.random.split(key, 2) + rng1, rng2 = jax.random.split(rng, 2) if prob._is_low_rank_convertible: prob = prob.to_low_rank() if init is None: initializer = self.create_initializer(prob) - init = initializer(prob, epsilon=self.epsilon, key=key1, **kwargs) + init = initializer(prob, epsilon=self.epsilon, rng=rng1, **kwargs) run_fn = jax.jit(iterations) if self.jit else iterations - out = run_fn(self, prob, init, key2) + out = run_fn(self, prob, init, rng2) # TODO(lpapaxanthos): remove stop_gradient when using backprop if self.is_low_rank: linearization = prob.update_lr_linearization( @@ -250,14 +248,14 @@ def init_state( self, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, ) -> GWState: """Initialize the state of the Gromov-Wasserstein iterations. Args: prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. - key: Random key for low-rank initializers. Only used when + rng: Random key for low-rank initializers. Only used when :attr:`warm_start` is `False`. Returns: @@ -277,7 +275,7 @@ def init_state( linear_state=linear_state, linear_pb=init, old_transport_mass=transport_mass, - keys=jax.random.split(key, num_iter), + rngs=jax.random.split(rng, num_iter), errors=errors, ) @@ -355,7 +353,7 @@ def iterations( solver: GromovWasserstein, prob: quadratic_problem.QuadraticProblem, init: linear_problem.LinearProblem, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, ) -> GWOutput: """Jittable Gromov-Wasserstein outer loop.""" @@ -372,11 +370,11 @@ def body_fn( lin_state = state.linear_state if solver.is_low_rank: - key = state.keys[iteration] + rng = state.rngs[iteration] init = (lin_state.q, lin_state.r, lin_state.g) if solver.warm_start else (None, None, None) linear_pb = prob.update_lr_linearization(state.linear_state) - out = solver.linear_ot_solver(linear_pb, init=init, key=key) + out = solver.linear_ot_solver(linear_pb, init=init, rng=rng) else: init = (lin_state.f, lin_state.g) if solver.warm_start else (None, None) linear_pb = prob.update_linearization( @@ -398,7 +396,7 @@ def body_fn( max_iterations=solver.max_iterations, inner_iterations=1, constants=solver, - state=solver.init_state(prob, init, key=key) + state=solver.init_state(prob, init, rng=rng) ) return solver.output_from_state(state) diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 3900f6c93..76a769845 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -131,7 +131,7 @@ def init_state( bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, a: Optional[jnp.ndarray] = None, - seed: int = 0, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> GWBarycenterState: """Initialize the (fused) Gromov-Wasserstein barycenter state. @@ -150,7 +150,7 @@ def init_state( the fused case. a: An array of shape ``[bar_size,]`` containing the barycenter weights. - seed: Random seed used when ``bar_init = None``. + rng: Random key for seeding used when ``bar_init = None``. Returns: The initial barycenter state. @@ -162,11 +162,10 @@ def init_state( if bar_init is None: _, b = problem.segmented_y_b - rng = jax.random.PRNGKey(seed) - keys = jax.random.split(rng, problem.num_measures) + rngs = jax.random.split(rng, problem.num_measures) linear_solver = self._quad_solver.linear_ot_solver - transports = init_transports(linear_solver, keys, a, b, problem.epsilon) + transports = init_transports(linear_solver, rngs, a, b, problem.epsilon) x = problem.update_features(transports, a) if problem.is_fused else None cost = problem.update_barycenter(transports, a) else: @@ -272,14 +271,14 @@ def tree_unflatten( # noqa: D102 @partial(jax.vmap, in_axes=[None, 0, None, 0, None]) def init_transports( - solver, key: jnp.ndarray, a: jnp.ndarray, b: jnp.ndarray, + solver, rng: jax.random.PRNGKeyArray, a: jnp.ndarray, b: jnp.ndarray, epsilon: Optional[float] ) -> jnp.ndarray: """Initialize random 2D point cloud and solve the linear OT problem. Args: solver: Linear OT solver. - key: Random key. + rng: Random key for seeding. a: Source marginals (e.g., for barycenter) of shape ``[bar_size,]``. b: Target marginals of shape ``[max_measure_size,]``. epsilon: Entropy regularization. @@ -287,9 +286,9 @@ def init_transports( Returns: Transport map of shape ``[bar_size, max_measure_size]``. """ - key1, key2 = jax.random.split(key, 2) - x = jax.random.normal(key1, shape=(len(a), 2)) - y = jax.random.normal(key2, shape=(len(b), 2)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(len(a), 2)) + y = jax.random.normal(rng2, shape=(len(b), 2)) geom = pointcloud.PointCloud( x, y, epsilon=epsilon, src_mask=a > 0, tgt_mask=b > 0 ) diff --git a/src/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py index b425145a0..35e16ff38 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm.py @@ -17,7 +17,7 @@ # initialize GMM with K-means++ gmm_init = fit_gmm.initialize( - key=key, + rng=rng, points=my_points, point_weights=None, n_components=COMPONENTS) @@ -195,12 +195,12 @@ def _dist_sq_one_loc(points: jnp.ndarray, loc: jnp.ndarray) -> jnp.ndarray: def _get_locs( - key: jnp.ndarray, points: jnp.ndarray, n_components: int + rng: jax.random.PRNGKeyArray, points: jnp.ndarray, n_components: int ) -> jnp.ndarray: """Get the initial component means. Args: - key: jax.random seed + rng: jax.random key points: (n, n_dimensions) array of observations n_components: desired number of components @@ -210,8 +210,8 @@ def _get_locs( points = points.copy() n_points = points.shape[0] weights = jnp.ones(n_points) / n_points - key, subkey = jax.random.split(key) - index = jax.random.choice(key=subkey, a=points.shape[0], p=weights) + rng, subrng = jax.random.split(rng) + index = jax.random.choice(key=subrng, a=points.shape[0], p=weights) loc = points[index] points = jnp.concatenate([points[:index], points[index + 1:]], axis=0) @@ -220,8 +220,8 @@ def _get_locs( dist_sq = _get_dist_sq(points, locs) min_dist_sq = jnp.min(dist_sq, axis=-1) weights = min_dist_sq / jnp.sum(min_dist_sq) - key, subkey = jax.random.split(key) - index = jax.random.choice(key=subkey, a=points.shape[0], p=weights) + rng, subrng = jax.random.split(rng) + index = jax.random.choice(key=subrng, a=points.shape[0], p=weights) loc = points[index] points = jnp.concatenate([points[:index], points[index + 1:]], axis=0) locs = jnp.concatenate([locs, loc[None]], axis=0) @@ -229,7 +229,7 @@ def _get_locs( def from_kmeans_plusplus( - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], n_components: int, @@ -237,7 +237,7 @@ def from_kmeans_plusplus( """Initialize a GMM via a single pass of K-means++. Args: - key: jax.random seed + rng: jax.random key points: (n, n_dimensions) array of observations point_weights: (n,) array of weights for points n_components: desired number of components @@ -248,8 +248,8 @@ def from_kmeans_plusplus( Raises: ValueError if any fitted parameters are non-finite. """ - key, subkey = jax.random.split(key) - locs = _get_locs(key=subkey, points=points, n_components=n_components) + rng, subrng = jax.random.split(rng) + locs = _get_locs(rng=subrng, points=points, n_components=n_components) dist_sq = _get_dist_sq(points, locs) assignment_prob = (dist_sq == jnp.min(dist_sq, axis=-1)[:, None]).astype(points.dtype) @@ -265,7 +265,7 @@ def from_kmeans_plusplus( def initialize( - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, points: jnp.ndarray, point_weights: Optional[jnp.ndarray], n_components: int, @@ -275,7 +275,7 @@ def initialize( """Initialize a GMM via K-means++ with retries on failure. Args: - key: jax.random seed + rng: jax.random key points: (n, n_dimensions) array of observations point_weights: (n,) array of weights for points n_components: desired number of components @@ -289,10 +289,10 @@ def initialize( ValueError if initialization was unsuccessful after n_attempts attempts. """ for attempt in range(n_attempts): - key, subkey = jax.random.split(key) + rng, subrng = jax.random.split(rng) try: return from_kmeans_plusplus( - key=subkey, + rng=subrng, points=points, point_weights=point_weights, n_components=n_components diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index 5b40bfb59..a28414539 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -65,7 +65,7 @@ def from_samples( @classmethod def from_random( cls, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, @@ -75,7 +75,7 @@ def from_random( """Construct a random Gaussian. Args: - key: jax.random seed + rng: jax.random key n_dimensions: desired covariance dimensions stdev: standard deviation of loc and log eigenvalues (means for both are 0) @@ -84,12 +84,12 @@ def from_random( Returns: A random Gaussian. """ - key, subkey0, subkey1 = jax.random.split(key, num=3) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) loc = jax.random.normal( - key=subkey0, shape=(n_dimensions,), dtype=dtype + key=subrng0, shape=(n_dimensions,), dtype=dtype ) * stdev_mean + ridge scale = scale_tril.ScaleTriL.from_random( - key=subkey1, n_dimensions=n_dimensions, stdev=stdev_cov, dtype=dtype + rng=subrng1, n_dimensions=n_dimensions, stdev=stdev_cov, dtype=dtype ) return cls(loc=loc, scale=scale) @@ -138,9 +138,9 @@ def log_prob( -0.5 * (d * LOG2PI + log_det[None] + jnp.sum(z ** 2., axis=-1)) ) # (?, k) - def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.random.PRNGKeyArray, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" - std_samples_t = jax.random.normal(key=key, shape=(self.n_dimensions, size)) + std_samples_t = jax.random.normal(key=rng, shape=(self.n_dimensions, size)) return self.loc[None] + ( jnp.swapaxes( jnp.matmul(self.scale.cholesky(), std_samples_t), diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index add35478a..bbd37b325 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -80,7 +80,7 @@ def __init__( @classmethod def from_random( cls, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n_components: int, n_dimensions: int, stdev_mean: float = 0.1, @@ -93,9 +93,9 @@ def from_random( loc = [] scale_params = [] for _ in range(n_components): - key, subkey = jax.random.split(key) + rng, subrng = jax.random.split(rng) component = gaussian.Gaussian.from_random( - key=subkey, + rng=subrng, n_dimensions=n_dimensions, stdev_mean=stdev_mean, stdev_cov=stdev_cov, @@ -107,7 +107,7 @@ def from_random( loc = jnp.stack(loc, axis=0) scale_params = jnp.stack(scale_params, axis=0) weight_ob = probabilities.Probabilities.from_random( - key=subkey, n_dimensions=n_components, stdev=stdev_weights, dtype=dtype + rng=subrng, n_dimensions=n_components, stdev=stdev_weights, dtype=dtype ) return cls( loc=loc, scale_params=scale_params, component_weight_ob=weight_ob @@ -221,12 +221,12 @@ def components(self) -> List[gaussian.Gaussian]: """List of all GMM components.""" return [self.get_component(i) for i in range(self.n_components)] - def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.random.PRNGKeyArray, size: int) -> jnp.ndarray: """Generate samples from the distribution.""" - subkey0, subkey1 = jax.random.split(key) - component = self.component_weight_ob.sample(key=subkey0, size=size) + subrng0, subrng1 = jax.random.split(rng) + component = self.component_weight_ob.sample(rng=subrng0, size=size) std_samples = jax.random.normal( - key=subkey1, shape=(size, self.n_dimensions) + key=subrng1, shape=(size, self.n_dimensions) ) def _transform_single_component(k, scale, loc): diff --git a/src/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py index b7f866b68..869a0975e 100644 --- a/src/ott/tools/gaussian_mixture/linalg.py +++ b/src/ott/tools/gaussian_mixture/linalg.py @@ -134,11 +134,11 @@ def invmatvectril( def get_random_orthogonal( - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, dim: int, dtype: Optional[jnp.dtype] = None ) -> jnp.ndarray: """Get a random orthogonal matrix with the specified dimension.""" - m = jax.random.normal(key=key, shape=[dim, dim], dtype=dtype) + m = jax.random.normal(key=rng, shape=[dim, dim], dtype=dtype) q, _ = jnp.linalg.qr(m) return q diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index 376483756..ce7fa1d6d 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -37,7 +37,7 @@ def __init__(self, params): @classmethod def from_random( cls, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: Optional[jnp.dtype] = None @@ -45,7 +45,7 @@ def from_random( """Construct a random Probabilities.""" return cls( params=jax.random - .normal(key=key, shape=(n_dimensions - 1,), dtype=dtype) * stdev + .normal(key=rng, shape=(n_dimensions - 1,), dtype=dtype) * stdev ) @classmethod @@ -78,10 +78,10 @@ def probs(self) -> jnp.ndarray: """Get the probabilities.""" return jax.nn.softmax(self.unnormalized_log_probs()) - def sample(self, key: jnp.ndarray, size: int) -> jnp.ndarray: + def sample(self, rng: jax.random.PRNGKeyArray, size: int) -> jnp.ndarray: """Sample from the distribution.""" return jax.random.categorical( - key=key, logits=self.unnormalized_log_probs(), shape=(size,) + key=rng, logits=self.unnormalized_log_probs(), shape=(size,) ) def tree_flatten(self): # noqa: D102 diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index c88379b9d..8aac0a455 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -46,7 +46,7 @@ def from_points_and_weights( @classmethod def from_random( cls, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n_dimensions: int, stdev: Optional[float] = 0.1, dtype: jnp.dtype = jnp.float32, @@ -54,7 +54,7 @@ def from_random( """Construct a random ScaleTriL. Args: - key: pseudo-random number generator key + rng: pseudo-random number generator key n_dimensions: number of dimensions stdev: desired standard deviation (around 0) for the log eigenvalues dtype: data type for the covariance matrix @@ -63,12 +63,12 @@ def from_random( A ScaleTriL. """ # generate a random orthogonal matrix - key, subkey = jax.random.split(key) - q = linalg.get_random_orthogonal(key=subkey, dim=n_dimensions, dtype=dtype) + rng, subrng = jax.random.split(rng) + q = linalg.get_random_orthogonal(rng=subrng, dim=n_dimensions, dtype=dtype) # generate random eigenvalues eigs = stdev * jnp.exp( - jax.random.normal(key=key, shape=(n_dimensions,), dtype=dtype) + jax.random.normal(key=rng, shape=(n_dimensions,), dtype=dtype) ) # random positive definite matrix diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index db675623a..8f0b46334 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -28,7 +28,7 @@ class KPPState(NamedTuple): - key: jnp.ndarray + rng: jax.random.PRNGKeyArray centroids: jnp.ndarray centroid_dists: jnp.ndarray @@ -108,39 +108,41 @@ def _from_state( def _random_init( - geom: pointcloud.PointCloud, k: int, key: jnp.ndarray + geom: pointcloud.PointCloud, k: int, rng: jax.random.PRNGKeyArray ) -> jnp.ndarray: ixs = jnp.arange(geom.shape[0]) - ixs = jax.random.choice(key, ixs, shape=(k,), replace=False) + ixs = jax.random.choice(rng, ixs, shape=(k,), replace=False) return geom.subset(ixs, None).x def _k_means_plus_plus( geom: pointcloud.PointCloud, k: int, - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n_local_trials: Optional[int] = None, ) -> jnp.ndarray: - def init_fn(geom: pointcloud.PointCloud, key: jnp.ndarray) -> KPPState: - key, next_key = jax.random.split(key, 2) - ix = jax.random.choice(key, jnp.arange(geom.shape[0]), shape=()) + def init_fn( + geom: pointcloud.PointCloud, rng: jax.random.PRNGKeyArray + ) -> KPPState: + rng, next_rng = jax.random.split(rng, 2) + ix = jax.random.choice(rng, jnp.arange(geom.shape[0]), shape=()) centroids = jnp.full((k, geom.cost_rank), jnp.inf).at[0].set(geom.x[ix]) dists = geom.subset(ix, None).cost_matrix[0] - return KPPState(key=next_key, centroids=centroids, centroid_dists=dists) + return KPPState(rng=next_rng, centroids=centroids, centroid_dists=dists) def body_fn( iteration: int, const: Tuple[pointcloud.PointCloud, jnp.ndarray], state: KPPState, compute_error: bool ) -> KPPState: del compute_error - key, next_key = jax.random.split(state.key, 2) + rng, next_rng = jax.random.split(state.rng, 2) geom, ixs = const # no need to normalize when `replace=True` probs = state.centroid_dists ixs = jax.random.choice( - key, ixs, shape=(n_local_trials,), p=probs, replace=True + rng, ixs, shape=(n_local_trials,), p=probs, replace=True ) geom = geom.subset(ixs, None) @@ -151,14 +153,14 @@ def body_fn( centroid_dists = candidate_dists[best_ix] return KPPState( - key=next_key, centroids=centroids, centroid_dists=centroid_dists + rng=next_rng, centroids=centroids, centroid_dists=centroid_dists ) if n_local_trials is None: n_local_trials = 2 + int(math.log(k)) assert n_local_trials > 0, n_local_trials - state = init_fn(geom, key) + state = init_fn(geom, rng) constants = (geom, jnp.arange(geom.shape[0])) state = fixed_point_loop.fixpoint_iter( lambda *_, **__: True, @@ -223,7 +225,7 @@ def _update_centroids( @functools.partial(jax.vmap, in_axes=[0] + [None] * 9) def _k_means( - key: jnp.ndarray, + rng: jax.random.PRNGKeyArray, geom: pointcloud.PointCloud, k: int, weights: Optional[jnp.ndarray] = None, @@ -248,7 +250,7 @@ def init_fn(init: Init_t) -> KMeansState: f"or a callable, found `{init_fn!r}`." ) - centroids = init(geom, k, key) + centroids = init(geom, k, rng) if centroids.shape != (k, geom.cost_rank): raise ValueError( f"Expected initial centroids to have shape " @@ -351,7 +353,7 @@ def k_means( min_iterations: int = 0, max_iterations: int = 300, store_inner_errors: bool = False, - key: Optional[jnp.ndarray] = None, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> KMeansOutput: r"""K-means clustering using Lloyd's algorithm :cite:`lloyd:82`. @@ -378,7 +380,7 @@ def k_means( min_iterations: Minimum number of iterations. max_iterations: Maximum number of iterations. store_inner_errors: Whether to store the errors (inertia) at each iteration. - key: Random key to seed the initializations. + rng: Random key for seeding the initializations. Returns: The k-means clustering. @@ -401,11 +403,9 @@ def k_means( weights = jnp.ones(geom.shape[0]) assert weights.shape == (geom.shape[0],) - if key is None: - key = jax.random.PRNGKey(0) - keys = jax.random.split(key, n_init) + rngs = jax.random.split(rng, n_init) out = _k_means( - keys, geom, k, weights, init, n_local_trials, tol, min_iterations, + rngs, geom, k, weights, init, n_local_trials, tol, min_iterations, max_iterations, store_inner_errors ) best_ix = jnp.argmin(out.error) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 8d3064d7a..5122da0b1 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -27,7 +27,7 @@ @pytest.mark.fast class TestCostFn: - def test_cosine(self, rng: jnp.ndarray): + def test_cosine(self, rng: jax.random.PRNGKeyArray): """Test the cosine cost function.""" x = jnp.array([0, 0]) y = jnp.array([0, 0]) @@ -45,9 +45,9 @@ def test_cosine(self, rng: jnp.ndarray): np.testing.assert_allclose(dist_x_y, 1.0 - -1.0, rtol=1e-5, atol=1e-5) n, m, d = 10, 12, 7 - keys = jax.random.split(rng, 2) - x = jax.random.normal(keys[0], (n, d)) - y = jax.random.normal(keys[1], (m, d)) + rngs = jax.random.split(rng, 2) + x = jax.random.normal(rngs[0], (n, d)) + y = jax.random.normal(rngs[1], (m, d)) cosine_fn = costs.Cosine() normalize = lambda v: v / jnp.sqrt(jnp.sum(v ** 2)) @@ -76,7 +76,7 @@ def test_cosine(self, rng: jnp.ndarray): @pytest.mark.fast class TestBuresBarycenter: - def test_bures(self, rng: jnp.ndarray): + def test_bures(self, rng: jax.random.PRNGKeyArray): d = 5 r = jnp.array([0.3206, 0.8825, 0.1113, 0.00052, 0.9454]) Sigma1 = r * jnp.eye(d) @@ -130,7 +130,9 @@ def test_reg_cost_legendre( @pytest.mark.parametrize("k", [1, 2, 7, 10]) @pytest.mark.parametrize("d", [10, 50, 100]) - def test_elastic_sq_k_overlap(self, rng: jax.random.PRNGKey, k: int, d: int): + def test_elastic_sq_k_overlap( + self, rng: jax.random.PRNGKeyArray, k: int, d: int + ): expected = jax.random.normal(rng, (d,)) cost_fn = costs.ElasticSqKOverlap(k=k, gamma=1e-2) @@ -149,9 +151,9 @@ def test_sparse_displacement( self, rng: jax.random.PRNGKeyArray, cost_fn: costs.RegTICost ): frac_sparse = 0.8 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (50, 30)) - y = jax.random.normal(key2, (71, 30)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (50, 30)) + y = jax.random.normal(rng2, (71, 30)) geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn) dp = sinkhorn.solve(geom).to_dual_potentials() @@ -164,11 +166,12 @@ def test_sparse_displacement( def test_stronger_regularization_increases_sparsity( self, rng: jax.random.PRNGKeyArray, cost_clazz: Type[costs.RegTICost] ): - d, keys = 30, jax.random.split(rng, 4) - x = jax.random.normal(keys[0], (50, d)) - y = jax.random.normal(keys[1], (71, d)) - xx = jax.random.normal(keys[2], (25, d)) - yy = jax.random.normal(keys[3], (35, d)) + d, rngs = 30, jax.random.split(rng, 4) + x = jax.random.normal(rngs[0], (50, d)) + y = jax.random.normal(rngs[1], (71, d)) + xx = jax.random.normal(rngs[2], (25, d)) + xx = jax.random.normal(rngs[2], (25, d)) + yy = jax.random.normal(rngs[3], (35, d)) sparsity = {False: [], True: []} for gamma in [9, 10, 100]: diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 3a29c235b..1370882ed 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -159,7 +159,7 @@ def test_solver(self, fmt: Optional[str]): @pytest.mark.fast.with_args("fmt", [None, "coo"], only_fast=0) def test_kernel_is_symmetric_positive_definite( - self, rng: jnp.ndarray, fmt: Optional[str] + self, rng: jax.random.PRNGKeyArray, fmt: Optional[str] ): n = 65 x = jax.random.normal(rng, (n,)) @@ -208,7 +208,8 @@ def test_automatic_t(self, fmt: Optional[str], as_laplacian: bool): only_fast=0, ) def test_approximates_ground_truth( - self, rng: jnp.ndarray, numerical_scheme: str, fmt: Optional[str] + self, rng: jax.random.PRNGKeyArray, numerical_scheme: str, + fmt: Optional[str] ): eps, n_steps = 1e-5, 20 G = random_graph(37, p=0.5, fmt=fmt) @@ -318,7 +319,7 @@ def laplacian(geom: graph.Graph) -> jnp.ndarray: np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) @pytest.mark.fast - def test_factor_cache_works(self, rng: jnp.ndarray): + def test_factor_cache_works(self, rng: jax.random.PRNGKeyArray): def timeit(fn: Callable[[Any], Any]) -> Callable[[Any], float]: @@ -373,7 +374,7 @@ def callback(g: graph.Graph) -> decomposition.CholeskySolver: # Total memory allocated: 99.1MiB @pytest.mark.fast @pytest.mark.limit_memory("200 MB") - def test_sparse_graph_memory(self, rng: jnp.ndarray): + def test_sparse_graph_memory(self, rng: jax.random.PRNGKeyArray): # use a graph with some structure for Cholesky to be faster G = nx.grid_graph((200, 200)) # 40 000 nodes L = nx.linalg.laplacian_matrix(G).tocsc() @@ -392,7 +393,7 @@ def test_sparse_graph_memory(self, rng: jnp.ndarray): only_fast=0, ) def test_graph_sinkhorn( - self, rng: jnp.ndarray, fmt: Optional[str], jit: bool + self, rng: jax.random.PRNGKeyArray, fmt: Optional[str], jit: bool ): def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: @@ -436,7 +437,7 @@ def callback(geom: geometry.Geometry) -> sinkhorn.SinkhornOutput: ids=["not-implicit", "implicit"], ) def test_dense_graph_differentiability( - self, rng: jnp.ndarray, implicit_diff: bool + self, rng: jax.random.PRNGKeyArray, implicit_diff: bool ): def callback( @@ -468,7 +469,7 @@ def callback( actual = 2 * jnp.vdot(v_w, grad_w) np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) - def test_tolerance_hilbert_metric(self, rng: jnp.ndarray): + def test_tolerance_hilbert_metric(self, rng: jax.random.PRNGKeyArray): n, n_steps, t, tol = 256, 1000, 1e-4, 3e-4 G = random_graph(n, p=0.15) x = jnp.abs(jax.random.normal(rng, (n,))) diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index b53abf9a0..c0c1e7f78 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -26,19 +26,19 @@ @pytest.mark.fast class TestLRGeometry: - def test_apply(self, rng: jnp.ndarray): + def test_apply(self, rng: jax.random.PRNGKeyArray): """Test application of cost to vec or matrix.""" n, m, r = 17, 11, 7 - keys = jax.random.split(rng, 5) - c1 = jax.random.normal(keys[0], (n, r)) - c2 = jax.random.normal(keys[1], (m, r)) + rngs = jax.random.split(rng, 5) + c1 = jax.random.normal(rngs[0], (n, r)) + c2 = jax.random.normal(rngs[1], (m, r)) c = jnp.matmul(c1, c2.T) bias = 0.27 geom = geometry.Geometry(c + bias) geom_lr = low_rank.LRCGeometry(c1, c2, bias=bias) for dim, axis in ((m, 1), (n, 0)): for mat_shape in ((dim, 2), (dim,)): - mat = jax.random.normal(keys[2], mat_shape) + mat = jax.random.normal(rngs[2], mat_shape) np.testing.assert_allclose( geom.apply_cost(mat, axis=axis), geom_lr.apply_cost(mat, axis=axis), @@ -47,13 +47,13 @@ def test_apply(self, rng: jnp.ndarray): @pytest.mark.parametrize("scale_cost", ['mean', 'max_cost', 'max_bound', 42.]) def test_conversion_pointcloud( - self, rng: jnp.ndarray, scale_cost: Union[str, float] + self, rng: jax.random.PRNGKeyArray, scale_cost: Union[str, float] ): """Test conversion from PointCloud to LRCGeometry.""" n, m, d = 17, 11, 3 - keys = jax.random.split(rng, 3) - x = jax.random.normal(keys[0], (n, d)) - y = jax.random.normal(keys[1], (m, d)) + rngs = jax.random.split(rng, 3) + x = jax.random.normal(rngs[0], (n, d)) + y = jax.random.normal(rngs[1], (m, d)) geom = pointcloud.PointCloud(x, y, scale_cost=scale_cost) geom_lr = geom.to_LRCGeometry() @@ -64,27 +64,27 @@ def test_conversion_pointcloud( ) for dim, axis in ((m, 1), (n, 0)): for mat_shape in ((dim, 2), (dim,)): - mat = jax.random.normal(keys[2], mat_shape) + mat = jax.random.normal(rngs[2], mat_shape) np.testing.assert_allclose( geom.apply_cost(mat, axis=axis), geom_lr.apply_cost(mat, axis=axis), rtol=1e-4 ) - def test_apply_squared(self, rng: jnp.ndarray): + def test_apply_squared(self, rng: jax.random.PRNGKeyArray): """Test application of squared cost to vec or matrix.""" n, m = 27, 25 - keys = jax.random.split(rng, 5) + rngs = jax.random.split(rng, 5) for r in [3, 15]: - c1 = jax.random.normal(keys[0], (n, r)) - c2 = jax.random.normal(keys[1], (m, r)) + c1 = jax.random.normal(rngs[0], (n, r)) + c2 = jax.random.normal(rngs[1], (m, r)) c = jnp.matmul(c1, c2.T) geom = geometry.Geometry(c) geom2 = geometry.Geometry(c ** 2) geom_lr = low_rank.LRCGeometry(c1, c2) for dim, axis in ((m, 1), (n, 0)): for mat_shape in ((dim, 2), (dim,)): - mat = jax.random.normal(keys[2], mat_shape) + mat = jax.random.normal(rngs[2], mat_shape) out_lr = geom_lr.apply_square_cost(mat, axis=axis) np.testing.assert_allclose( geom.apply_square_cost(mat, axis=axis), out_lr, rtol=5e-4 @@ -96,16 +96,16 @@ def test_apply_squared(self, rng: jnp.ndarray): @pytest.mark.parametrize("bias", [(0, 0), (4, 5)]) @pytest.mark.parametrize("scale_factor", [(1, 1), (2, 3)]) def test_add_lr_geoms( - self, rng: jnp.ndarray, bias: Tuple[float, float], + self, rng: jax.random.PRNGKeyArray, bias: Tuple[float, float], scale_factor: Tuple[float, float] ): """Test application of cost to vec or matrix.""" n, m, r, q = 17, 11, 7, 2 - keys = jax.random.split(rng, 5) - c1 = jax.random.normal(keys[0], (n, r)) - c2 = jax.random.normal(keys[1], (m, r)) - d1 = jax.random.normal(keys[0], (n, q)) - d2 = jax.random.normal(keys[1], (m, q)) + rngs = jax.random.split(rng, 5) + c1 = jax.random.normal(rngs[0], (n, r)) + c2 = jax.random.normal(rngs[1], (m, r)) + d1 = jax.random.normal(rngs[0], (n, q)) + d2 = jax.random.normal(rngs[1], (m, q)) s1, s2 = scale_factor b1, b2 = bias @@ -119,13 +119,13 @@ def test_add_lr_geoms( geom_lr = geom_lr_c + geom_lr_d for dim, axis in ((m, 1), (n, 0)): - mat = jax.random.normal(keys[1], (dim, 2)) + mat = jax.random.normal(rngs[1], (dim, 2)) np.testing.assert_allclose( geom.apply_cost(mat, axis=axis), geom_lr.apply_cost(mat, axis=axis), rtol=1e-4 ) - vec = jax.random.normal(keys[1], (dim,)) + vec = jax.random.normal(rngs[1], (dim,)) np.testing.assert_allclose( geom.apply_cost(vec, axis=axis), geom_lr.apply_cost(vec, axis=axis), @@ -136,19 +136,19 @@ def test_add_lr_geoms( "scale,scale_cost,epsilon", [(0.1, "mean", None), (0.9, "max_cost", 1e-2)] ) def test_add_lr_geoms_scale_factor( - self, rng: jnp.ndarray, scale: float, scale_cost: str, + self, rng: jax.random.PRNGKeyArray, scale: float, scale_cost: str, epsilon: Optional[float] ): n, d = 71, 2 - key1, key2 = jax.random.split(rng, 2) + rng1, rng2 = jax.random.split(rng, 2) geom1 = pointcloud.PointCloud( - jax.random.normal(key1, (n, d)) + 10., + jax.random.normal(rng1, (n, d)) + 10., epsilon=epsilon, scale_cost=scale_cost ) geom2 = pointcloud.PointCloud( - jax.random.normal(key2, (n, d)) + 20., + jax.random.normal(rng2, (n, d)) + 20., epsilon=epsilon, scale_cost=scale_cost ) @@ -163,14 +163,14 @@ def test_add_lr_geoms_scale_factor( @pytest.mark.parametrize("axis", [0, 1]) @pytest.mark.parametrize("fn", [lambda x: x + 10, lambda x: x * 2]) def test_apply_affine_function_efficient( - self, rng: jnp.ndarray, fn: Callable[[jnp.ndarray], jnp.ndarray], - axis: int + self, rng: jax.random.PRNGKeyArray, fn: Callable[[jnp.ndarray], + jnp.ndarray], axis: int ): n, m, d = 21, 13, 3 - keys = jax.random.split(rng, 3) - x = jax.random.normal(keys[0], (n, d)) - y = jax.random.normal(keys[1], (m, d)) - vec = jax.random.normal(keys[2], (n if axis == 0 else m,)) + rngs = jax.random.split(rng, 3) + x = jax.random.normal(rngs[0], (n, d)) + y = jax.random.normal(rngs[1], (m, d)) + vec = jax.random.normal(rngs[2], (n if axis == 0 else m,)) geom = pointcloud.PointCloud(x, y) @@ -184,12 +184,12 @@ def test_apply_affine_function_efficient( np.testing.assert_allclose(res_ineff, res_eff, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("rank", [5, 1000]) - def test_point_cloud_to_lr(self, rng: jnp.ndarray, rank: int): + def test_point_cloud_to_lr(self, rng: jax.random.PRNGKeyArray, rank: int): n, m = 1500, 1000 scale = 2.0 - keys = jax.random.split(rng, 2) - x = jax.random.normal(keys[0], (n, rank)) - y = jax.random.normal(keys[1], (m, rank)) + rngs = jax.random.split(rng, 2) + x = jax.random.normal(rngs[0], (n, rank)) + y = jax.random.normal(rngs[1], (m, rank)) geom_pc = pointcloud.PointCloud(x, y) geom_lr = geom_pc.to_LRCGeometry(scale=scale) @@ -224,13 +224,15 @@ def assert_upper_bound( assert lhs <= rhs @pytest.mark.fast.with_args(rank=[2, 3], tol=[5e-1, 1e-2], only_fast=0) - def test_geometry_to_lr(self, rng: jnp.ndarray, rank: int, tol: float): - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(370, 3)) - y = jax.random.normal(key2, shape=(460, 3)) + def test_geometry_to_lr( + self, rng: jax.random.PRNGKeyArray, rank: int, tol: float + ): + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(370, 3)) + y = jax.random.normal(rng2, shape=(460, 3)) geom = geometry.Geometry(cost_matrix=x @ y.T) - geom_lr = geom.to_LRCGeometry(rank=rank, tol=tol, seed=42) + geom_lr = geom.to_LRCGeometry(rank=rank, tol=tol, rng=jax.random.PRNGKey(0)) np.testing.assert_array_equal(geom.shape, geom_lr.shape) assert geom_lr.cost_rank == rank @@ -245,13 +247,13 @@ def test_geometry_to_lr(self, rng: jnp.ndarray, rank: int, tol: float): only_fast=1 ) def test_point_cloud_to_lr( - self, rng: jnp.ndarray, batch_size: Optional[int], + self, rng: jax.random.PRNGKeyArray, batch_size: Optional[int], scale_cost: Optional[str] ): rank, tol = 7, 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(384, 10)) - y = jax.random.normal(key2, shape=(512, 10)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(384, 10)) + y = jax.random.normal(rng2, shape=(512, 10)) geom = pointcloud.PointCloud( x, y, @@ -270,10 +272,10 @@ def test_point_cloud_to_lr( assert geom_lr.cost_rank == rank self.assert_upper_bound(geom, geom_lr, rank=rank, tol=tol) - def test_to_lrc_geometry_noop(self, rng: jnp.ndarray): - key1, key2 = jax.random.split(rng, 2) - cost1 = jax.random.normal(key1, shape=(32, 2)) - cost2 = jax.random.normal(key2, shape=(23, 2)) + def test_to_lrc_geometry_noop(self, rng: jax.random.PRNGKeyArray): + rng1, rng2 = jax.random.split(rng, 2) + cost1 = jax.random.normal(rng1, shape=(32, 2)) + cost2 = jax.random.normal(rng2, shape=(23, 2)) geom = low_rank.LRCGeometry(cost1, cost2) geom_lrc = geom.to_LRCGeometry(rank=10) @@ -292,11 +294,11 @@ def test_apply_transport_from_potentials(self): np.testing.assert_allclose(res, 1.1253539e-07, rtol=1e-6, atol=1e-6) @pytest.mark.limit_memory("190 MB") - def test_large_scale_factorization(self, rng: jnp.ndarray): + def test_large_scale_factorization(self, rng: jax.random.PRNGKeyArray): rank, tol = 4, 1e-2 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(10_000, 7)) - y = jax.random.normal(key2, shape=(11_000, 7)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(10_000, 7)) + y = jax.random.normal(rng2, shape=(11_000, 7)) geom = pointcloud.PointCloud(x, y, epsilon=1e-2, cost_fn=costs.Cosine()) geom_lr = geom.to_LRCGeometry(rank=rank, tol=tol) @@ -323,10 +325,10 @@ def test_conversion_grid(self): cost_matrix, cost_matrix_lrc, rtol=1e-5, atol=1e-5 ) - def test_full_to_lrc_geometry(self, rng: jnp.ndarray): - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(13, 7)) - y = jax.random.normal(key2, shape=(29, 7)) + def test_full_to_lrc_geometry(self, rng: jax.random.PRNGKeyArray): + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(13, 7)) + y = jax.random.normal(rng2, shape=(29, 7)) geom = pointcloud.PointCloud(x, y, cost_fn=costs.PNormP(1.4)) geom_lrc = geom.to_LRCGeometry(rank=0) np.testing.assert_allclose( diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 97f1c0879..5cfa2cf7c 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -26,15 +26,15 @@ @pytest.mark.fast class TestPointCloudApply: - def test_apply_cost_and_kernel(self, rng: jnp.ndarray): + def test_apply_cost_and_kernel(self, rng: jax.random.PRNGKeyArray): """Test consistency of cost/kernel apply to vec.""" n, m, p, b = 5, 8, 10, 7 - keys = jax.random.split(rng, 5) - x = jax.random.normal(keys[0], (n, p)) - y = jax.random.normal(keys[1], (m, p)) + 1 + rngs = jax.random.split(rng, 5) + x = jax.random.normal(rngs[0], (n, p)) + y = jax.random.normal(rngs[1], (m, p)) + 1 cost = jnp.sum((x[:, None, :] - y[None, :, :]) ** 2, axis=-1) - vec0 = jax.random.normal(keys[2], (n, b)) - vec1 = jax.random.normal(keys[3], (m, b)) + vec0 = jax.random.normal(rngs[2], (n, b)) + vec1 = jax.random.normal(rngs[3], (m, b)) geom = pointcloud.PointCloud(x, y, batch_size=3) prod0_online = geom.apply_cost(vec0, axis=0) @@ -70,14 +70,14 @@ def test_apply_cost_and_kernel(self, rng: jnp.ndarray): np.testing.assert_allclose(prod0_online, prod0, rtol=1e-03, atol=1e-02) np.testing.assert_allclose(prod1_online, prod1, rtol=1e-03, atol=1e-02) - def test_general_cost_fn(self, rng: jnp.ndarray): + def test_general_cost_fn(self, rng: jax.random.PRNGKeyArray): """Test non-vec cost apply to vec.""" n, m, p, b = 5, 8, 10, 7 - keys = jax.random.split(rng, 5) - x = jax.random.normal(keys[0], (n, p)) - y = jax.random.normal(keys[1], (m, p)) + 1 - vec0 = jax.random.normal(keys[2], (n, b)) - vec1 = jax.random.normal(keys[3], (m, b)) + rngs = jax.random.split(rng, 5) + x = jax.random.normal(rngs[0], (n, p)) + y = jax.random.normal(rngs[1], (m, p)) + 1 + vec0 = jax.random.normal(rngs[2], (n, b)) + vec1 = jax.random.normal(rngs[3], (m, b)) geom = pointcloud.PointCloud(x, y, cost_fn=costs.Cosine(), batch_size=None) cost = geom.cost_matrix @@ -99,10 +99,10 @@ def test_correct_shape(self): np.testing.assert_array_equal(pc.shape, (n, m)) @pytest.mark.parametrize("axis", [0, 1]) - def test_apply_cost_without_norm(self, rng: jnp.ndarray, axis: 1): - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(17, 3)) - y = jax.random.normal(key2, shape=(12, 3)) + def test_apply_cost_without_norm(self, rng: jax.random.PRNGKeyArray, axis: 1): + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(17, 3)) + y = jax.random.normal(rng2, shape=(12, 3)) pc = pointcloud.PointCloud(x, y, cost_fn=costs.Cosine()) arr = jnp.ones((pc.shape[0],)) if axis == 0 else jnp.ones((pc.shape[1],)) @@ -124,11 +124,11 @@ class TestPointCloudCosineConversion: "scale_cost", ["mean", "median", "max_cost", "max_norm", 41] ) def test_cosine_to_sqeucl_conversion( - self, rng: jnp.ndarray, scale_cost: Union[str, float] + self, rng: jax.random.PRNGKeyArray, scale_cost: Union[str, float] ): - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(101, 4)) - y = jax.random.normal(key2, shape=(123, 4)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(101, 4)) + y = jax.random.normal(rng2, shape=(123, 4)) cosine = pointcloud.PointCloud( x, y, cost_fn=costs.Cosine(), scale_cost=scale_cost ) @@ -157,11 +157,12 @@ def test_cosine_to_sqeucl_conversion( ) @pytest.mark.parametrize("axis", [0, 1]) def test_apply_cost_cosine_to_sqeucl( - self, rng: jnp.ndarray, axis: int, scale_cost: Union[str, float] + self, rng: jax.random.PRNGKeyArray, axis: int, scale_cost: Union[str, + float] ): - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, shape=(17, 5)) - y = jax.random.normal(key2, shape=(12, 5)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, shape=(17, 5)) + y = jax.random.normal(rng2, shape=(12, 5)) cosine = pointcloud.PointCloud( x, y, cost_fn=costs.Cosine(), scale_cost=scale_cost ) diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index d07ef5ff3..d6c0d4d7e 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -28,7 +28,7 @@ class TestScaleCost: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 4 self.n = 7 self.m = 9 @@ -204,9 +204,9 @@ def test_max_scale_cost_low_rank_with_batch(self, batch_size: int): def test_max_scale_cost_low_rank_large_array(self): """Test max_cost options for large matrices.""" - _, *keys = jax.random.split(self.rng, 3) - cost1 = jax.random.uniform(keys[0], (10000, 2)) - cost2 = jax.random.uniform(keys[1], (11000, 2)) + _, *rngs = jax.random.split(self.rng, 3) + cost1 = jax.random.uniform(rngs[0], (10000, 2)) + cost2 = jax.random.uniform(rngs[1], (11000, 2)) max_cost_lr = jnp.max(jnp.dot(cost1, cost2.T)) geom0 = low_rank.LRCGeometry(cost1, cost2, scale_cost='max_cost') diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index bc32f643e..0b44f0f18 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -13,14 +13,14 @@ @pytest.fixture() def pc_masked( - rng: jnp.ndarray + rng: jax.random.PRNGKeyArray ) -> Tuple[pointcloud.PointCloud, pointcloud.PointCloud]: n, m = 20, 30 - key1, key2 = jax.random.split(rng, 2) + rng1, rng2 = jax.random.split(rng, 2) # x = jnp.full((n,), fill_value=1.) # y = jnp.full((m,), fill_value=2.) - x = jax.random.normal(key1, shape=(n, 3)) - y = jax.random.normal(key1, shape=(m, 3)) + x = jax.random.normal(rng1, shape=(n, 3)) + y = jax.random.normal(rng1, shape=(m, 3)) src_mask = jnp.asarray([0, 1, 2]) tgt_mask = jnp.asarray([3, 5, 6]) @@ -54,14 +54,14 @@ class TestMaskPointCloud: "clazz", [geometry.Geometry, pointcloud.PointCloud, low_rank.LRCGeometry] ) def test_mask( - self, rng: jnp.ndarray, clazz: Type[geometry.Geometry], + self, rng: jax.random.PRNGKeyArray, clazz: Type[geometry.Geometry], src_ixs: Optional[Union[int, Sequence[int]]], tgt_ixs: Optional[Union[int, Sequence[int]]] ): - key1, key2 = jax.random.split(rng, 2) + rng1, rng2 = jax.random.split(rng, 2) new_batch_size = 7 - x = jax.random.normal(key1, shape=(10, 3)) - y = jax.random.normal(key2, shape=(20, 3)) + x = jax.random.normal(rng1, shape=(10, 3)) + y = jax.random.normal(rng2, shape=(20, 3)) if clazz is geometry.Geometry: geom = clazz(cost_matrix=x @ y.T, scale_cost="mean") @@ -128,9 +128,10 @@ def test_masked_summary( ) def test_mask_permutation( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], + rng: jax.random.PRNGKeyArray ): - key1, key2 = jax.random.split(rng) + rng1, rng2 = jax.random.split(rng) geom, _ = geom_masked n, m = geom.shape @@ -141,8 +142,8 @@ def test_mask_permutation( children, aux_data = geom.tree_flatten() gt_geom = type(geom).tree_unflatten(aux_data, children) - geom._src_mask = jax.random.permutation(key1, jnp.arange(n)) - geom._tgt_mask = jax.random.permutation(key2, jnp.arange(m)) + geom._src_mask = jax.random.permutation(rng1, jnp.arange(n)) + geom._tgt_mask = jax.random.permutation(rng2, jnp.arange(m)) np.testing.assert_allclose(geom.mean_cost_matrix, gt_geom.mean_cost_matrix) np.testing.assert_allclose( @@ -150,15 +151,16 @@ def test_mask_permutation( ) def test_boolean_mask( - self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], rng: jnp.ndarray + self, geom_masked: Tuple[Geom_t, pointcloud.PointCloud], + rng: jax.random.PRNGKeyArray ): - key1, key2 = jax.random.split(rng) + rng1, rng2 = jax.random.split(rng) p = jnp.array([0.5, 0.5]) geom, _ = geom_masked n, m = geom.shape - src_mask = jax.random.choice(key1, jnp.array([False, True]), (n,), p=p) - tgt_mask = jax.random.choice(key1, jnp.array([False, True]), (m,), p=p) + src_mask = jax.random.choice(rng1, jnp.array([False, True]), (n,), p=p) + tgt_mask = jax.random.choice(rng1, jnp.array([False, True]), (m,), p=p) geom._src_mask = src_mask geom._tgt_mask = tgt_mask gt_cost = geom.cost_matrix[src_mask, :][:, tgt_mask] diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 81e407563..7d0abe78a 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -28,7 +28,7 @@ def create_sorting_problem( - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n: int, epsilon: float = 1e-2, batch_size: Optional[int] = None @@ -58,7 +58,7 @@ def create_sorting_problem( def create_ot_problem( - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, n: int, m: int, d: int, @@ -172,7 +172,7 @@ def test_sorting_init(self, vector_min: bool, lse_mode: bool): assert sink_out_init.converged assert sink_out_base.n_iters > sink_out_init.n_iters - def test_sorting_init_online(self, rng: jnp.ndarray): + def test_sorting_init_online(self, rng: jax.random.PRNGKeyArray): n = 100 epsilon = 1e-2 @@ -183,7 +183,7 @@ def test_sorting_init_online(self, rng: jnp.ndarray): with pytest.raises(AssertionError, match=r"online"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_sorting_init_square_cost(self, rng: jnp.ndarray): + def test_sorting_init_square_cost(self, rng: jax.random.PRNGKeyArray): n, m, d = 100, 150, 1 epsilon = 1e-2 @@ -192,7 +192,7 @@ def test_sorting_init_square_cost(self, rng: jnp.ndarray): with pytest.raises(AssertionError, match=r"square"): sort_init.init_dual_a(ot_problem, lse_mode=True) - def test_default_initializer(self, rng: jnp.ndarray): + def test_default_initializer(self, rng: jax.random.PRNGKeyArray): """Tests default initializer""" n, m, d = 200, 200, 2 epsilon = 1e-2 @@ -210,7 +210,7 @@ def test_default_initializer(self, rng: jnp.ndarray): np.testing.assert_array_equal(0., default_potential_a) np.testing.assert_array_equal(0., default_potential_b) - def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): + def test_gauss_pointcloud_geom(self, rng: jax.random.PRNGKeyArray): n, m, d = 200, 200, 2 epsilon = 1e-2 @@ -231,7 +231,7 @@ def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize("initializer", ["sorting", "gaussian", "subsample"]) def test_initializer_n_iter( - self, rng: jnp.ndarray, lse_mode: bool, jit: bool, + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, jit: bool, initializer: Literal["sorting", "gaussian", "subsample"] ): """Tests Gaussian initializer""" @@ -288,7 +288,7 @@ def test_initializer_n_iter( assert default_out.n_iters >= init_out.n_iters @pytest.mark.parametrize('lse_mode', [True, False]) - def test_meta_initializer(self, rng: jnp.ndarray, lse_mode: bool): + def test_meta_initializer(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): """Tests Meta initializer""" n, m, d = 200, 200, 2 epsilon = 1e-2 diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index 04a36aeb4..c85ff8e8f 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -28,7 +28,9 @@ class TestLRInitializers: @pytest.mark.fast.with_args("kind", ["pc", "lrc", "geom"], only_fast=0) - def test_create_default_initializer(self, rng: jnp.ndarray, kind: str): + def test_create_default_initializer( + self, rng: jax.random.PRNGKeyArray, kind: str + ): n, d, rank = 110, 2, 3 x = jax.random.normal(rng, (n, d)) geom = pointcloud.PointCloud(x) @@ -71,16 +73,16 @@ def test_explicitly_passing_initializer(self): ) @pytest.mark.parametrize("partial_init", ["q", "r", "g"]) def test_partial_initialization( - self, rng: jnp.ndarray, initializer: str, partial_init: str + self, rng: jax.random.PRNGKeyArray, initializer: str, partial_init: str ): n, d, rank = 100, 10, 6 - key1, key2, key3, key4 = jax.random.split(rng, 4) - x = jax.random.normal(key1, (n, d)) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) + x = jax.random.normal(rng1, (n, d)) pc = pointcloud.PointCloud(x, epsilon=5e-1) prob = linear_problem.LinearProblem(pc) - q_init = jax.random.normal(key2, (n, rank)) - r_init = jax.random.normal(key2, (n, rank)) - g_init = jax.random.normal(key2, (rank,)) + q_init = jax.random.normal(rng2, (n, rank)) + r_init = jax.random.normal(rng2, (n, rank)) + g_init = jax.random.normal(rng2, (rank,)) solver = sinkhorn_lr.LRSinkhorn(rank=rank, initializer=initializer) initializer = solver.create_initializer(prob) @@ -99,7 +101,7 @@ def test_partial_initialization( @pytest.mark.fast.with_args("rank", [2, 4, 10, 13], only_fast=True) def test_generalized_k_means_has_correct_rank( - self, rng: jnp.ndarray, rank: int + self, rng: jax.random.PRNGKeyArray, rank: int ): n, d = 100, 10 x = jax.random.normal(rng, (n, d)) @@ -116,12 +118,14 @@ def test_generalized_k_means_has_correct_rank( assert jnp.linalg.matrix_rank(q) == rank assert jnp.linalg.matrix_rank(r) == rank - def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): + def test_generalized_k_means_matches_k_means( + self, rng: jax.random.PRNGKeyArray + ): n, d, rank = 120, 15, 5 eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key1, (n, d)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (n, d)) + y = jax.random.normal(rng1, (n, d)) pc = pointcloud.PointCloud(x, y, epsilon=eps) geom = geometry.Geometry(cost_matrix=pc.cost_matrix, epsilon=eps) @@ -146,11 +150,13 @@ def test_generalized_k_means_matches_k_means(self, rng: jnp.ndarray): ) @pytest.mark.parametrize("epsilon", [0., 1e-1]) - def test_better_initialization_helps(self, rng: jnp.ndarray, epsilon: float): + def test_better_initialization_helps( + self, rng: jax.random.PRNGKeyArray, epsilon: float + ): n, d, rank = 81, 13, 3 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key2, (n, d)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (n, d)) + y = jax.random.normal(rng2, (n, d)) pc = pointcloud.PointCloud(x, y, epsilon=5e-1) prob = linear_problem.LinearProblem(pc) diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index f6c774e73..1b111f1ab 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -16,7 +16,6 @@ import pytest import jax -import jax.numpy as jnp import numpy as np from ott.geometry import geometry, pointcloud @@ -30,12 +29,14 @@ class TestQuadraticInitializers: @pytest.mark.parametrize("kind", ["pc", "lrc", "geom"]) - def test_create_default_lr_initializer(self, rng: jnp.ndarray, kind: str): + def test_create_default_lr_initializer( + self, rng: jax.random.PRNGKeyArray, kind: str + ): n, d1, d2, rank = 150, 2, 3, 5 eps = 1e-1 - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d1)) - y = jax.random.normal(key1, (n, d2)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (n, d1)) + y = jax.random.normal(rng1, (n, d2)) kwargs_init = {"foo": "bar"} geom_x = pointcloud.PointCloud(x, epsilon=eps) @@ -93,18 +94,20 @@ def test_explicitly_passing_initializer(self, rank: int): assert solver.quad_initializer.rank == rank @pytest.mark.parametrize("eps", [0., 1e-2]) - def test_gw_better_initialization_helps(self, rng: jnp.ndarray, eps: float): + def test_gw_better_initialization_helps( + self, rng: jax.random.PRNGKeyArray, eps: float + ): n, m, d1, d2, rank = 123, 124, 12, 10, 5 - key1, key2, key3, key4 = jax.random.split(rng, 4) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) geom_x = pointcloud.PointCloud( - jax.random.normal(key1, (n, d1)), - jax.random.normal(key2, (n, d1)), + jax.random.normal(rng1, (n, d1)), + jax.random.normal(rng2, (n, d1)), epsilon=eps, ) geom_y = pointcloud.PointCloud( - jax.random.normal(key3, (m, d2)), - jax.random.normal(key4, (m, d2)), + jax.random.normal(rng3, (m, d2)), + jax.random.normal(rng4, (m, d2)), epsilon=eps, ) problem = quadratic_problem.QuadraticProblem(geom_x, geom_y) diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index 8c42e99ae..5f8665807 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -25,14 +25,14 @@ @pytest.mark.fast class TestGeometryLse: - def test_lse(self, rng: jnp.ndarray): + def test_lse(self, rng: jax.random.PRNGKeyArray): """Test consistency of custom lse's jvp.""" n, m = 12, 8 - keys = jax.random.split(rng, 5) - mat = jax.random.normal(keys[0], (n, m)) + rngs = jax.random.split(rng, 5) + mat = jax.random.normal(rngs[0], (n, m)) # picking potentially negative weights on purpose - b_0 = jax.random.normal(keys[1], (m,)) - b_1 = jax.random.normal(keys[2], (n, 1)) + b_0 = jax.random.normal(rngs[1], (m,)) + b_1 = jax.random.normal(rngs[2], (n, 1)) def lse_(x, axis, b, return_sign): out = mu.logsumexp(x, axis, False, b, return_sign) @@ -41,7 +41,7 @@ def lse_(x, axis, b, return_sign): lse = jax.value_and_grad(lse_, argnums=(0, 2)) for axis in (0, 1): _, g = lse(mat, axis, None, False) - delta_mat = jax.random.normal(keys[3], (n, m)) + delta_mat = jax.random.normal(rngs[3], (n, m)) eps = 1e-3 val_peps = lse(mat + eps * delta_mat, axis, None, False)[0] val_meps = lse(mat - eps * delta_mat, axis, None, False)[0] @@ -50,7 +50,7 @@ def lse_(x, axis, b, return_sign): rtol=1e-03, atol=1e-02) for b, dim, axis in zip((b_0, b_1), (m, n), (1, 0)): - delta_b = jax.random.normal(keys[4], (dim,)).reshape(b.shape) + delta_b = jax.random.normal(rngs[4], (dim,)).reshape(b.shape) _, g = lse(mat, axis, b, True) eps = 1e-3 val_peps = lse(mat + eps * delta_mat, axis, b + eps * delta_b, True)[0] diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 505823852..71e46c1c8 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -23,24 +23,24 @@ from ott.math import matrix_square_root -def _get_random_spd_matrix(dim: int, key: jnp.ndarray): +def _get_random_spd_matrix(dim: int, rng: jax.random.PRNGKeyArray): # Get a random symmetric, positive definite matrix of a specified size. - key, subkey0, subkey1 = jax.random.split(key, num=3) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) # Step 1: generate a random orthogonal matrix - m = jax.random.normal(key=subkey0, shape=[dim, dim]) + m = jax.random.normal(key=subrng0, shape=[dim, dim]) q, _ = jnp.linalg.qr(m) # Step 2: generate random eigenvalues in [1/2. , 2.] to ensure the condition # number is reasonable. - eigs = 2. ** (2. * jax.random.uniform(key=subkey1, shape=(dim,)) - 1.) + eigs = 2. ** (2. * jax.random.uniform(key=subrng1, shape=(dim,)) - 1.) return jnp.matmul(eigs[None, :] * q, jnp.transpose(q)) def _get_test_fn( - fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, key: jnp.ndarray, - **kwargs: Any + fn: Callable[[jnp.ndarray], jnp.ndarray], dim: int, + rng: jax.random.PRNGKeyArray, **kwargs: Any ) -> Callable[[jnp.ndarray], jnp.ndarray]: # We want to test gradients of a function fn that maps positive definite # matrices to positive definite matrices by comparing them to finite @@ -49,11 +49,11 @@ def _get_test_fn( # (2) maps the real to a positive definite matrix, # (3) applies fn, then # (4) maps the matrix-valued output of fn to a scalar. - key, subkey0, subkey1, subkey2, subkey3 = jax.random.split(key, num=5) - m0 = _get_random_spd_matrix(dim=dim, key=subkey0) - m1 = _get_random_spd_matrix(dim=dim, key=subkey1) - dx = _get_random_spd_matrix(dim=dim, key=subkey2) - unit = jax.random.normal(key=subkey3, shape=(dim, dim)) + rng, subrng0, subrng1, subrng2, subrng3 = jax.random.split(rng, num=5) + m0 = _get_random_spd_matrix(dim=dim, rng=subrng0) + m1 = _get_random_spd_matrix(dim=dim, rng=subrng1) + dx = _get_random_spd_matrix(dim=dim, rng=subrng2) + unit = jax.random.normal(key=subrng3, shape=(dim, dim)) unit /= jnp.sqrt(jnp.sum(unit ** 2.)) def _test_fn(x: jnp.ndarray, **kwargs: Any) -> jnp.ndarray: @@ -73,7 +73,7 @@ def _sqrt_plus_inv_sqrt(x: jnp.ndarray) -> jnp.ndarray: class TestMatrixSquareRoot: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 13 self.batch = 3 # Values for testing the Sylvester solver @@ -81,14 +81,14 @@ def initialize(self, rng: jnp.ndarray): # Shapes: A = (m, m), B = (n, n), C = (m, n), X = (m, n) m = 3 n = 2 - key, subkey0, subkey1, subkey2 = jax.random.split(rng, 4) - self.a = jax.random.normal(key=subkey0, shape=(2, m, m)) - self.b = jax.random.normal(key=subkey1, shape=(2, n, n)) - self.x = jax.random.normal(key=subkey2, shape=(2, m, n)) + rng, subrng0, subrng1, subrng2 = jax.random.split(rng, 4) + self.a = jax.random.normal(key=subrng0, shape=(2, m, m)) + self.b = jax.random.normal(key=subrng1, shape=(2, n, n)) + self.x = jax.random.normal(key=subrng2, shape=(2, m, n)) # make sure the system has a solution self.c = jnp.matmul(self.a, self.x) - jnp.matmul(self.x, self.b) - self.rng = key + self.rng = rng def test_sqrtm(self): """Sample a random p.s.d. (Wishart) matrix, check its sqrt matches.""" @@ -194,10 +194,10 @@ def test_grad( self, enable_x64, fn: Callable, n_tests: int, dim: int, epsilon: float, atol: float, rtol: float ): - key = self.rng + rng = self.rng for _ in range(n_tests): - key, subkey = jax.random.split(key) - test_fn = _get_test_fn(fn, dim=dim, key=subkey, threshold=1e-5) + rng, subrng = jax.random.split(rng) + test_fn = _get_test_fn(fn, dim=dim, rng=subrng, threshold=1e-5) expected = (test_fn(epsilon) - test_fn(-epsilon)) / (2. * epsilon) actual = jax.grad(test_fn)(0.) np.testing.assert_allclose(actual, expected, atol=atol, rtol=rtol) diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index b5535eb33..b09fdc502 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -25,13 +25,13 @@ class TestEntropicPotentials: def test_device_put(self, rng: jax.random.PRNGKeyArray): n = 10 device = jax.devices()[0] - keys = jax.random.split(rng, 5) - f = jax.random.normal(keys[0], (n,)) - g = jax.random.normal(keys[1], (n,)) + rngs = jax.random.split(rng, 5) + f = jax.random.normal(rngs[0], (n,)) + g = jax.random.normal(rngs[1], (n,)) - geom = pointcloud.PointCloud(jax.random.normal(keys[2], (n, 3))) - a = jax.random.normal(keys[4], (n, 3)) - b = jax.random.normal(keys[5], (n, 3)) + geom = pointcloud.PointCloud(jax.random.normal(rngs[2], (n, 3))) + a = jax.random.normal(rngs[4], (n, 3)) + b = jax.random.normal(rngs[5], (n, 3)) prob = linear_problem.LinearProblem(geom, a, b) pot = potentials.EntropicPotentials(f, g, prob) @@ -39,16 +39,18 @@ def test_device_put(self, rng: jax.random.PRNGKeyArray): _ = jax.device_put(pot, device) @pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0) - def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): + def test_entropic_potentials_dist( + self, rng: jax.random.PRNGKeyArray, eps: float + ): n1, n2, d = 64, 96, 2 - key1, key2, key3, key4 = jax.random.split(rng, 4) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) mean1, mean2 = jnp.zeros(d), jnp.ones(d) * 2 cov1, cov2 = jnp.eye(d), jnp.array([[2, 0], [0, 0.5]]) g1 = gaussian.Gaussian.from_mean_and_cov(mean1, cov1) g2 = gaussian.Gaussian.from_mean_and_cov(mean2, cov2) - x = g1.sample(key1, n1) - y = g2.sample(key2, n2) + x = g1.sample(rng1, n1) + y = g2.sample(rng2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) prob = linear_problem.LinearProblem(geom) @@ -63,18 +65,18 @@ def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float): @pytest.mark.fast.with_args(forward=[False, True], only_fast=0) def test_entropic_potentials_displacement( - self, rng: jnp.ndarray, forward: bool + self, rng: jax.random.PRNGKeyArray, forward: bool ): n1, n2, d = 96, 128, 2 eps = 1e-2 - key1, key2, key3, key4 = jax.random.split(rng, 4) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) mean1, mean2 = jnp.zeros(d), jnp.ones(d) * 2 cov1, cov2 = jnp.eye(d), jnp.array([[1.5, 0], [0, 0.8]]) g1 = gaussian.Gaussian.from_mean_and_cov(mean1, cov1) g2 = gaussian.Gaussian.from_mean_and_cov(mean2, cov2) - x = g1.sample(key1, n1) - y = g2.sample(key2, n2) + x = g1.sample(rng1, n1) + y = g2.sample(rng2, n2) geom = pointcloud.PointCloud(x, y, epsilon=eps) prob = linear_problem.LinearProblem(geom) @@ -82,8 +84,8 @@ def test_entropic_potentials_displacement( assert out.converged potentials = out.to_dual_potentials() - x_test = g1.sample(key3, n1 + 1) - y_test = g2.sample(key4, n2 + 2) + x_test = g1.sample(rng3, n1 + 1) + y_test = g2.sample(rng4, n2 + 2) if forward: expected_points = g1.transport(g2, x_test) actual_points = potentials.transport(x_test, forward=forward) @@ -99,16 +101,16 @@ def test_entropic_potentials_displacement( p=[1.3, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_sqpnorm( - self, rng: jnp.ndarray, p: float, forward: bool + self, rng: jax.random.PRNGKeyArray, p: float, forward: bool ): epsilon = None cost_fn = costs.SqPNorm(p=p) n1, n2, d = 93, 127, 2 eps = 1e-2 - keys = jax.random.split(rng, 4) + rngs = jax.random.split(rng, 4) - x = jax.random.uniform(keys[0], (n1, d)) - y = jax.random.normal(keys[1], (n2, d)) + 2 + x = jax.random.uniform(rngs[0], (n1, d)) + y = jax.random.normal(rngs[1], (n2, d)) + 2 geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn) prob = linear_problem.LinearProblem(geom) @@ -116,8 +118,8 @@ def test_entropic_potentials_sqpnorm( assert out.converged potentials = out.to_dual_potentials() - x_test = jax.random.uniform(keys[2], (n1 + 3, d)) - y_test = jax.random.normal(keys[3], (n2 + 5, d)) + 2 + x_test = jax.random.uniform(rngs[2], (n1 + 3, d)) + y_test = jax.random.normal(rngs[3], (n2 + 5, d)) + 2 sdiv = lambda x, y: sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, x, y, cost_fn=cost_fn, epsilon=epsilon @@ -138,16 +140,16 @@ def test_entropic_potentials_sqpnorm( p=[1.45, 2.2, 1.0], forward=[False, True], only_fast=0 ) def test_entropic_potentials_pnorm( - self, rng: jnp.ndarray, p: float, forward: bool + self, rng: jax.random.PRNGKeyArray, p: float, forward: bool ): epsilon = None cost_fn = costs.PNormP(p=p) n1, n2, d = 43, 77, 2 eps = 1e-2 - keys = jax.random.split(rng, 4) + rngs = jax.random.split(rng, 4) - x = jax.random.uniform(keys[0], (n1, d)) - y = jax.random.normal(keys[1], (n2, d)) + 2 + x = jax.random.uniform(rngs[0], (n1, d)) + y = jax.random.normal(rngs[1], (n2, d)) + 2 geom = pointcloud.PointCloud(x, y, epsilon=eps, cost_fn=cost_fn) prob = linear_problem.LinearProblem(geom) @@ -155,8 +157,8 @@ def test_entropic_potentials_pnorm( assert out.converged potentials = out.to_dual_potentials() - x_test = jax.random.uniform(keys[2], (n1 + 3, d)) - y_test = jax.random.normal(keys[3], (n2 + 5, d)) + 2 + x_test = jax.random.uniform(rngs[2], (n1 + 3, d)) + y_test = jax.random.normal(rngs[3], (n2 + 5, d)) + 2 sdiv = lambda x, y: sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, x, y, cost_fn=cost_fn, epsilon=epsilon @@ -177,14 +179,16 @@ def test_entropic_potentials_pnorm( assert div < .1 * div_0 # check we have moved points much closer to target. @pytest.mark.parametrize("jit", [False, True]) - def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): - key1, key2, key3 = jax.random.split(rng, 3) + def test_distance_differentiability( + self, rng: jax.random.PRNGKeyArray, jit: bool + ): + rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 18, 36, 5 - x = jax.random.normal(key1, (n, d)) - y = jax.random.normal(key2, (m, d)) + x = jax.random.normal(rng1, (n, d)) + y = jax.random.normal(rng2, (m, d)) prob = linear_problem.LinearProblem(pointcloud.PointCloud(x, y)) - v_x = jax.random.normal(key3, shape=x.shape) + v_x = jax.random.normal(rng3, shape=x.shape) v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * 1e-3 pots = sinkhorn.Sinkhorn()(prob).to_dual_potentials() @@ -199,15 +203,17 @@ def test_distance_differentiability(self, rng: jnp.ndarray, jit: bool): np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4) @pytest.mark.parametrize("eps", [None, 1e-1, 1e1, 1e2, 1e3]) - def test_potentials_sinkhorn_divergence(self, rng: jnp.ndarray, eps: float): - key1, key2, key3 = jax.random.split(rng, 3) + def test_potentials_sinkhorn_divergence( + self, rng: jax.random.PRNGKeyArray, eps: float + ): + rng1, rng2, rng3 = jax.random.split(rng, 3) n, m, d = 32, 36, 4 fwd = True mu0, mu1 = -5., 5. - x = jax.random.normal(key1, (n, d)) + mu0 - y = jax.random.normal(key2, (m, d)) + mu1 - x_test = jax.random.normal(key3, (n, d)) + mu0 + x = jax.random.normal(rng1, (n, d)) + mu0 + y = jax.random.normal(rng2, (m, d)) + mu1 + x_test = jax.random.normal(rng3, (n, d)) + mu0 geom = pointcloud.PointCloud(x, y, epsilon=eps) prob = linear_problem.LinearProblem(geom) diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index 2c11263e8..e1a336f9a 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -52,8 +52,8 @@ class TestBarycenter: }, ) def test_euclidean_barycenter( - self, rng: jnp.ndarray, rank: int, epsilon: float, init_random: bool, - jit: bool + self, rng: jax.random.PRNGKeyArray, rank: int, epsilon: float, + init_random: bool, jit: bool ): rngs = jax.random.split(rng, 20) # Sample 2 point clouds, each of size 113, the first around [0,1]^4, @@ -117,7 +117,9 @@ def test_euclidean_barycenter( assert jnp.all(out.x.ravel() > .7) @pytest.mark.parametrize("segment_before", [False, True]) - def test_barycenter_jit(self, rng: jnp.ndarray, segment_before: bool): + def test_barycenter_jit( + self, rng: jax.random.PRNGKeyArray, segment_before: bool + ): @functools.partial(jax.jit, static_argnums=(2, 3)) def barycenter( @@ -181,7 +183,7 @@ def barycenter( ) def test_bures_barycenter( self, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, lse_mode: bool, epsilon: float, jit: bool, @@ -278,7 +280,7 @@ def test_bures_barycenter( ) def test_bures_barycenter_different_number_of_components( self, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, alpha: float, epsilon: float, dim: int, @@ -293,21 +295,21 @@ def test_bures_barycenter_different_number_of_components( b_cost = costs.Bures(dimension=dim) # keys for random number generation - keys = jax.random.split(rng, num=4) + rngs = jax.random.split(rng, num=4) # test for non-uniform barycentric weights barycentric_weights = jax.random.dirichlet( - keys[0], alpha=jnp.ones(num_measures) * alpha + rngs[0], alpha=jnp.ones(num_measures) * alpha ) ridges = jnp.array([jnp.ones(dim), 5 * jnp.ones(dim)]) stdev_means = 0.1 * jnp.mean(ridges, axis=1) stdev_covs = jax.random.uniform( - keys[1], shape=(num_measures,), minval=0., maxval=10. + rngs[1], shape=(num_measures,), minval=0., maxval=10. ) seeds = jax.random.randint( - keys[2], shape=(num_measures,), minval=0, maxval=100 + rngs[2], shape=(num_measures,), minval=0, maxval=100 ) gmm_generators = [ @@ -339,7 +341,7 @@ def test_bures_barycenter_different_number_of_components( # random initialization of the barycenter gmm_generator = gaussian_mixture.GaussianMixture.from_random( - keys[3], n_components=bar_size, n_dimensions=dim + rngs[3], n_components=bar_size, n_dimensions=dim ) x_init_means = gmm_generator.loc x_init_covs = gmm_generator.covariance diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index a929fe618..3abe2cd83 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -31,7 +31,7 @@ class TestSinkhornImplicit: """Check implicit and autodiff match for Sinkhorn.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 3 self.n = 38 self.m = 73 @@ -140,17 +140,18 @@ class TestSinkhornJacobian: only_fast=0, ) def test_autograd_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, shape_data: Tuple[int, + int] ): """Test gradient w.r.t. probability weights.""" n, m = shape_data d = 3 eps = 1e-3 # perturbation magnitude - keys = jax.random.split(rng, 5) - x = jax.random.uniform(keys[0], (n, d)) - y = jax.random.uniform(keys[1], (m, d)) - a = jax.random.uniform(keys[2], (n,)) + eps - b = jax.random.uniform(keys[3], (m,)) + eps + rngs = jax.random.split(rng, 5) + x = jax.random.uniform(rngs[0], (n, d)) + y = jax.random.uniform(rngs[1], (m, d)) + a = jax.random.uniform(rngs[2], (n,)) + eps + b = jax.random.uniform(rngs[3], (m,)) + eps # Adding zero weights to test proper handling a = a.at[0].set(0) b = b.at[3].set(0) @@ -166,7 +167,7 @@ def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: reg_ot_and_grad = jax.jit(jax.grad(reg_ot)) grad_reg_ot = reg_ot_and_grad(a, b) - delta = jax.random.uniform(keys[4], (n,)) + delta = jax.random.uniform(rngs[4], (n,)) delta = delta * (a > 0) # ensures only perturbing non-zero coords. delta = delta - jnp.sum(delta) / jnp.sum(a > 0) # center perturbation delta = delta * (a > 0) # ensures only perturbing non-zero coords. @@ -185,13 +186,14 @@ def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: "lse_mode,shape_data", [(True, (7, 9)), (False, (11, 5))] ) def test_gradient_sinkhorn_geometry( - self, rng: jnp.ndarray, lse_mode: bool, shape_data: Tuple[int, int] + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, shape_data: Tuple[int, + int] ): """Test gradient w.r.t. cost matrix.""" n, m = shape_data - keys = jax.random.split(rng, 2) - cost_matrix = jnp.abs(jax.random.normal(keys[0], (n, m))) - delta = jax.random.normal(keys[1], (n, m)) + rngs = jax.random.split(rng, 2) + cost_matrix = jnp.abs(jax.random.normal(rngs[0], (n, m))) + delta = jax.random.normal(rngs[1], (n, m)) delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-3 # perturbation magnitude @@ -248,19 +250,19 @@ def loss_fn(cm: jnp.ndarray): only_fast=[0, 1], ) def test_gradient_sinkhorn_euclidean( - self, rng: jnp.ndarray, lse_mode: bool, implicit: bool, min_iter: int, - max_iter: int, epsilon: float, cost_fn: costs.CostFn + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, implicit: bool, + min_iter: int, max_iter: int, epsilon: float, cost_fn: costs.CostFn ): """Test gradient w.r.t. locations x of reg-ot-cost.""" # TODO(cuturi): ensure scaling mode works with backprop. d = 3 n, m = 11, 13 - keys = jax.random.split(rng, 4) - x = jax.random.normal(keys[0], (n, d)) / 10 - y = jax.random.normal(keys[1], (m, d)) / 10 + rngs = jax.random.split(rng, 4) + x = jax.random.normal(rngs[0], (n, d)) / 10 + y = jax.random.normal(rngs[1], (m, d)) / 10 - a = jax.random.uniform(keys[2], (n,)) - b = jax.random.uniform(keys[3], (m,)) + a = jax.random.uniform(rngs[2], (n,)) + b = jax.random.uniform(rngs[3], (m,)) # Adding zero weights to test proper handling a = a.at[0].set(0) b = b.at[3].set(0) @@ -285,7 +287,7 @@ def loss_fn(x: jnp.ndarray, out = solver(prob) return out.reg_ot_cost, out - delta = jax.random.normal(keys[0], (n, d)) + delta = jax.random.normal(rngs[0], (n, d)) delta = delta / jnp.sqrt(jnp.vdot(delta, delta)) eps = 1e-5 # perturbation magnitude @@ -324,7 +326,7 @@ def loss_fn(x: jnp.ndarray, ) np.testing.assert_array_equal(jnp.isnan(custom_grad), False) - def test_autoepsilon_differentiability(self, rng: jnp.ndarray): + def test_autoepsilon_differentiability(self, rng: jax.random.PRNGKeyArray): cost = jax.random.uniform(rng, (15, 17)) def reg_ot_cost(c: jnp.ndarray) -> float: @@ -336,7 +338,7 @@ def reg_ot_cost(c: jnp.ndarray) -> float: np.testing.assert_array_equal(jnp.isnan(gradient), False) @pytest.mark.fast - def test_differentiability_with_jit(self, rng: jnp.ndarray): + def test_differentiability_with_jit(self, rng: jax.random.PRNGKeyArray): def reg_ot_cost(c: jnp.ndarray) -> float: geom = geometry.Geometry(c, epsilon=1e-2) @@ -357,8 +359,8 @@ def reg_ot_cost(c: jnp.ndarray) -> float: only_fast=0, ) def test_apply_transport_jacobian( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, - shape: Tuple[int, int], arg: int, axis: int + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, + tau_b: float, shape: Tuple[int, int], arg: int, axis: int ): """Tests Jacobian of application of OT to vector, w.r.t. @@ -465,8 +467,8 @@ def apply_ot(a: jnp.ndarray, x: jnp.ndarray, implicit: bool) -> jnp.ndarray: only_fast=0, ) def test_potential_jacobian_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, - shape: Tuple[int, int], arg: int + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, + tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential w.r.t. weights and locations.""" n, m = shape @@ -544,18 +546,18 @@ class TestSinkhornGradGrid: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_x_perturbation( - self, rng: jnp.ndarray, lse_mode: bool + self, rng: jax.random.PRNGKeyArray, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-3 # perturbation magnitude - keys = jax.random.split(rng, 6) + rngs = jax.random.split(rng, 6) x = ( jnp.array([.0, 1.0]), jnp.array([.3, .4, .7]), jnp.array([1.0, 1.3, 2.4, 3.7]) ) grid_size = tuple(xs.shape[0] for xs in x) - a = jax.random.uniform(keys[0], grid_size) + 1.0 - b = jax.random.uniform(keys[1], grid_size) + 1.0 + a = jax.random.uniform(rngs[0], grid_size) + 1.0 + b = jax.random.uniform(rngs[1], grid_size) + 1.0 a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) @@ -567,7 +569,7 @@ def reg_ot(x: List[jnp.ndarray]) -> float: reg_ot_and_grad = jax.grad(reg_ot) grad_reg_ot = reg_ot_and_grad(x) - delta = [jax.random.uniform(keys[i], (g,)) for i, g in enumerate(grid_size)] + delta = [jax.random.uniform(rngs[i], (g,)) for i, g in enumerate(grid_size)] x_p_delta = [(xs + eps * delt) for xs, delt in zip(x, delta)] x_m_delta = [(xs - eps * delt) for xs, delt in zip(x, delta)] @@ -589,11 +591,11 @@ def reg_ot(x: List[jnp.ndarray]) -> float: @pytest.mark.parametrize("lse_mode", [False, True]) def test_diff_sinkhorn_x_grid_weights_perturbation( - self, rng: jnp.ndarray, lse_mode: bool + self, rng: jax.random.PRNGKeyArray, lse_mode: bool ): """Test gradient w.r.t. probability weights.""" eps = 1e-4 # perturbation magnitude - keys = jax.random.split(rng, 3) + rngs = jax.random.split(rng, 3) # yapf: disable x = ( jnp.asarray([.0, 1.0]), @@ -602,8 +604,8 @@ def test_diff_sinkhorn_x_grid_weights_perturbation( ) # yapf: enable grid_size = tuple(xs.shape[0] for xs in x) - a = jax.random.uniform(keys[0], grid_size) + 1 - b = jax.random.uniform(keys[1], grid_size) + 1 + a = jax.random.uniform(rngs[0], grid_size) + 1 + b = jax.random.uniform(rngs[1], grid_size) + 1 a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom = grid.Grid(x=x, epsilon=1) @@ -615,7 +617,7 @@ def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: reg_ot_and_grad = jax.grad(reg_ot) grad_reg_ot = reg_ot_and_grad(a, b) - delta = jax.random.uniform(keys[2], grid_size).ravel() + delta = jax.random.uniform(rngs[2], grid_size).ravel() delta = delta - jnp.mean(delta) # center perturbation @@ -640,8 +642,8 @@ class TestSinkhornJacobianPreconditioning: only_fast=[0, -1], ) def test_potential_jacobian_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, - shape: Tuple[int, int], arg: int + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, + tau_b: float, shape: Tuple[int, int], arg: int ): """Test Jacobian of optimal potential w.r.t. weights and locations.""" n, m = shape @@ -737,8 +739,8 @@ class TestSinkhornHessian: only_fast=-1 ) def test_hessian_sinkhorn( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, - shape: Tuple[int, int], arg: int + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, + tau_b: float, shape: Tuple[int, int], arg: int ): """Test hessian w.r.t. weights and locations.""" # TODO(cuturi): reinstate this flag to True when JAX bug fixed. diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index 77fecad5d..764a6c3fd 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -27,12 +27,12 @@ class TestSinkhornGrid: @pytest.mark.parametrize("lse_mode", [False, True]) - def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): + def test_separable_grid(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): """Two histograms in a grid of size 5 x 6 x 7 in the hypercube^3.""" grid_size = (5, 6, 7) - keys = jax.random.split(rng, 2) - a = jax.random.uniform(keys[0], grid_size) - b = jax.random.uniform(keys[1], grid_size) + rngs = jax.random.split(rng, 2) + a = jax.random.uniform(rngs[0], grid_size) + b = jax.random.uniform(rngs[1], grid_size) # adding zero weights to test proper handling, then ravel. a = a.at[0].set(0).ravel() a = a / jnp.sum(a) @@ -48,11 +48,13 @@ def test_separable_grid(self, rng: jnp.ndarray, lse_mode: bool): assert threshold > err @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=0) - def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): + def test_grid_vs_euclidean( + self, rng: jax.random.PRNGKeyArray, lse_mode: bool + ): grid_size = (5, 6, 7) - keys = jax.random.split(rng, 2) - a = jax.random.uniform(keys[0], grid_size) - b = jax.random.uniform(keys[1], grid_size) + rngs = jax.random.split(rng, 2) + a = jax.random.uniform(rngs[0], grid_size) + b = jax.random.uniform(rngs[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) epsilon = 0.1 @@ -71,11 +73,13 @@ def test_grid_vs_euclidean(self, rng: jnp.ndarray, lse_mode: bool): ) @pytest.mark.fast.with_args("lse_mode", [False, True], only_fast=1) - def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): + def test_apply_transport_grid( + self, rng: jax.random.PRNGKeyArray, lse_mode: bool + ): grid_size = (5, 6, 7) - keys = jax.random.split(rng, 4) - a = jax.random.uniform(keys[0], grid_size) - b = jax.random.uniform(keys[1], grid_size) + rngs = jax.random.split(rng, 4) + a = jax.random.uniform(rngs[0], grid_size) + b = jax.random.uniform(rngs[1], grid_size) a = a.ravel() / jnp.sum(a) b = b.ravel() / jnp.sum(b) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) @@ -91,8 +95,8 @@ def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): batch_a = 3 batch_b = 4 - vec_a = jax.random.normal(keys[2], [batch_a, np.prod(np.array(grid_size))]) - vec_b = jax.random.normal(keys[3], [batch_b, np.prod(grid_size)]) + vec_a = jax.random.normal(rngs[2], [batch_a, np.prod(np.array(grid_size))]) + vec_b = jax.random.normal(rngs[3], [batch_b, np.prod(grid_size)]) vec_a = vec_a / jnp.sum(vec_a, axis=1)[:, jnp.newaxis] vec_b = vec_b / jnp.sum(vec_b, axis=1)[:, jnp.newaxis] @@ -120,7 +124,7 @@ def test_apply_transport_grid(self, rng: jnp.ndarray, lse_mode: bool): np.testing.assert_array_equal(jnp.isnan(mat_transport_t_vec_a), False) @pytest.mark.fast - def test_apply_cost(self, rng: jnp.ndarray): + def test_apply_cost(self, rng: jax.random.PRNGKeyArray): grid_size = (5, 6, 7) geom_grid = grid.Grid(grid_size=grid_size, epsilon=0.1) diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index fd13c03cb..f2ba405eb 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -26,7 +26,7 @@ class TestLRSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 4 self.n = 33 self.m = 37 diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index 0540353f9..cb48c1c87 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -40,8 +40,8 @@ class TestSinkhornAnderson: only_fast=0, ) def test_anderson( - self, rng: jnp.ndarray, lse_mode: bool, tau_a: float, tau_b: float, - shape: Tuple[int, int], refresh_anderson_frequency: int + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, tau_a: float, + tau_b: float, shape: Tuple[int, int], refresh_anderson_frequency: int ): """Test efficiency of Anderson acceleration. @@ -134,7 +134,8 @@ def initialize(self): @pytest.mark.parametrize("lse_mode", [False, True]) @pytest.mark.parametrize("unbalanced,thresh", [(False, 1e-3), (True, 1e-4)]) def test_bures_point_cloud( - self, rng: jnp.ndarray, lse_mode: bool, unbalanced: bool, thresh: float + self, rng: jax.random.PRNGKeyArray, lse_mode: bool, unbalanced: bool, + thresh: float ): """Two point clouds of Gaussians, tested with various parameters.""" if unbalanced: @@ -173,7 +174,7 @@ def test_regularized_unbalanced_bures_cost(self): class TestSinkhornOnline: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 3 self.n = 1000 self.m = 402 @@ -238,7 +239,7 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: class TestSinkhornUnbalanced: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 4 self.n = 17 self.m = 23 @@ -319,7 +320,7 @@ class TestSinkhornJIT: """Check jitted and non jit match for Sinkhorn, and that everything jits.""" @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.dim = 3 self.n = 10 self.m = 11 diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 9fde1f3aa..0a432121d 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -28,7 +28,7 @@ class TestSinkhorn: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.rng = rng self.dim = 4 self.n = 17 @@ -287,11 +287,11 @@ def test_online_vs_batch_euclidean_point_cloud(self, lse_mode: bool): def test_apply_transport_geometry_from_potentials(self): """Applying transport matrix P on vector without instantiating P.""" n, m, d = 160, 230, 6 - keys = jax.random.split(self.rng, 6) - x = jax.random.uniform(keys[0], (n, d)) - y = jax.random.uniform(keys[1], (m, d)) - a = jax.random.uniform(keys[2], (n,)) - b = jax.random.uniform(keys[3], (m,)) + rngs = jax.random.split(self.rng, 6) + x = jax.random.uniform(rngs[0], (n, d)) + y = jax.random.uniform(rngs[1], (m, d)) + a = jax.random.uniform(rngs[2], (n,)) + b = jax.random.uniform(rngs[3], (m,)) a = a / jnp.sum(a) b = b / jnp.sum(b) transport_t_vec_a = [None, None, None, None] @@ -299,8 +299,8 @@ def test_apply_transport_geometry_from_potentials(self): batch_b = 8 - vec_a = jax.random.normal(keys[4], (n,)) - vec_b = jax.random.normal(keys[5], (batch_b, m)) + vec_a = jax.random.normal(rngs[4], (n,)) + vec_b = jax.random.normal(rngs[5], (batch_b, m)) # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): @@ -341,11 +341,11 @@ def test_apply_transport_geometry_from_potentials(self): def test_apply_transport_geometry_from_scalings(self): """Applying transport matrix P on vector without instantiating P.""" n, m, d = 160, 230, 6 - keys = jax.random.split(self.rng, 6) - x = jax.random.uniform(keys[0], (n, d)) - y = jax.random.uniform(keys[1], (m, d)) - a = jax.random.uniform(keys[2], (n,)) - b = jax.random.uniform(keys[3], (m,)) + rngs = jax.random.split(self.rng, 6) + x = jax.random.uniform(rngs[0], (n, d)) + y = jax.random.uniform(rngs[1], (m, d)) + a = jax.random.uniform(rngs[2], (n,)) + b = jax.random.uniform(rngs[3], (m,)) a = a / jnp.sum(a) b = b / jnp.sum(b) transport_t_vec_a = [None, None, None, None] @@ -353,8 +353,8 @@ def test_apply_transport_geometry_from_scalings(self): batch_b = 8 - vec_a = jax.random.normal(keys[4], (n,)) - vec_b = jax.random.normal(keys[5], (batch_b, m)) + vec_a = jax.random.normal(rngs[4], (n,)) + vec_b = jax.random.normal(rngs[5], (batch_b, m)) # test with lse_mode and online = True / False for j, lse_mode in enumerate([True, False]): diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index dfd87066d..afa517a86 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -25,7 +25,7 @@ @pytest.mark.fast class TestICNN: - def test_icnn_convexity(self, rng: jnp.ndarray): + def test_icnn_convexity(self, rng: jax.random.PRNGKeyArray): """Tests convexity of ICNN.""" n_samples, n_features = 10, 2 dim_hidden = (64, 64) @@ -34,12 +34,12 @@ def test_icnn_convexity(self, rng: jnp.ndarray): model = models.ICNN(n_features, dim_hidden=dim_hidden) # initialize model - key1, key2, key3 = jax.random.split(rng, 3) - params = model.init(key1, jnp.ones(n_features))['params'] + rng1, rng2, rng3 = jax.random.split(rng, 3) + params = model.init(rng1, jnp.ones(n_features))['params'] # check convexity - x = jax.random.normal(key1, (n_samples, n_features)) * 0.1 - y = jax.random.normal(key2, (n_samples, n_features)) + x = jax.random.normal(rng1, (n_samples, n_features)) * 0.1 + y = jax.random.normal(rng2, (n_samples, n_features)) out_x = model.apply({'params': params}, x) out_y = model.apply({'params': params}, y) @@ -51,7 +51,7 @@ def test_icnn_convexity(self, rng: jnp.ndarray): np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) - def test_icnn_hessian(self, rng: jnp.ndarray): + def test_icnn_hessian(self, rng: jax.random.PRNGKeyArray): """Tests if Hessian of ICNN is positive-semidefinite.""" # define icnn model @@ -60,11 +60,11 @@ def test_icnn_hessian(self, rng: jnp.ndarray): model = models.ICNN(n_features, dim_hidden=dim_hidden) # initialize model - key1, key2 = jax.random.split(rng) - params = model.init(key1, jnp.ones(n_features))['params'] + rng1, rng2 = jax.random.split(rng) + params = model.init(rng1, jnp.ones(n_features))['params'] # check if Hessian is positive-semidefinite via eigenvalues - data = jax.random.normal(key2, (n_features,)) + data = jax.random.normal(rng2, (n_features,)) # compute Hessian hessian = jax.hessian(model.apply, argnums=1)({'params': params}, data) diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index f4d4fa9aa..55ce18891 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -32,25 +32,25 @@ class TestFusedGromovWasserstein: # TODO(michalk8): refactor me in the future @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): d_x = 2 d_y = 3 d_xy = 4 self.n, self.m = 5, 6 - keys = jax.random.split(rng, 7) - self.x = jax.random.uniform(keys[0], (self.n, d_x)) - self.y = jax.random.uniform(keys[1], (self.m, d_y)) - self.x_2 = jax.random.uniform(keys[0], (self.n, d_xy)) - self.y_2 = jax.random.uniform(keys[1], (self.m, d_xy)) + rngs = jax.random.split(rng, 7) + self.x = jax.random.uniform(rngs[0], (self.n, d_x)) + self.y = jax.random.uniform(rngs[1], (self.m, d_y)) + self.x_2 = jax.random.uniform(rngs[0], (self.n, d_xy)) + self.y_2 = jax.random.uniform(rngs[1], (self.m, d_xy)) self.fused_penalty = 2.0 self.fused_penalty_2 = 0.05 - a = jax.random.uniform(keys[2], (self.n,)) + 0.1 - b = jax.random.uniform(keys[3], (self.m,)) + 0.1 + a = jax.random.uniform(rngs[2], (self.n,)) + 0.1 + b = jax.random.uniform(rngs[3], (self.m,)) + 0.1 self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - self.cx = jax.random.uniform(keys[4], (self.n, self.n)) - self.cy = jax.random.uniform(keys[5], (self.m, self.m)) - self.cxy = jax.random.uniform(keys[6], (self.n, self.m)) + self.cx = jax.random.uniform(rngs[4], (self.n, self.n)) + self.cy = jax.random.uniform(rngs[5], (self.m, self.m)) + self.cxy = jax.random.uniform(rngs[6], (self.n, self.m)) @pytest.mark.fast.with_args("jit", [False, True], only_fast=0) def test_gradient_marginals_fgw_solver(self, jit: bool): @@ -220,7 +220,7 @@ def reg_gw( @pytest.mark.limit_memory("400 MB") @pytest.mark.parametrize("jit", [False, True]) - def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): + def test_fgw_lr_memory(self, rng: jax.random.PRNGKeyArray, jit: bool): # Total memory allocated on CI: 342.5MiB (32bit) rngs = jax.random.split(rng, 4) n, m, d1, d2 = 15_000, 10_000, 2, 3 @@ -248,14 +248,15 @@ def test_fgw_lr_memory(self, rng: jnp.ndarray, jit: bool): @pytest.mark.parametrize("cost_rank", [4, (2, 3, 4)]) def test_fgw_lr_generic_cost_matrix( - self, rng: jnp.ndarray, cost_rank: Union[int, Tuple[int, int, int]] + self, rng: jax.random.PRNGKeyArray, cost_rank: Union[int, Tuple[int, int, + int]] ): - n, m = 70, 100 - key1, key2, key3, key4 = jax.random.split(rng, 4) - x = jax.random.normal(key1, shape=(n, 7)) - y = jax.random.normal(key2, shape=(m, 6)) - xx = jax.random.normal(key3, shape=(n, 5)) - yy = jax.random.normal(key4, shape=(m, 5)) + n, m = 20, 30 + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) + x = jax.random.normal(rng1, shape=(n, 7)) + y = jax.random.normal(rng2, shape=(m, 6)) + xx = jax.random.normal(rng3, shape=(n, 5)) + yy = jax.random.normal(rng4, shape=(m, 5)) geom_x = geometry.Geometry(cost_matrix=x @ x.T) geom_y = geometry.Geometry(cost_matrix=y @ y.T) @@ -268,7 +269,7 @@ def test_fgw_lr_generic_cost_matrix( lr_prob = prob.to_low_rank() assert lr_prob.is_low_rank - solver = gw_solver.GromovWasserstein(rank=5, epsilon=1.0) + solver = gw_solver.GromovWasserstein(rank=5, epsilon=10.0) out = solver(prob) assert solver.rank == 5 diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index 9a03f8c86..be5d71e3b 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -33,13 +33,13 @@ class TestGWBarycenter: def random_pc( n: int, d: int, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, m: Optional[int] = None, **kwargs: Any ) -> pointcloud.PointCloud: - key1, key2 = jax.random.split(rng, 2) - x = jax.random.normal(key1, (n, d)) - y = x if m is None else jax.random.normal(key2, (m, d)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (n, d)) + y = x if m is None else jax.random.normal(rng2, (m, d)) return pointcloud.PointCloud(x, y, **kwargs) @staticmethod @@ -67,7 +67,7 @@ def pad_cost_matrices( [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] ) def test_gw_barycenter( - self, rng: jnp.ndarray, gw_loss: str, bar_size: int, + self, rng: jax.random.PRNGKeyArray, gw_loss: str, bar_size: int, epsilon: Optional[float] ): tol = 1e-3 if gw_loss == "sqeucl" else 1e-1 @@ -121,7 +121,7 @@ def test_gw_barycenter( ) def test_fgw_barycenter( self, - rng: jnp.ndarray, + rng: jax.random.PRNGKeyArray, jit: bool, fused_penalty: float, scale_cost: str, @@ -157,12 +157,12 @@ def barycenter( bar_size, epsilon, = 10, 1e-1 num_per_segment = (7, 12) - key1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) + rng1, *rngs = jax.random.split(rng, len(num_per_segment) + 1) y = jnp.concatenate([ self.random_pc(n, d=self.ndim, rng=rng).x for n, rng in zip(num_per_segment, rngs) ]) - rngs = jax.random.split(key1, len(num_per_segment)) + rngs = jax.random.split(rng1, len(num_per_segment)) y_fused = jnp.concatenate([ self.random_pc(n, d=self.ndim_f, rng=rng).x for n, rng in zip(num_per_segment, rngs) diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index bde02d30b..a9db95902 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -33,14 +33,15 @@ class TestQuadraticProblem: @pytest.mark.parametrize("as_pc", [False, True]) @pytest.mark.parametrize("rank", [-1, 5, (1, 2, 3), (2, 3, 5)]) def test_quad_to_low_rank( - self, rng: jnp.ndarray, as_pc: bool, rank: Union[int, Tuple[int, ...]] + self, rng: jax.random.PRNGKeyArray, as_pc: bool, + rank: Union[int, Tuple[int, ...]] ): n, m, d1, d2, d = 200, 300, 20, 25, 30 - k1, k2, k3, k4 = jax.random.split(rng, 4) - x = jax.random.normal(k1, (n, d1)) - y = jax.random.normal(k2, (m, d2)) - xx = jax.random.normal(k3, (n, d)) - yy = jax.random.normal(k4, (m, d)) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) + x = jax.random.normal(rng1, (n, d1)) + y = jax.random.normal(rng2, (m, d2)) + xx = jax.random.normal(rng3, (n, d)) + yy = jax.random.normal(rng4, (m, d)) geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y) @@ -89,11 +90,13 @@ def test_quad_to_low_rank( assert lr_prob._is_low_rank_convertible assert lr_prob.to_low_rank() is lr_prob - def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): + def test_gw_implicit_conversion_mixed_input( + self, rng: jax.random.PRNGKeyArray + ): n, m, d1, d2 = 200, 300, 20, 25 - k1, k2 = jax.random.split(rng, 2) - x = jax.random.normal(k1, (n, d1)) - y = jax.random.normal(k2, (m, d2)) + rng1, rng2 = jax.random.split(rng, 2) + x = jax.random.normal(rng1, (n, d1)) + y = jax.random.normal(rng2, (m, d2)) geom_xx = pointcloud.PointCloud(x) geom_yy = pointcloud.PointCloud(y).to_LRCGeometry() @@ -109,19 +112,19 @@ def test_gw_implicit_conversion_mixed_input(self, rng: jnp.ndarray): class TestGromovWasserstein: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): d_x = 2 d_y = 3 self.n, self.m = 6, 7 - keys = jax.random.split(rng, 6) - self.x = jax.random.uniform(keys[0], (self.n, d_x)) - self.y = jax.random.uniform(keys[1], (self.m, d_y)) - a = jax.random.uniform(keys[2], (self.n,)) + 1e-1 - b = jax.random.uniform(keys[3], (self.m,)) + 1e-1 + rngs = jax.random.split(rng, 6) + self.x = jax.random.uniform(rngs[0], (self.n, d_x)) + self.y = jax.random.uniform(rngs[1], (self.m, d_y)) + a = jax.random.uniform(rngs[2], (self.n,)) + 1e-1 + b = jax.random.uniform(rngs[3], (self.m,)) + 1e-1 self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - self.cx = jax.random.uniform(keys[4], (self.n, self.n)) - self.cy = jax.random.uniform(keys[5], (self.m, self.m)) + self.cx = jax.random.uniform(rngs[4], (self.n, self.n)) + self.cy = jax.random.uniform(rngs[5], (self.m, self.m)) self.tau_a = 0.8 self.tau_b = 0.9 @@ -306,7 +309,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-3) >= loss_thre(1e-5) @pytest.mark.fast - def test_gw_lr(self, rng: jnp.ndarray): + def test_gw_lr(self, rng: jax.random.PRNGKeyArray): """Checking LR and Entropic have similar outputs on same problem.""" rngs = jax.random.split(rng, 4) n, m, d1, d2 = 24, 17, 2, 3 @@ -326,7 +329,7 @@ def test_gw_lr(self, rng: jnp.ndarray): ot_gw = solver(prob) np.testing.assert_allclose(ot_gwlr.costs, ot_gw.costs, rtol=5e-2) - def test_gw_lr_matches_fused(self, rng: jnp.ndarray): + def test_gw_lr_matches_fused(self, rng: jax.random.PRNGKeyArray): """Checking LR and Entropic have similar outputs on same fused problem.""" rngs = jax.random.split(rng, 5) n, m, d1, d2 = 24, 17, 2, 3 @@ -373,11 +376,11 @@ def test_gw_lr_apply(self, axis: int): np.testing.assert_allclose(res_apply, res_matrix, rtol=1e-5, atol=1e-5) - def test_gw_lr_warm_start_helps(self, rng: jnp.ndarray): + def test_gw_lr_warm_start_helps(self, rng: jax.random.PRNGKeyArray): rank = 3 - key1, key2 = jax.random.split(rng, 2) - geom_x = pointcloud.PointCloud(jax.random.normal(key1, (100, 5))) - geom_y = pointcloud.PointCloud(jax.random.normal(key2, (110, 6))) + rng1, rng2 = jax.random.split(rng, 2) + geom_x = pointcloud.PointCloud(jax.random.normal(rng1, (100, 5))) + geom_y = pointcloud.PointCloud(jax.random.normal(rng2, (110, 6))) prob = quadratic_problem.QuadraticProblem(geom_x, geom_y) solver_cold = gromov_wasserstein.GromovWasserstein( diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index c582d126f..63c3d01a9 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -29,7 +29,7 @@ class TestFitGmmPair: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): mean_generator0 = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator0 = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) @@ -61,9 +61,9 @@ def initialize(self, rng: jnp.ndarray): self.rho = 0.1 self.tau = self.rho / (self.rho + self.epsilon) - self.key, subkey0, subkey1 = jax.random.split(rng, num=3) - self.samples_gmm0 = gmm_generator0.sample(key=subkey0, size=2000) - self.samples_gmm1 = gmm_generator1.sample(key=subkey1, size=2000) + self.rng, subrng0, subrng1 = jax.random.split(rng, num=3) + self.samples_gmm0 = gmm_generator0.sample(rng=subrng0, size=2000) + self.samples_gmm1 = gmm_generator1.sample(rng=subrng1, size=2000) # requires Schur decomposition, which jax does not implement on GPU @pytest.mark.cpu @@ -89,7 +89,7 @@ def test_fit_gmm(self, balanced, weighted): # Fit a GMM to the pooled samples samples = jnp.concatenate([self.samples_gmm0, self.samples_gmm1]) gmm_init = fit_gmm.initialize( - key=self.key, + rng=self.rng, points=samples, point_weights=weights_pooled, n_components=3, diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index e208f9fe6..647e4f7ff 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -26,7 +26,7 @@ class TestFitGmm: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): mean_generator = jnp.array([[2., -1.], [-2., 0.], [4., 3.]]) cov_generator = jnp.array([[[0.2, 0.], [0., 0.1]], [[0.6, 0.], [0., 0.3]], [[0.5, 0.4], [0.4, 0.5]]]) @@ -40,15 +40,15 @@ def initialize(self, rng: jnp.ndarray): ) ) - self.key, subkey = jax.random.split(rng) - self.samples = gmm_generator.sample(key=subkey, size=2000) + self.rng, subrng = jax.random.split(rng) + self.samples = gmm_generator.sample(rng=subrng, size=2000) def test_integration(self): # dumb integration test that makes sure nothing crashes # Fit a GMM to the samples gmm_init = fit_gmm.initialize( - key=self.key, + rng=self.rng, points=self.samples, point_weights=None, n_components=3, diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index 30b2b9612..ba81cdcba 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -25,20 +25,20 @@ class TestGaussianMixturePair: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self.n_components = 3 self.n_dimensions = 2 self.epsilon = 1.e-3 self.rho = 0.1 self.tau = self.rho / (self.rho + self.epsilon) - self.key, subkey0, subkey1 = jax.random.split(rng, num=3) + self.rng, subrng0, subrng1 = jax.random.split(rng, num=3) self.gmm0 = gaussian_mixture.GaussianMixture.from_random( - key=subkey0, + rng=subrng0, n_components=self.n_components, n_dimensions=self.n_dimensions ) self.gmm1 = gaussian_mixture.GaussianMixture.from_random( - key=subkey1, + rng=subrng1, n_components=self.n_components, n_dimensions=self.n_dimensions ) diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index 9e0c68e4f..1fdfbd8db 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -26,18 +26,18 @@ class TestGaussianMixture: def test_get_summary_stats_from_points_and_assignment_probs( - self, rng: jnp.ndarray + self, rng: jax.random.PRNGKeyArray ): n = 50 - key, subkey0, subkey1 = jax.random.split(rng, num=3) - points0 = jax.random.normal(key=subkey0, shape=(n, 2)) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + points0 = jax.random.normal(key=subrng0, shape=(n, 2)) points1 = ( - 2. * jax.random.normal(key=subkey1, shape=(n, 2)) + jnp.array([6., 8.]) + 2. * jax.random.normal(key=subrng1, shape=(n, 2)) + jnp.array([6., 8.]) ) points = jnp.concatenate([points0, points1], axis=0) - key, subkey0, subkey1 = jax.random.split(key, num=3) - weights0 = jax.random.uniform(key=subkey0, shape=(n,)) - weights1 = jax.random.uniform(key=subkey1, shape=(n,)) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + weights0 = jax.random.uniform(key=subrng0, shape=(n,)) + weights1 = jax.random.uniform(key=subrng1, shape=(n,)) weights = jnp.concatenate([weights0, weights1], axis=0) aprobs0 = jnp.stack([jnp.ones((n,)), jnp.zeros((n,))], axis=-1) aprobs1 = jnp.stack([jnp.zeros((n,)), jnp.ones((n,))], axis=-1) @@ -59,9 +59,9 @@ def test_get_summary_stats_from_points_and_assignment_probs( np.testing.assert_allclose(expected_cov, cov, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(expected_wt, comp_wt, atol=1e-4, rtol=1e-4) - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_random( - key=rng, n_components=3, n_dimensions=2 + rng=rng, n_components=3, n_dimensions=2 ) np.testing.assert_array_equal([gmm.n_components, gmm.n_dimensions], (3, 2)) @@ -82,9 +82,9 @@ def test_from_mean_cov_component_weights(self,): comp_wts, gmm.component_weights, atol=1e-4, rtol=1e-4 ) - def test_covariance(self, rng: jnp.ndarray): + def test_covariance(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_random( - key=rng, n_components=3, n_dimensions=2 + rng=rng, n_components=3, n_dimensions=2 ) cov = gmm.covariance for i, component in enumerate(gmm.components()): @@ -92,13 +92,13 @@ def test_covariance(self, rng: jnp.ndarray): cov[i], component.covariance(), atol=1e-4, rtol=1e-4 ) - def test_sample(self, rng: jnp.ndarray): + def test_sample(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_mean_cov_component_weights( mean=jnp.array([[-1., 0.], [1., 0.]]), cov=jnp.array([[[0.01, 0.], [0., 0.01]], [[0.01, 0.], [0., 0.01]]]), component_weights=jnp.array([0.2, 0.8]) ) - samples = gmm.sample(key=rng, size=10000) + samples = gmm.sample(rng=rng, size=10000) frac_pos = jnp.mean(samples[:, 0] > 0.) np.testing.assert_array_equal(samples.shape, (10000, 2)) @@ -114,19 +114,19 @@ def test_sample(self, rng: jnp.ndarray): atol=1.e-1 ) - def test_log_prob(self, rng: jnp.ndarray): + def test_log_prob(self, rng: jax.random.PRNGKeyArray): n_components = 3 size = 100 - subkey0, subkey1 = jax.random.split(rng, num=2) + subrng0, subrng1 = jax.random.split(rng, num=2) gmm = gaussian_mixture.GaussianMixture.from_random( - key=subkey0, + rng=subrng0, n_components=3, n_dimensions=2, stdev_mean=1., stdev_cov=1., stdev_weights=1 ) - x = gmm.sample(key=subkey1, size=size) + x = gmm.sample(rng=subrng1, size=size) actual = gmm.log_prob(x) prob = jnp.zeros(size) @@ -138,9 +138,9 @@ def test_log_prob(self, rng: jnp.ndarray): np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1e-4) - def test_log_component_posterior(self, rng: jnp.ndarray): + def test_log_component_posterior(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_random( - key=rng, n_components=3, n_dimensions=2 + rng=rng, n_components=3, n_dimensions=2 ) x = jnp.zeros(shape=(1, 2)) px_c = jnp.exp(gmm.conditional_log_prob(x)) @@ -152,18 +152,18 @@ def test_log_component_posterior(self, rng: jnp.ndarray): expected, gmm.get_log_component_posterior(x), atol=1e-4, rtol=1e-4 ) - def test_flatten_unflatten(self, rng: jnp.ndarray): + def test_flatten_unflatten(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_random( - key=rng, n_components=3, n_dimensions=2 + rng=rng, n_components=3, n_dimensions=2 ) children, aux_data = jax.tree_util.tree_flatten(gmm) gmm_new = jax.tree_util.tree_unflatten(aux_data, children) assert gmm == gmm_new - def test_pytree_mapping(self, rng: jnp.ndarray): + def test_pytree_mapping(self, rng: jax.random.PRNGKeyArray): gmm = gaussian_mixture.GaussianMixture.from_random( - key=rng, n_components=3, n_dimensions=2 + rng=rng, n_components=3, n_dimensions=2 ) gmm_x_2 = jax.tree_map(lambda x: 2 * x, gmm) np.testing.assert_allclose(2. * gmm.loc, gmm_x_2.loc, atol=1e-4, rtol=1e-4) diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index 6334f2d51..b09c9c8fe 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -25,8 +25,8 @@ @pytest.mark.fast class TestGaussian: - def test_from_random(self, rng: jnp.ndarray): - g = gaussian.Gaussian.from_random(key=rng, n_dimensions=3) + def test_from_random(self, rng: jax.random.PRNGKeyArray): + g = gaussian.Gaussian.from_random(rng=rng, n_dimensions=3) np.testing.assert_array_equal(g.loc.shape, (3,)) np.testing.assert_array_equal(g.covariance().shape, (3, 3)) @@ -39,14 +39,14 @@ def test_from_mean_and_cov(self): np.testing.assert_array_equal(mean, g.loc) np.testing.assert_allclose(cov, g.covariance(), atol=1e-4, rtol=1e-4) - def test_to_z(self, rng: jnp.ndarray): + def test_to_z(self, rng: jax.random.PRNGKeyArray): g = gaussian.Gaussian( loc=jnp.array([1., 2.]), scale=scale_tril.ScaleTriL( params=jnp.array([0., 0.25, jnp.log(0.5)]), size=2 ) ) - samples = g.sample(key=rng, size=1000) + samples = g.sample(rng=rng, size=1000) z = g.to_z(samples) sample_mean = jnp.mean(z, axis=0) sample_cov = jnp.cov(z, rowvar=False) @@ -55,47 +55,47 @@ def test_to_z(self, rng: jnp.ndarray): np.testing.assert_allclose(sample_mean, jnp.zeros(2), atol=0.1) np.testing.assert_allclose(sample_cov, jnp.eye(2), atol=0.1) - def test_from_z(self, rng: jnp.ndarray): + def test_from_z(self, rng: jax.random.PRNGKeyArray): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( params=jnp.array([jnp.log(2.), 0., 0.]), size=2 ) ) - x = g.sample(key=rng, size=100) + x = g.sample(rng=rng, size=100) z = g.to_z(x) xnew = g.from_z(z) np.testing.assert_allclose(x, xnew, atol=1e-4, rtol=1e-4) - def test_log_prob(self, rng: jnp.ndarray): + def test_log_prob(self, rng: jax.random.PRNGKeyArray): g = gaussian.Gaussian( loc=jnp.array([0., 0.]), scale=scale_tril.ScaleTriL( params=jnp.array([jnp.log(2.), 0., 0.]), size=2 ) ) - x = g.sample(key=rng, size=100) + x = g.sample(rng=rng, size=100) actual = g.log_prob(x) expected = jnp.log( jax.scipy.stats.multivariate_normal.pdf(x, g.loc, g.covariance()) ) np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_sample(self, rng: jnp.ndarray): + def test_sample(self, rng: jax.random.PRNGKeyArray): mean = jnp.array([1., 2.]) cov = jnp.diag(jnp.array([1., 4.])) g = gaussian.Gaussian.from_mean_and_cov(mean, cov) - samples = g.sample(key=rng, size=10000) + samples = g.sample(rng=rng, size=10000) sample_mean = jnp.mean(samples, axis=0) sample_cov = jnp.cov(samples, rowvar=False) np.testing.assert_allclose(sample_mean, mean, atol=3. * 2. / 100.) np.testing.assert_allclose(sample_cov, cov, atol=2e-1) - def test_w2_dist(self, rng: jnp.ndarray): + def test_w2_dist(self, rng: jax.random.PRNGKeyArray): # make sure distance between a random normal and itself is 0 - key, subkey = jax.random.split(rng) - n = gaussian.Gaussian.from_random(key=subkey, n_dimensions=3) + rng, subrng = jax.random.split(rng) + n = gaussian.Gaussian.from_random(rng=subrng, n_dimensions=3) w2 = n.w2_dist(n) np.testing.assert_almost_equal(w2, 0., decimal=5) @@ -103,12 +103,12 @@ def test_w2_dist(self, rng: jnp.ndarray): # distance between covariances = frobenius norm^2 of (delta cholesky) # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # pylint: disable=line-too-long size = 4 - key, subkey0, subkey1 = jax.random.split(key, num=3) - loc0 = jax.random.normal(key=subkey0, shape=(size,)) - loc1 = jax.random.normal(key=subkey1, shape=(size,)) - key, subkey0, subkey1 = jax.random.split(key, num=3) - diag0 = jnp.exp(jax.random.normal(key=subkey0, shape=(size,))) - diag1 = jnp.exp(jax.random.normal(key=subkey1, shape=(size,))) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + loc0 = jax.random.normal(key=subrng0, shape=(size,)) + loc1 = jax.random.normal(key=subrng1, shape=(size,)) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) + diag1 = jnp.exp(jax.random.normal(key=subrng1, shape=(size,))) g0 = gaussian.Gaussian( loc=loc0, scale=scale_tril.ScaleTriL.from_covariance(jnp.diag(diag0)) ) @@ -121,7 +121,7 @@ def test_w2_dist(self, rng: jnp.ndarray): expected = delta_mean + delta_sigma np.testing.assert_allclose(expected, w2, rtol=1e-6, atol=1e-6) - def test_transport(self, rng: jnp.ndarray): + def test_transport(self, rng: jax.random.PRNGKeyArray): diag0 = jnp.array([1.]) diag1 = jnp.array([4.]) g0 = gaussian.Gaussian( @@ -137,14 +137,14 @@ def test_transport(self, rng: jnp.ndarray): expected = 2. * points + 1. np.testing.assert_allclose(expected, actual, atol=1e-5, rtol=1e-5) - def test_flatten_unflatten(self, rng: jnp.ndarray): + def test_flatten_unflatten(self, rng: jax.random.PRNGKeyArray): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(g) g_new = jax.tree_util.tree_unflatten(aux_data, children) assert g == g_new - def test_pytree_mapping(self, rng: jnp.ndarray): + def test_pytree_mapping(self, rng: jax.random.PRNGKeyArray): g = gaussian.Gaussian.from_random(rng, n_dimensions=3) g_x_2 = jax.tree_map(lambda x: 2 * x, g) diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 210eab8cb..927d8f849 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -25,7 +25,7 @@ @pytest.mark.fast class TestLinalg: - def test_get_mean_and_var(self, rng: jnp.ndarray): + def test_get_mean_and_var(self, rng: jax.random.PRNGKeyArray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -36,7 +36,9 @@ def test_get_mean_and_var(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, atol=1E-5, rtol=1E-5) np.testing.assert_allclose(expected_var, actual_var, atol=1E-5, rtol=1E-5) - def test_get_mean_and_var_nonuniform_weights(self, rng: jnp.ndarray): + def test_get_mean_and_var_nonuniform_weights( + self, rng: jax.random.PRNGKeyArray + ): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -47,7 +49,7 @@ def test_get_mean_and_var_nonuniform_weights(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_var, actual_var, rtol=1e-6, atol=1e-6) - def test_get_mean_and_cov(self, rng: jnp.ndarray): + def test_get_mean_and_cov(self, rng: jax.random.PRNGKeyArray): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.ones(10) expected_mean = jnp.mean(points, axis=0) @@ -58,7 +60,9 @@ def test_get_mean_and_cov(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, atol=1e-5, rtol=1e-5) np.testing.assert_allclose(expected_cov, actual_cov, atol=1e-5, rtol=1e-5) - def test_get_mean_and_cov_nonuniform_weights(self, rng: jnp.ndarray): + def test_get_mean_and_cov_nonuniform_weights( + self, rng: jax.random.PRNGKeyArray + ): points = jax.random.normal(key=rng, shape=(10, 2)) weights = jnp.concatenate([jnp.ones(5), jnp.zeros(5)], axis=-1) expected_mean = jnp.mean(points[:5], axis=0) @@ -69,7 +73,7 @@ def test_get_mean_and_cov_nonuniform_weights(self, rng: jnp.ndarray): np.testing.assert_allclose(expected_mean, actual_mean, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(expected_cov, actual_cov, rtol=1e-6, atol=1e-6) - def test_flat_to_tril(self, rng: jnp.ndarray): + def test_flat_to_tril(self, rng: jax.random.PRNGKeyArray): size = 3 x = jax.random.normal(key=rng, shape=(5, 4, size * (size + 1) // 2)) m = linalg.flat_to_tril(x, size) @@ -89,7 +93,7 @@ def test_flat_to_tril(self, rng: jnp.ndarray): actual = linalg.tril_to_flat(m) np.testing.assert_allclose(x, actual) - def test_tril_to_flat(self, rng: jnp.ndarray): + def test_tril_to_flat(self, rng: jax.random.PRNGKeyArray): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) for i in range(size): @@ -106,7 +110,7 @@ def test_tril_to_flat(self, rng: jnp.ndarray): inverted = linalg.flat_to_tril(flat, size) np.testing.assert_allclose(m, inverted) - def test_apply_to_diag(self, rng: jnp.ndarray): + def test_apply_to_diag(self, rng: jax.random.PRNGKeyArray): size = 3 m = jax.random.normal(key=rng, shape=(5, 4, size, size)) mnew = linalg.apply_to_diag(m, jnp.exp) @@ -117,9 +121,9 @@ def test_apply_to_diag(self, rng: jnp.ndarray): else: np.testing.assert_allclose(jnp.exp(m[..., i, j]), mnew[..., i, j]) - def test_matrix_powers(self, rng: jnp.ndarray): - key, subkey = jax.random.split(rng) - m = jax.random.normal(key=subkey, shape=(4, 4)) + def test_matrix_powers(self, rng: jax.random.PRNGKeyArray): + rng, subrng = jax.random.split(rng) + m = jax.random.normal(key=subrng, shape=(4, 4)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric m = jnp.matmul(m, m) # symmetric, pos def inv_m = jnp.linalg.inv(m) @@ -128,22 +132,22 @@ def test_matrix_powers(self, rng: jnp.ndarray): np.testing.assert_allclose(m, actual[0], rtol=1.e-5) np.testing.assert_allclose(inv_m, actual[1], rtol=1.e-4) - def test_invmatvectril(self, rng: jnp.ndarray): - key, subkey = jax.random.split(rng) - m = jax.random.normal(key=subkey, shape=(2, 2)) + def test_invmatvectril(self, rng: jax.random.PRNGKeyArray): + rng, subrng = jax.random.split(rng) + m = jax.random.normal(key=subrng, shape=(2, 2)) m += jnp.swapaxes(m, axis1=-2, axis2=-1) # symmetric m = jnp.matmul(m, m) # symmetric, pos def cholesky = jnp.linalg.cholesky(m) # lower triangular - key, subkey = jax.random.split(key) - x = jax.random.normal(key=subkey, shape=(10, 2)) + rng, subrng = jax.random.split(rng) + x = jax.random.normal(key=subrng, shape=(10, 2)) inv_cholesky = jnp.linalg.inv(cholesky) expected = jnp.transpose(jnp.matmul(inv_cholesky, jnp.transpose(x))) actual = linalg.invmatvectril(m=cholesky, x=x, lower=True) np.testing.assert_allclose(expected, actual, atol=1e-4, rtol=1.e-4) - def test_get_random_orthogonal(self, rng: jnp.ndarray): - key, subkey = jax.random.split(rng) - q = linalg.get_random_orthogonal(key=subkey, dim=3) + def test_get_random_orthogonal(self, rng: jax.random.PRNGKeyArray): + rng, subrng = jax.random.split(rng) + q = linalg.get_random_orthogonal(rng=subrng, dim=3) qt = jnp.transpose(q) expected = jnp.eye(3) actual = jnp.matmul(q, qt) diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 44220dde0..65dd727c1 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -44,9 +44,9 @@ def test_log_probs(self): def test_from_random(self): n_dimensions = 4 - key = jax.random.PRNGKey(0) + rng = jax.random.PRNGKey(0) pp = probabilities.Probabilities.from_random( - key=key, n_dimensions=n_dimensions, stdev=0.1 + rng=rng, n_dimensions=n_dimensions, stdev=0.1 ) np.testing.assert_array_equal(pp.probs().shape, (4,)) @@ -59,7 +59,7 @@ def test_sample(self): p = 0.4 probs = jnp.array([p, 1. - p]) pp = probabilities.Probabilities.from_probs(probs) - samples = pp.sample(key=jax.random.PRNGKey(0), size=10000) + samples = pp.sample(rng=jax.random.PRNGKey(0), size=10000) sd = jnp.sqrt(p * (1. - p)) np.testing.assert_allclose(jnp.mean(samples == 0), p, atol=3. * sd) diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 2e07b10de..6aa28b4ae 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -50,27 +50,27 @@ def test_log_det_covariance(self, chol: scale_tril.ScaleTriL): actual = chol.log_det_covariance() np.testing.assert_almost_equal(actual, expected) - def test_from_random(self, rng: jnp.ndarray): + def test_from_random(self, rng: jax.random.PRNGKeyArray): n_dimensions = 4 cov = scale_tril.ScaleTriL.from_random( - key=rng, n_dimensions=n_dimensions, stdev=0.1 + rng=rng, n_dimensions=n_dimensions, stdev=0.1 ) np.testing.assert_array_equal( cov.cholesky().shape, (n_dimensions, n_dimensions) ) - def test_from_cholesky(self, rng: jnp.ndarray): + def test_from_cholesky(self, rng: jax.random.PRNGKeyArray): n_dimensions = 4 cholesky = scale_tril.ScaleTriL.from_random( - key=rng, n_dimensions=n_dimensions, stdev=1. + rng=rng, n_dimensions=n_dimensions, stdev=1. ).cholesky() scale = scale_tril.ScaleTriL.from_cholesky(cholesky) np.testing.assert_allclose(cholesky, scale.cholesky(), atol=1e-4, rtol=1e-4) - def test_w2_dist(self, rng: jnp.ndarray): + def test_w2_dist(self, rng: jax.random.PRNGKeyArray): # make sure distance between a random normal and itself is 0 - key, subkey = jax.random.split(rng) - s = scale_tril.ScaleTriL.from_random(key=subkey, n_dimensions=3) + rng, subrng = jax.random.split(rng) + s = scale_tril.ScaleTriL.from_random(rng=subrng, n_dimensions=3) w2 = s.w2_dist(s) expected = 0. np.testing.assert_allclose(expected, w2, atol=1e-4, rtol=1e-4) @@ -79,37 +79,37 @@ def test_w2_dist(self, rng: jnp.ndarray): # distance between covariances = Frobenius norm^2 of (delta sqrt(cov)) # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # pylint: disable=line-too-long size = 4 - key, subkey0, subkey1 = jax.random.split(key, num=3) - diag0 = jnp.exp(jax.random.normal(key=subkey0, shape=(size,))) - diag1 = jnp.exp(jax.random.normal(key=subkey1, shape=(size,))) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) + diag1 = jnp.exp(jax.random.normal(key=subrng1, shape=(size,))) s0 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag0)) s1 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag1)) w2 = s0.w2_dist(s1) delta_sigma = jnp.sum((jnp.sqrt(diag0) - jnp.sqrt(diag1)) ** 2.) np.testing.assert_allclose(delta_sigma, w2, atol=1e-4, rtol=1e-4) - def test_transport(self, rng: jnp.ndarray): + def test_transport(self, rng: jax.random.PRNGKeyArray): size = 4 - key, subkey0, subkey1 = jax.random.split(rng, num=3) - diag0 = jnp.exp(jax.random.normal(key=subkey0, shape=(size,))) + rng, subrng0, subrng1 = jax.random.split(rng, num=3) + diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) s0 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag0)) - diag1 = jnp.exp(jax.random.normal(key=subkey1, shape=(size,))) + diag1 = jnp.exp(jax.random.normal(key=subrng1, shape=(size,))) s1 = scale_tril.ScaleTriL.from_covariance(jnp.diag(diag1)) - key, subkey = jax.random.split(key) - x = jax.random.normal(key=subkey, shape=(100, size)) + rng, subrng = jax.random.split(rng) + x = jax.random.normal(key=subrng, shape=(100, size)) transported = s0.transport(s1, points=x) expected = x * jnp.sqrt(diag1)[None] / jnp.sqrt(diag0)[None] np.testing.assert_allclose(expected, transported, atol=1e-4, rtol=1e-4) - def test_flatten_unflatten(self, rng: jnp.ndarray): - scale = scale_tril.ScaleTriL.from_random(key=rng, n_dimensions=3) + def test_flatten_unflatten(self, rng: jax.random.PRNGKeyArray): + scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) children, aux_data = jax.tree_util.tree_flatten(scale) scale_new = jax.tree_util.tree_unflatten(aux_data, children) np.testing.assert_array_equal(scale.params, scale_new.params) assert scale == scale_new - def test_pytree_mapping(self, rng: jnp.ndarray): - scale = scale_tril.ScaleTriL.from_random(key=rng, n_dimensions=3) + def test_pytree_mapping(self, rng: jax.random.PRNGKeyArray): + scale = scale_tril.ScaleTriL.from_random(rng=rng, n_dimensions=3) scale_x_2 = jax.tree_map(lambda x: 2 * x, scale) np.testing.assert_allclose(2. * scale.params, scale_x_2.params) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index de9fb0876..82fd31a30 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -51,14 +51,14 @@ def compute_assignment( class TestKmeansPlusPlus: @pytest.mark.fast.with_args("n_local_trials", [None, 1, 5], only_fast=-1) - def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): + def test_n_local_trials(self, rng: jax.random.PRNGKeyArray, n_local_trials): n, k = 150, 4 - key1, key2 = jax.random.split(rng) + rng1, rng2 = jax.random.split(rng) geom, _, c = make_blobs( n_samples=n, centers=k, cost_fn='sqeucl', random_state=0 ) - centers1 = k_means._k_means_plus_plus(geom, k, key1, n_local_trials) - centers2 = k_means._k_means_plus_plus(geom, k, key2, 20) + centers1 = k_means._k_means_plus_plus(geom, k, rng1, n_local_trials) + centers2 = k_means._k_means_plus_plus(geom, k, rng2, 20) shift1 = jnp.linalg.norm(centers1 - c, ord="fro") ** 2 shift2 = jnp.linalg.norm(centers2 - c, ord="fro") ** 2 @@ -66,7 +66,7 @@ def test_n_local_trials(self, rng: jnp.ndarray, n_local_trials): assert shift1 > shift2 @pytest.mark.fast.with_args("k", [4, 5, 10], only_fast=0) - def test_matches_sklearn(self, rng: jnp.ndarray, k: int): + def test_matches_sklearn(self, rng: jax.random.PRNGKeyArray, k: int): ndim = 2 geom, _, _ = make_blobs( n_samples=200, @@ -91,11 +91,11 @@ def test_matches_sklearn(self, rng: jnp.ndarray, k: int): # the largest was 70.56378 assert jnp.abs(pred_inertia - gt_inertia) <= 100 - def test_initialization_differentiable(self, rng: jnp.ndarray): + def test_initialization_differentiable(self, rng: jax.random.PRNGKeyArray): def callback(x: jnp.ndarray) -> float: geom = pointcloud.PointCloud(x) - centers = k_means._k_means_plus_plus(geom, k=3, key=rng) + centers = k_means._k_means_plus_plus(geom, k=3, rng=rng) _, inertia = compute_assignment(x, centers) return inertia @@ -111,7 +111,7 @@ class TestKmeans: @pytest.mark.fast @pytest.mark.parametrize("k", [1, 6]) - def test_k_means_output(self, rng: jnp.ndarray, k: int): + def test_k_means_output(self, rng: jax.random.PRNGKeyArray, k: int): max_iter, ndim = 10, 4 geom, gt_assignment, _ = make_blobs( n_samples=50, n_features=ndim, centers=k, random_state=42 @@ -119,7 +119,7 @@ def test_k_means_output(self, rng: jnp.ndarray, k: int): gt_assignment = np.array(gt_assignment) res = k_means.k_means( - geom, k, max_iterations=max_iter, store_inner_errors=False, key=rng + geom, k, max_iterations=max_iter, store_inner_errors=False, rng=rng ) pred_assignment = np.array(res.assignment) @@ -149,7 +149,7 @@ def test_k_means_simple_example(self): ["k-means++", "random", "callable", "wrong-callable"], only_fast=1, ) - def test_init_method(self, rng: jnp.ndarray, init: str): + def test_init_method(self, rng: jax.random.PRNGKeyArray, init: str): if init == "callable": init_fn = lambda geom, k, _: geom.x[:k] elif init == "wrong-callable": @@ -165,37 +165,41 @@ def test_init_method(self, rng: jnp.ndarray, init: str): else: _ = k_means.k_means(geom, k, init=init_fn) - def test_k_means_plus_plus_better_than_random(self, rng: jnp.ndarray): + def test_k_means_plus_plus_better_than_random( + self, rng: jax.random.PRNGKeyArray + ): k = 5 - key1, key2 = jax.random.split(rng, 2) + rng1, rng2 = jax.random.split(rng, 2) geom, _, _ = make_blobs(n_samples=50, centers=k, random_state=10) - res_random = k_means.k_means(geom, k, init="random", key=key1) - res_kpp = k_means.k_means(geom, k, init="k-means++", key=key2) + res_random = k_means.k_means(geom, k, init="random", rng=rng1) + res_kpp = k_means.k_means(geom, k, init="k-means++", rng=rng2) assert res_random.converged assert res_kpp.converged assert res_kpp.iteration < res_random.iteration assert res_kpp.error <= res_random.error - def test_larger_n_init_helps(self, rng: jnp.ndarray): + def test_larger_n_init_helps(self, rng: jax.random.PRNGKeyArray): k = 10 geom, _, _ = make_blobs(n_samples=150, centers=k, random_state=0) - res = k_means.k_means(geom, k, n_init=3, key=rng) - res_larger_n_init = k_means.k_means(geom, k, n_init=20, key=rng) + res = k_means.k_means(geom, k, n_init=3, rng=rng) + res_larger_n_init = k_means.k_means(geom, k, n_init=20, rng=rng) assert res_larger_n_init.error < res.error @pytest.mark.parametrize("max_iter", [8, 16]) - def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): + def test_store_inner_errors( + self, rng: jax.random.PRNGKeyArray, max_iter: int + ): ndim, k = 10, 4 geom, _, _ = make_blobs( n_samples=40, n_features=ndim, centers=k, random_state=43 ) res = k_means.k_means( - geom, k, max_iterations=max_iter, store_inner_errors=True, key=rng + geom, k, max_iterations=max_iter, store_inner_errors=True, rng=rng ) errors = res.inner_errors @@ -204,12 +208,12 @@ def test_store_inner_errors(self, rng: jnp.ndarray, max_iter: int): # check if error is decreasing np.testing.assert_array_equal(jnp.diff(errors[::-1]) >= 0., True) - def test_strict_tolerance(self, rng: jnp.ndarray): + def test_strict_tolerance(self, rng: jax.random.PRNGKeyArray): k = 11 geom, _, _ = make_blobs(n_samples=200, centers=k, random_state=39) - res = k_means.k_means(geom, k=k, tol=1., key=rng) - res_strict = k_means.k_means(geom, k=k, tol=0., key=rng) + res = k_means.k_means(geom, k=k, tol=1., rng=rng) + res_strict = k_means.k_means(geom, k=k, tol=0., rng=rng) assert res.converged assert res_strict.converged @@ -218,7 +222,9 @@ def test_strict_tolerance(self, rng: jnp.ndarray): @pytest.mark.parametrize( "tol", [1e-3, 0.], ids=["weak-convergence", "strict-convergence"] ) - def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): + def test_convergence_force_scan( + self, rng: jax.random.PRNGKeyArray, tol: float + ): k, n_iter = 9, 20 geom, _, _ = make_blobs(n_samples=100, centers=k, random_state=37) @@ -229,14 +235,14 @@ def test_convergence_force_scan(self, rng: jnp.ndarray, tol: float): min_iterations=n_iter, max_iterations=n_iter, store_inner_errors=True, - key=rng + rng=rng ) assert res.converged assert res.iteration == n_iter np.testing.assert_array_equal(res.inner_errors == -1, False) - def test_k_means_min_iterations(self, rng: jnp.ndarray): + def test_k_means_min_iterations(self, rng: jax.random.PRNGKeyArray): k, min_iter = 8, 12 geom, _, _ = make_blobs(n_samples=160, centers=k, random_state=38) @@ -247,17 +253,19 @@ def test_k_means_min_iterations(self, rng: jnp.ndarray): min_iterations=min_iter, max_iterations=20, tol=0., - key=rng + rng=rng ) assert res.converged assert jnp.sum(res.inner_errors != -1) >= min_iter - def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): + def test_weight_scaling_effects_only_inertia( + self, rng: jax.random.PRNGKeyArray + ): k = 10 - key1, key2 = jax.random.split(rng) + rng1, rng2 = jax.random.split(rng) geom, _, _ = make_blobs(n_samples=130, centers=k, random_state=3) - weights = jnp.abs(jax.random.normal(key1, shape=(geom.shape[0],))) + weights = jnp.abs(jax.random.normal(rng1, shape=(geom.shape[0],))) weights_scaled = weights / jnp.sum(weights) res = k_means.k_means(geom, k=k - 1, weights=weights) @@ -274,7 +282,7 @@ def test_weight_scaling_effects_only_inertia(self, rng: jnp.ndarray): ) @pytest.mark.fast - def test_empty_weights(self, rng: jnp.ndarray): + def test_empty_weights(self, rng: jax.random.PRNGKeyArray): n, ndim, k, d = 20, 2, 3, 5. x = np.random.normal(size=(n, ndim)) x[:, 0] += d @@ -293,7 +301,7 @@ def test_empty_weights(self, rng: jnp.ndarray): weights = jnp.ones((x.shape[0],)).at[:n].set(0.) expected_centroids = jnp.stack([w.mean(0), z.mean(0), y.mean(0)]) - res = k_means.k_means(x, k=k, weights=weights, key=rng) + res = k_means.k_means(x, k=k, weights=weights, rng=rng) cost = pointcloud.PointCloud(res.centroids, expected_centroids).cost_matrix ixs = jnp.argmin(cost, axis=1) @@ -321,12 +329,12 @@ def test_cosine_cost_fn(self): @pytest.mark.fast.with_args("init", ["k-means++", "random"], only_fast=0) def test_k_means_jitting( - self, rng: jnp.ndarray, init: Literal["k-means++", "random"] + self, rng: jax.random.PRNGKeyArray, init: Literal["k-means++", "random"] ): def callback(x: jnp.ndarray) -> k_means.KMeansOutput: return k_means.k_means( - x, k=k, init=init, store_inner_errors=True, key=rng + x, k=k, init=init, store_inner_errors=True, rng=rng ) k = 7 @@ -354,7 +362,7 @@ def callback(x: jnp.ndarray) -> k_means.KMeansOutput: ids=["jit-while-loop", "nojit-for-loop"] ) def test_k_means_differentiability( - self, rng: jnp.ndarray, jit: bool, force_scan: bool + self, rng: jax.random.PRNGKeyArray, jit: bool, force_scan: bool ): def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: @@ -364,17 +372,17 @@ def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: weights=w, min_iterations=20 if force_scan else 1, max_iterations=20, - key=key1, + rng=rng1, ).error k, eps, tol = 4, 1e-3, 1e-3 x, _, _ = make_blobs(n_samples=150, centers=k, random_state=41) - key1, key2, key3, key4 = jax.random.split(rng, 4) - w = jnp.abs(jax.random.normal(key2, (x.shape[0],))) + rng1, rng2, rng3, rng4 = jax.random.split(rng, 4) + w = jnp.abs(jax.random.normal(rng2, (x.shape[0],))) - v_x = jax.random.normal(key3, shape=x.shape) + v_x = jax.random.normal(rng3, shape=x.shape) v_x = (v_x / jnp.linalg.norm(v_x, axis=-1, keepdims=True)) * eps - v_w = jax.random.normal(key4, shape=w.shape) * eps + v_w = jax.random.normal(rng4, shape=w.shape) * eps v_w = (v_w / jnp.linalg.norm(v_w, axis=-1, keepdims=True)) * eps grad_fn = jax.grad(inertia, (0, 1)) @@ -393,12 +401,12 @@ def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: @pytest.mark.parametrize("tol", [1e-3, 0.]) @pytest.mark.parametrize("n,k", [(37, 4), (128, 6)]) def test_clustering_matches_sklearn( - self, rng: jnp.ndarray, n: int, k: int, tol: float + self, rng: jax.random.PRNGKeyArray, n: int, k: int, tol: float ): x, _, _ = make_blobs(n_samples=n, centers=k, random_state=41) res_kmeans = KMeans(n_clusters=k, n_init=20, tol=tol, random_state=0).fit(x) - res_ours = k_means.k_means(x, k, n_init=20, tol=tol, key=rng) + res_ours = k_means.k_means(x, k, n_init=20, tol=tol, rng=rng) gt_labels = res_kmeans.labels_ pred_labels = np.array(res_ours.assignment) diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 960dc7f6c..8a5b9e923 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -29,7 +29,7 @@ class TestSegmentSinkhorn: @pytest.fixture(autouse=True) - def setUp(self, rng: jnp.ndarray): + def setUp(self, rng: jax.random.PRNGKeyArray): self._dim = 4 self._num_points = 13, 17 self._max_measure_size = 20 diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index d35c89446..71e5c1846 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -29,7 +29,7 @@ class TestSinkhornDivergence: @pytest.fixture(autouse=True) - def setUp(self, rng: jnp.ndarray): + def setUp(self, rng: jax.random.PRNGKeyArray): self._dim = 4 self._num_points = 13, 17 self.rng, *rngs = jax.random.split(rng, 3) @@ -405,7 +405,7 @@ def test_euclidean_momentum_params( class TestSinkhornDivergenceGrad: @pytest.fixture(autouse=True) - def initialize(self, rng: jnp.ndarray): + def initialize(self, rng: jax.random.PRNGKeyArray): self._dim = 3 self._num_points = 13, 12 self.rng, *rngs = jax.random.split(rng, 3) diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index c54f6f295..034198ae0 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -29,14 +29,16 @@ class TestSoftSort: @pytest.mark.parametrize("shape", [(20,), (20, 1)]) - def test_sort_one_array(self, rng: jnp.ndarray, shape: Tuple[int, ...]): + def test_sort_one_array( + self, rng: jax.random.PRNGKeyArray, shape: Tuple[int, ...] + ): x = jax.random.uniform(rng, shape) xs = soft_sort.sort(x, axis=0) np.testing.assert_array_equal(x.shape, xs.shape) np.testing.assert_array_equal(jnp.diff(xs, axis=0) >= 0.0, True) - def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): + def test_sort_array_squashing_momentum(self, rng: jax.random.PRNGKeyArray): shape = (33, 1) x = jax.random.uniform(rng, shape) xs_lin = soft_sort.sort( @@ -63,7 +65,7 @@ def test_sort_array_squashing_momentum(self, rng: jnp.ndarray): @pytest.mark.fast @pytest.mark.parametrize("k", [-1, 4, 100]) - def test_topk_one_array(self, rng: jnp.ndarray, k: int): + def test_topk_one_array(self, rng: jax.random.PRNGKeyArray, k: int): n = 20 x = jax.random.uniform(rng, (n,)) axis = 0 @@ -77,7 +79,7 @@ def test_topk_one_array(self, rng: jnp.ndarray, k: int): np.testing.assert_allclose(xs, jnp.sort(x, axis=axis)[-outsize:], atol=0.01) @pytest.mark.fast.with_args("topk", [-1, 2, 5, 11], only_fast=-1) - def test_sort_batch(self, rng: jnp.ndarray, topk: int): + def test_sort_batch(self, rng: jax.random.PRNGKeyArray, topk: int): x = jax.random.uniform(rng, (32, 10, 6, 4)) axis = 1 xs = soft_sort.sort(x, axis=axis, topk=topk) @@ -87,7 +89,7 @@ def test_sort_batch(self, rng: jnp.ndarray, topk: int): np.testing.assert_array_equal(xs.shape, expected_shape) np.testing.assert_array_equal(jnp.diff(xs, axis=axis) >= 0.0, True) - def test_rank_one_array(self, rng: jnp.ndarray): + def test_rank_one_array(self, rng: jax.random.PRNGKeyArray): x = jax.random.uniform(rng, (20,)) expected_ranks = jnp.argsort(jnp.argsort(x, axis=0), axis=0).astype(float) @@ -106,7 +108,7 @@ def test_quantile(self, level: float): np.testing.assert_approx_equal(q, level, significant=1) - def test_quantile_on_several_axes(self, rng: jnp.ndarray): + def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): batch, height, width, channels = 16, 100, 100, 3 x = jax.random.uniform(rng, shape=(batch, height, width, channels)) q = soft_sort.quantile( @@ -118,7 +120,7 @@ def test_quantile_on_several_axes(self, rng: jnp.ndarray): q, 0.5 * np.ones((batch, 1, channels)), atol=3e-2 ) - def test_soft_quantile_normalization(self, rng: jnp.ndarray): + def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray): rngs = jax.random.split(rng, 2) x = jax.random.uniform(rngs[0], shape=(100,)) mu, sigma = 2.0, 1.2 @@ -131,7 +133,7 @@ def test_soft_quantile_normalization(self, rng: jnp.ndarray): [mu_target, sigma_target], rtol=0.05) - def test_sort_with(self, rng: jnp.ndarray): + def test_sort_with(self, rng: jax.random.PRNGKeyArray): n, d = 20, 4 inputs = jax.random.uniform(rng, shape=(n, d)) criterion = jnp.linspace(0.1, 1.2, n) @@ -158,7 +160,9 @@ def test_quantize(self): np.testing.assert_allclose(min_distances, min_distances, atol=0.05) @pytest.mark.parametrize("implicit", [False, True]) - def test_soft_sort_jacobian(self, rng: jnp.ndarray, implicit: bool): + def test_soft_sort_jacobian( + self, rng: jax.random.PRNGKeyArray, implicit: bool + ): b, n = 10, 40 idx_column = 5 rngs = jax.random.split(rng, 3)