Skip to content

Commit

Permalink
comments by Michal in #382
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Jul 4, 2023
1 parent b2b7ebb commit 202ed76
Showing 1 changed file with 56 additions and 56 deletions.
112 changes: 56 additions & 56 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,18 +148,19 @@ def sort(
For instance:
```
x = jax.random.uniform(rng, (100,))
x_sorted = sort(x)
```
.. code-block:: python
x = jax.random.uniform(rng, (100,))
x_sorted = sort(x)
will output sorted convex-combinations of values contained in ``x``, that are
differentiable approximations to the sorted vector of entries in ``x``.
These should be the values produced by :func:`jax.numpy.sort`,
```
x_ranks = jax.numpy.sort(x)
```
.. code-block:: python
x_sorted = jax.numpy.sort(x)
Args:
Expand All @@ -180,11 +181,10 @@ def sort(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :attr:`~ott.geometry.pointcloud.cost_fn` of
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values (:class:`~ott.geometry.costs.SqEuclidean` by default, see
:class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -216,19 +216,19 @@ def ranks(
For instance:
```
x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)
```
.. code-block:: python
x = jax.random.uniform(rng, (100,))
x_ranks = ranks(x)
will output fractional values, between 0 and 1, that are differentiable
approximations to the normalized ranks of entries in ``x``. These should be
compared to the non-differentiable rank vectors, namely the normalized inverse
permutation produced by :func:`jax.numpy.argsort`, which can be obtained as:
```
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
```
.. code-block:: python
x_ranks = jax.numpy.argsort(jax.numpy.argsort(x)) / x.shape[0]
Args:
inputs: jnp.ndarray<float> of any shape.
Expand All @@ -248,11 +248,10 @@ def ranks(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :attr:`~ott.geometry.pointcloud.cost_fn` of
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values (:class:`~ott.geometry.costs.SqEuclidean` by default, see
:class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand All @@ -266,19 +265,19 @@ def ranks(

def quantile(
inputs: jnp.ndarray,
q: jnp.ndarray,
axis: int = -1,
q: Optional[jnp.ndarray] = None,
weight: Optional[Union[float, jnp.ndarray]] = None,
**kwargs: Any,
) -> jnp.ndarray:
r"""Apply the soft quantiles operator on the input tensor.
For instance:
```
x = jax.random.uniform(rng, (100,))
x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8]))
```
.. code-block:: python
x = jax.random.uniform(rng, (100,))
x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8]))
``x_quantiles`` will hold an approximation to the 20 and 80 percentiles in
``x``, computed as a convex combination (a weighted mean, with weights summing
Expand All @@ -289,16 +288,17 @@ def quantile(
impact all values listed in ``x``, not just those indexed at 20 and 80).
The non-differentiable version is given by :func:`jax.numpy.quantile`, e.g.
```
x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
```
.. code-block:: python
x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
Args:
inputs: a jnp.ndarray<float> of any shape.
axis: the axis on which to apply the operator.
q: values of the quantile level to be computed, e.g. [0.5] for median.
These values should all lie in :math:`[0,1]` and are selected as
``[0.2, 0.5, 0.8]`` by default.
These values should all lie in :math:`[0,1]`.
axis: the axis on which to apply the operator.
weight: the weight assigned to each quantile target value in the OT problem.
This weight should be small, typically of the order of ``1/n``, where ``n``
is the size of ``x``. Note: Since the size of ``q`` times ``weight``
Expand All @@ -309,11 +309,10 @@ def quantile(
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :attr:`~ott.geometry.pointcloud.cost_fn` of
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values (:class:`~ott.geometry.costs.SqEuclidean` by default, see
:class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -379,7 +378,7 @@ def _quantile(
jnp.ones((num_quantiles + 1, 1), dtype=bool)
],
axis=1).ravel()[:-1]
return (out[odds])[idx]
return out[odds][idx]

return apply_on_axis(_quantile, inputs, axis, q, weight, **kwargs)

Expand Down Expand Up @@ -412,27 +411,26 @@ def quantile_normalization(
inputs: array of any shape whose values will be changed to match those in
``targets``.
targets: sorted array (in ascending order) of dimension 1 describing a
discrete distribution. Note: the``targets`` values must be provided as
discrete distribution. Note: the ``targets`` values must be provided as
a sorted vector.
weights: vector of nonnegative weights, summing to :math:`1.0`, of the same
weights: vector of nonnegative weights, summing to :math:`1`, of the same
size as ``targets``. When not set, this defaults to the uniform
distribution.
axis: the axis along which the quantile transformation is applied.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :attr:`~ott.geometry.pointcloud.cost_fn` of
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values (:class:`~ott.geometry.costs.SqEuclidean` by default, see
:class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray, which has the same shape as the input, except on the give
axis on which the dimension is 1.
An array, which has the same shape as the input, except on the give axis on
which the dimension is 1.
Raises:
A ValueError in case the weights and the targets are both set and not of
Expand Down Expand Up @@ -468,24 +466,23 @@ def sort_with(
smaller indices will contain combinations of vectors with smaller criterion.
Args:
inputs: the inputs as a jnp.ndarray[batch, dim].
inputs: Array of size [batch, dim].
criterion: the values according to which to sort the inputs. It has shape
[batch, 1].
topk: The number of outputs to keep.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :attr:`~ott.geometry.pointcloud.cost_fn` of
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values (:class:`~ott.geometry.costs.SqEuclidean` by default, see
:class:`~ott.geometry.pointcloud.PointCloud`); ``epsilon`` regularization
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray[batch | topk, dim].
An Array of size [batch | topk, dim].
"""
num_points = criterion.shape[0]
weights = jnp.ones(num_points, dtype=criterion.dtype) / num_points
Expand Down Expand Up @@ -538,19 +535,22 @@ def quantize(
differentiable.
Args:
inputs: the inputs as a jnp.ndarray[batch, dim].
num_levels: number of q available to quantize the signal.
inputs: an Array of size [batch, dim].
num_levels: number of quantiles available to quantize the signal.
axis: axis along which quantization is carried out.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in [0,1] (sigmoid of whitened values by default) to
solve the optimal transport problem; ``cost_fn``, used in ``PointCloud``,
that defines the ground cost function to transport from ``inputs`` to the
``num_targets`` target values (squared Euclidean distance by default, see
``pointcloud.py`` for more details); ``epsilon`` values as well as other
parameters to shape the ``sinkhorn`` algorithm.
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
attribute :class:`cost_fn <ott.geometry.costs.CostFn>` of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
cost function to transport from ``inputs`` to the ``num_targets`` target
values ; ``epsilon`` regularization
parameter. Remaining ``kwargs`` are passed on to defined the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
A jnp.ndarray of the same size as ``inputs``.
An Array of the same size as ``inputs``.
"""
return apply_on_axis(_quantize, inputs, axis, num_levels, **kwargs)

0 comments on commit 202ed76

Please sign in to comment.