Skip to content

Commit

Permalink
Update docs (#398)
Browse files Browse the repository at this point in the history
* Update docs

* Update docs
  • Loading branch information
michalk8 committed Jul 20, 2023
1 parent f3ccc3b commit 79050b1
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ Soft Sorting Algorithms
soft_sort.ranks
soft_sort.sort
soft_sort.sort_with
soft_sort.topk_mask

Clustering
----------
Expand Down
48 changes: 25 additions & 23 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

__all__ = ["sort", "ranks", "quantile"]
__all__ = [
"sort", "ranks", "sort_with", "quantile", "quantile_normalization",
"quantize", "topk_mask"
]


def transport_for_sort(
Expand Down Expand Up @@ -160,14 +163,13 @@ def sort(
x_sorted = jax.numpy.sort(x)
Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
topk: if set to a positive value, the returned vector will only contain
the top-k values. This also reduces the complexity of soft-sorting, since
the number of target points to which the slice of the ``inputs`` tensor
will be mapped to will be equal to ``topk+1``.
will be mapped to will be equal to ``topk + 1``.
num_targets: if ``topk`` is not specified, a vector of size``num_targets``
is returned. This defines the number of (composite) sorted values computed
from the inputs (each value is a convex combination of values recorded in
Expand All @@ -183,7 +185,7 @@ def sort(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -242,9 +244,9 @@ def ranks(
axis: the axis on which to apply the soft-sorting operator.
target_weights: This vector contains weights (summing to 1) that describe
amount of mass shipped to targets.
num_targets: If `target_weights` is ``None``, ``num_targets`` is considered
to define the number of targets used to rank inputs. Each normalized rank
returned in the output will be a convex combination of
num_targets: If ``target_weights` is ``None``, ``num_targets`` is
considered to define the number of targets used to rank inputs. Each
normalized rank in the output will be a convex combination of
``{1, .., num_targets}/num_targets``. The weight of each of these points
is assumed to be uniform. If neither ``num_targets`` nor
``target_weights`` are specified, ``num_targets`` defaults to the size
Expand All @@ -256,14 +258,14 @@ def ranks(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array of the same shape as the input with soft-rank values
normalized to be in `[0, n-1]` where `n` is `inputs.shape[axis]`.
normalized to be in :math:`[0, n-1]` where :math:`n` is
``inputs.shape[axis]``.
"""
return apply_on_axis(
_ranks, inputs, axis, num_targets, target_weights, **kwargs
Expand All @@ -276,7 +278,7 @@ def topk_mask(
k: int = 1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft top-$k$ selection mask.
r"""Soft :math:`\text{top-}k` selection mask.
For instance:
Expand All @@ -298,18 +300,20 @@ def topk_mask(
Args:
inputs: Array of any shape.
axis: the axis on which to apply the soft-sorting operator.
k : topk parameter. Should be smaller than ``inputs.shape[axis]``.
k: topk parameter. Should be smaller than ``inputs.shape[axis]``.
kwargs: keyword arguments passed on to lower level functions. Of interest
to the user are ``squashing_fun``, which will redistribute the values in
``inputs`` to lie in :math:`[0,1]` (sigmoid of whitened values by default)
to solve the optimal transport problem;
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
The soft :math:`\text{top-}k` selection mask.
"""
num_points = inputs.shape[axis]
assert k < num_points, (
Expand Down Expand Up @@ -357,7 +361,6 @@ def quantile(
x_quantiles = jax.numpy.quantile(x, q=jnp.array([0.2, 0.8]))
Args:
inputs: an Array of any shape.
q: values of the quantile level to be computed, e.g. [0.5] for median.
Expand All @@ -376,7 +379,7 @@ def quantile(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -462,9 +465,9 @@ def quantile_normalization(
targets: jnp.ndarray,
weights: Optional[jnp.ndarray] = None,
axis: int = -1,
**kwargs
**kwargs: Any,
) -> jnp.ndarray:
r"""Renormalize inputs so that its quantiles match those of targets/weights.
r"""Re-normalize inputs so that its quantiles match those of targets/weights.
Quantile normalization rearranges the values in inputs to values that match
the distribution of values described in the discrete distribution ``targets``
Expand All @@ -477,7 +480,7 @@ def quantile_normalization(
targets: sorted array (in ascending order) of dimension 1 describing a
discrete distribution. Note: the ``targets`` values must be provided as
a sorted vector.
weights: vector of nonnegative weights, summing to :math:`1`, of the same
weights: vector of non-negative weights, summing to :math:`1`, of the same
size as ``targets``. When not set, this defaults to the uniform
distribution.
axis: the axis along which the quantile transformation is applied.
Expand All @@ -488,7 +491,7 @@ def quantile_normalization(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -541,7 +544,7 @@ def sort_with(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Expand Down Expand Up @@ -584,7 +587,7 @@ def quantize(
axis: int = -1,
**kwargs: Any,
) -> jnp.ndarray:
r"""Soft quantizes an input according using num_levels values along axis.
r"""Soft quantizes an input according using ``num_levels`` values along axis.
The quantization operator consists in concentrating several values around
a few predefined ``num_levels``. The soft quantization operator proposed here
Expand All @@ -609,11 +612,10 @@ def quantize(
:class:`cost_fn <ott.geometry.costs.CostFn>` object of
:class:`~ott.geometry.pointcloud.PointCloud`, which defines the ground
1D cost function to transport from ``inputs`` to the ``num_targets``
target values ; ``epsilon`` regularization parameter. Remaining ``kwargs``
target values; ``epsilon`` regularization parameter. Remaining ``kwargs``
are passed on to parameterize the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
Returns:
An Array of the same size as ``inputs``.
"""
Expand Down

0 comments on commit 79050b1

Please sign in to comment.