From 57fb3273c36e868761e0900d35d6ab68a5d4dea4 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Tue, 4 Jul 2023 13:53:53 +0200 Subject: [PATCH] incorporate feedback in #382 (#385) * comments by Michal in #382 * test --- src/ott/tools/soft_sort.py | 112 +++++++++++++++++----------------- tests/tools/soft_sort_test.py | 6 +- 2 files changed, 58 insertions(+), 60 deletions(-) diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 8f5ed10af..477a8204f 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -148,18 +148,19 @@ def sort( For instance: - ``` - x = jax.random.uniform(rng, (100,)) - x_sorted = sort(x) - ``` + .. code-block:: python + + x = jax.random.uniform(rng, (100,)) + x_sorted = sort(x) + will output sorted convex-combinations of values contained in ``x``, that are differentiable approximations to the sorted vector of entries in ``x``. These should be the values produced by :func:`jax.numpy.sort`, - ``` - x_ranks = jax.numpy.sort(x) - ``` + .. code-block:: python + + x_sorted = jax.numpy.sort(x) Args: @@ -180,11 +181,10 @@ def sort( to the user are ``squashing_fun``, which will redistribute the values in ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) to solve the optimal transport problem; - attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + attribute :class:`cost_fn ` of :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground cost function to transport from ``inputs`` to the ``num_targets`` target - values (:class:`~ott.geometry.costs.SqEuclidean` by default, see - :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + values ; ``epsilon`` regularization parameter. Remaining ``kwargs`` are passed on to defined the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. @@ -216,19 +216,19 @@ def ranks( For instance: - ``` - x = jax.random.uniform(rng, (100,)) - x_ranks = ranks(x) - ``` + .. code-block:: python + + x = jax.random.uniform(rng, (100,)) + x_ranks = ranks(x) will output fractional values, between 0 and 1, that are differentiable approximations to the normalized ranks of entries in ``x``. These should be compared to the non-differentiable rank vectors, namely the normalized inverse permutation produced by :func:`jax.numpy.argsort`, which can be obtained as: - ``` - x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0] - ``` + .. code-block:: python + + x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0] Args: inputs: jnp.ndarray of any shape. @@ -248,11 +248,10 @@ def ranks( to the user are ``squashing_fun``, which will redistribute the values in ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) to solve the optimal transport problem; - attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + attribute :class:`cost_fn ` of :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground cost function to transport from ``inputs`` to the ``num_targets`` target - values (:class:`~ott.geometry.costs.SqEuclidean` by default, see - :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + values ; ``epsilon`` regularization parameter. Remaining ``kwargs`` are passed on to defined the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. @@ -266,8 +265,8 @@ def ranks( def quantile( inputs: jnp.ndarray, + q: jnp.ndarray, axis: int = -1, - q: Optional[jnp.ndarray] = None, weight: Optional[Union[float, jnp.ndarray]] = None, **kwargs: Any, ) -> jnp.ndarray: @@ -275,10 +274,10 @@ def quantile( For instance: - ``` - x = jax.random.uniform(rng, (100,)) - x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8])) - ``` + .. code-block:: python + + x = jax.random.uniform(rng, (100,)) + x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8])) ``x_quantiles`` will hold an approximation to the 20 and 80 percentiles in ``x``, computed as a convex combination (a weighted mean, with weights summing @@ -289,16 +288,17 @@ def quantile( impact all values listed in ``x``, not just those indexed at 20 and 80). The non-differentiable version is given by :func:`jax.numpy.quantile`, e.g. - ``` - x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8])) - ``` + + .. code-block:: python + + x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8])) + Args: inputs: a jnp.ndarray of any shape. - axis: the axis on which to apply the operator. q: values of the quantile level to be computed, e.g. [0.5] for median. - These values should all lie in :math:`[0,1]` and are selected as - ``[0.2, 0.5, 0.8]`` by default. + These values should all lie in :math:`[0,1]`. + axis: the axis on which to apply the operator. weight: the weight assigned to each quantile target value in the OT problem. This weight should be small, typically of the order of ``1/n``, where ``n`` is the size of ``x``. Note: Since the size of ``q`` times ``weight`` @@ -309,11 +309,10 @@ def quantile( to the user are ``squashing_fun``, which will redistribute the values in ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) to solve the optimal transport problem; - attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + attribute :class:`cost_fn ` of :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground cost function to transport from ``inputs`` to the ``num_targets`` target - values (:class:`~ott.geometry.costs.SqEuclidean` by default, see - :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + values ; ``epsilon`` regularization parameter. Remaining ``kwargs`` are passed on to defined the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. @@ -379,7 +378,7 @@ def _quantile( jnp.ones((num_quantiles + 1, 1), dtype=bool) ], axis=1).ravel()[:-1] - return (out[odds])[idx] + return out[odds][idx] return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs) @@ -412,9 +411,9 @@ def quantile_normalization( inputs: array of any shape whose values will be changed to match those in ``targets``. targets: sorted array (in ascending order) of dimension 1 describing a - discrete distribution. Note: the``targets`` values must be provided as + discrete distribution. Note: the ``targets`` values must be provided as a sorted vector. - weights: vector of nonnegative weights, summing to :math:`1.0`, of the same + weights: vector of nonnegative weights, summing to :math:`1`, of the same size as ``targets``. When not set, this defaults to the uniform distribution. axis: the axis along which the quantile transformation is applied. @@ -422,17 +421,16 @@ def quantile_normalization( to the user are ``squashing_fun``, which will redistribute the values in ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) to solve the optimal transport problem; - attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + attribute :class:`cost_fn ` of :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground cost function to transport from ``inputs`` to the ``num_targets`` target - values (:class:`~ott.geometry.costs.SqEuclidean` by default, see - :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + values ; ``epsilon`` regularization parameter. Remaining ``kwargs`` are passed on to defined the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray, which has the same shape as the input, except on the give - axis on which the dimension is 1. + An array, which has the same shape as the input, except on the give axis on + which the dimension is 1. Raises: A ValueError in case the weights and the targets are both set and not of @@ -468,7 +466,7 @@ def sort_with( smaller indices will contain combinations of vectors with smaller criterion. Args: - inputs: the inputs as a jnp.ndarray[batch, dim]. + inputs: Array of size [batch, dim]. criterion: the values according to which to sort the inputs. It has shape [batch, 1]. topk: The number of outputs to keep. @@ -476,16 +474,15 @@ def sort_with( to the user are ``squashing_fun``, which will redistribute the values in ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) to solve the optimal transport problem; - attribute :attr:`~ott.geometry.pointcloud.cost_fn` of + attribute :class:`cost_fn ` of :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground cost function to transport from ``inputs`` to the ``num_targets`` target - values (:class:`~ott.geometry.costs.SqEuclidean` by default, see - :class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization + values ; ``epsilon`` regularization parameter. Remaining ``kwargs`` are passed on to defined the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. Returns: - A jnp.ndarray[batch | topk, dim]. + An Array of size [batch | topk, dim]. """ num_points = criterion.shape[0] weights = jnp.ones(num_points, dtype=criterion.dtype) / num_points @@ -538,19 +535,22 @@ def quantize( differentiable. Args: - inputs: the inputs as a jnp.ndarray[batch, dim]. - num_levels: number of q available to quantize the signal. + inputs: an Array of size [batch, dim]. + num_levels: number of quantiles available to quantize the signal. axis: axis along which quantization is carried out. kwargs: keyword arguments passed on to lower level functions. Of interest to the user are ``squashing_fun``, which will redistribute the values in - ``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to - solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``, - that defines the ground cost function to transport from ``inputs`` to the - ``num_targets`` target values (squared Euclidean distance by default, see - ``pointcloud.py`` for more details); ``epsilon`` values as well as other - parameters to shape the ``sinkhorn`` algorithm. + ``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default) + to solve the optimal transport problem; + attribute :class:`cost_fn ` of + :class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground + cost function to transport from ``inputs`` to the ``num_targets`` target + values ; ``epsilon`` regularization + parameter. Remaining ``kwargs`` are passed on to defined the + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver. + Returns: - A jnp.ndarray of the same size as ``inputs``. + An Array of the same size as ``inputs``. """ return apply_on_axis(_quantize, inputs, axis, num_levels, **kwargs) diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 8e2085b4c..771243a44 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -116,13 +116,11 @@ def test_quantile_on_several_axes(self, rng: jax.random.PRNGKeyArray): ) @pytest.mark.fast() - def test_quantiles(self): - inputs = jax.random.uniform(jax.random.PRNGKey(0), (200, 2, 3)) + def test_quantiles(self, rng: jax.random.PRNGKeyArray): + inputs = jax.random.uniform(rng, (200, 2, 3)) q = jnp.array([.1, .8, .4]) m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2) - m2 = soft_sort.quantile(inputs, q=q, weight=None, axis=0) - np.testing.assert_allclose(m2.mean(axis=[1, 2]), q, atol=5e-2) def test_soft_quantile_normalization(self, rng: jax.random.PRNGKeyArray): rngs = jax.random.split(rng, 2)