Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute multiple soft-quantiles in one execution without using vmap #382

Merged
merged 7 commits into from
Jul 3, 2023

Conversation

marcocuturi
Copy link
Contributor

This PR is a follow up to
#373
This implements a quantiles function in the soft_sort module to return simultenaously multiple quantile values. Should be more efficient than the vmap proposed in the discussion, but will likely return very slightly different results.

@marcocuturi marcocuturi changed the title Possibility to compute multilple quantiles at once Possibility to compute multiple quantiles at once Jul 1, 2023
@marcocuturi marcocuturi changed the title Possibility to compute multiple quantiles at once Compute multiple soft-quantiles in one execution without using vmap Jul 1, 2023
@marcocuturi
Copy link
Contributor Author

numba, via jaxopt, seems to be causing the issue.

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@codecov-commenter
Copy link

codecov-commenter commented Jul 3, 2023

Codecov Report

Merging #382 (4268043) into main (3b4f7b6) will increase coverage by 0.02%.
The diff coverage is 96.55%.

❗ Your organization is not using the GitHub App Integration. As a result you may experience degraded service beginning May 15th. Please install the Github App Integration for your organization. Read more.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #382      +/-   ##
==========================================
+ Coverage   88.51%   88.53%   +0.02%     
==========================================
  Files          52       52              
  Lines        5660     5679      +19     
  Branches      839      841       +2     
==========================================
+ Hits         5010     5028      +18     
  Misses        530      530              
- Partials      120      121       +1     
Impacted Files Coverage Δ
src/ott/solvers/linear/sinkhorn.py 96.94% <ø> (ø)
src/ott/tools/soft_sort.py 95.23% <96.55%> (-0.09%) ⬇️

@marcocuturi marcocuturi requested a review from michalk8 July 3, 2023 14:23
@marcocuturi marcocuturi merged commit b2b7ebb into main Jul 3, 2023
10 checks passed
@marcocuturi marcocuturi deleted the mott4 branch July 3, 2023 21:48
tests/tools/soft_sort_test.py Show resolved Hide resolved
q = jnp.array([.1, .8, .4])
m1 = soft_sort.quantile(inputs, q=q, weight=None, axis=0)
np.testing.assert_allclose(m1.mean(axis=[1, 2]), q, atol=5e-2)
m2 = soft_sort.quantile(inputs, q=q, weight=None, axis=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does m2 exist since it's the same as m1?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I think this was naive byproduct of starting PR with 2 different interfaces (quantile and quantiles)

@@ -141,28 +146,50 @@ def sort(
) -> jnp.ndarray:
r"""Apply the soft sort operator on a given axis of the input.

For instance:

```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the python code blocks (also in other places), e.g.,:

.. code-block:: python

  x = jax.random.uniform(rng, (100,))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, it doesn't render nicely, see here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

jnp.ones((num_quantiles + 1, 1), dtype=bool)
],
axis=1).ravel()[:-1]
return (out[odds])[idx]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unnecessary ()


Returns:
A jnp.ndarray of the same shape as the input with soft sorted values on the
A jnp.ndarray of the same shape as the input with soft-sorted values on the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's a good time to refactor and say An array of the same shape ...

targets: sorted array (in ascending order) of dimension 1 describing a
discrete distribution. Note: the``targets`` values must be provided as
a sorted vector.
weights: vector of nonnegative weights, summing to :math:`1.0`, of the same
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-negative

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

targets: sorted array (in ascending order) of dimension 1 describing a
discrete distribution. Note: the``targets`` values must be provided as
a sorted vector.
weights: vector of nonnegative weights, summing to :math:`1.0`, of the same
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:math:`1`

inputs: array of any shape whose values will be changed to match those in
``targets``.
targets: sorted array (in ascending order) of dimension 1 describing a
discrete distribution. Note: the``targets`` values must be provided as
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the ``targets`` (missing space, doesn't render correctly)

specified by the optimal transport between values in ``inputs`` towards
those values. If not specified, ``num_targets`` is set by default to be
the size of the slices of the input that are sorted.
inputs: jnp.ndarray<float> of any shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Array of any shape.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also elsewhere, if possible.

num_points = inputs.shape[0]
q = jnp.array([0.2, 0.5, 0.8]) if q is None else jnp.atleast_1d(q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are defaults needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, pure implementation question: why pass these as an array, not e.g., as a tuple? To be able to differentiate w.r.t. to it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's also a question of being able to re-run by changing quantile values, without jitting again, pending that the number of quantiles does not change.

Copy link
Contributor Author

@marcocuturi marcocuturi Jul 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you prefer no defaults? it was there because there was a default previously (the median) but maybe you're right, better stick to jax's quantile API

marcocuturi added a commit that referenced this pull request Jul 4, 2023
marcocuturi added a commit that referenced this pull request Jul 4, 2023
* comments by Michal in #382

* test
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
…#382)

* add quantiles

* numpy

* using Michal's fix

* chg threshold in kmeans test

* changing quantile API to match jnp's + pydocs

* impact chg in NB

* pydocs
michalk8 pushed a commit that referenced this pull request Jun 27, 2024
* comments by Michal in #382

* test
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants