Skip to content

Commit

Permalink
fix typos (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
pierreablin committed Jul 18, 2023
1 parent e04f016 commit 6becead
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/ott/tools/soft_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,12 @@ def topk_mask(
k = 5
x = jax.random.uniform(rng, (100,))
mask = top_k_mask(x, k=k)
mask = topk_mask(x, k=k)
will output a vector of shape ``x.shape``, with values in :math:`[0,1]`, that
are differentiable approximations to the binary mask selecting the top $k$
entries in ``x``. These should be compared to the non-differentiable mask
obtained with :func:`jax.numpy.argsort`, which can be obtained as:
obtained with :func:`jax.numpy.sort`, which can be obtained as:
.. code-block:: python
Expand Down Expand Up @@ -341,7 +341,7 @@ def quantile(
.. code-block:: python
x = jax.random.uniform(rng, (100,))
x_quantiles = quantiles(x, q=jnp.array([0.2, 0.8]))
x_quantiles = quantile(x, q=jnp.array([0.2, 0.8]))
``x_quantiles`` will hold an approximation to the 20 and 80 percentiles in
``x``, computed as a convex combination (a weighted mean, with weights summing
Expand Down

0 comments on commit 6becead

Please sign in to comment.