Skip to content

Commit

Permalink
Compute multiple soft-quantiles in one execution without using vmap (
Browse files Browse the repository at this point in the history
…#382)

* add quantiles

* numpy

* using Michal's fix

* chg threshold in kmeans test

* changing quantile API to match jnp's + pydocs

* impact chg in NB

* pydocs
  • Loading branch information
marcocuturi committed Jul 3, 2023
1 parent 3b4f7b6 commit b2b7ebb
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 97 deletions.
16 changes: 14 additions & 2 deletions docs/tutorials/notebooks/soft_sort.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "Dg9-8gVqjq2H"
Expand Down Expand Up @@ -172,7 +173,7 @@
}
],
"source": [
"jnp.quantile(x, 0.5)"
"jnp.quantile(x, q=0.5)"
]
},
{
Expand All @@ -196,6 +197,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "mnEkXjwT-Z1C"
Expand Down Expand Up @@ -334,6 +336,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "JKNcCJOe9Dcl"
Expand Down Expand Up @@ -377,6 +380,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "4w4ggUy7zYQX"
Expand Down Expand Up @@ -442,6 +446,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "jQoRgYEd-pkj"
Expand Down Expand Up @@ -494,7 +499,7 @@
],
"source": [
"softquantile = jax.jit(soft_sort.quantile)\n",
"softquantile(x, level=0.5)"
"softquantile(x, q=0.5)"
]
},
{
Expand Down Expand Up @@ -533,6 +538,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -596,6 +602,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "4t3VrtNcmN0R"
Expand All @@ -611,6 +618,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "tqsCC0tunHQh"
Expand Down Expand Up @@ -663,6 +671,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "irSGHYZ7nWuY"
Expand Down Expand Up @@ -709,6 +718,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {
"id": "jK94muT8oAlQ"
Expand Down Expand Up @@ -929,6 +939,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down Expand Up @@ -976,6 +987,7 @@
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ dependencies = [
"jax>=0.1.67",
"jaxlib>=0.1.47",
"jaxopt>=0.5.5",
# https://github.com/google/jax/discussions/9951#discussioncomment-3017784
"numpy>=1.18.4, !=1.23.0",
## https://github.com/google/jax/discussions/9951#discussioncomment-3017784
"numpy>=1.18.4, !=1.25.0",
"flax>=0.5.2",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
Expand Down
4 changes: 3 additions & 1 deletion src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,7 +732,9 @@ class Sinkhorn:
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
initializer: how to compute the initial potentials/scalings.
initializer: how to compute the initial potentials/scalings. This refers to
a few possible classes implemented following the template in
:class:`~ott.initializers.linear.SinkhornInitializer`.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn`
Expand Down
Loading

0 comments on commit b2b7ebb

Please sign in to comment.