Skip to content

Commit

Permalink
refresh tutorial NBs
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Jun 27, 2024
1 parent d3b6c40 commit 96366f7
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 90 deletions.
21 changes: 10 additions & 11 deletions docs/tutorials/basic_ot_between_datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"\n",
"This short tutorial covers a basic use case for {mod}`ott`:\n",
"\n",
"- Compute an optimal transport between two point clouds. This solves a problem that is described by a {class}`~ott.geometry.pointcloud.PointCloud` geometry object (to describe pairwise distances between the points), which is then fed in the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm. \n",
"- Showcase the seamless integration with {mod}`jax`, to differentiate through that OT distance, and plot the gradient flow of that distance, to morph the first point cloud into the second."
"- Compute an optimal coupling between two point clouds. The problem is first described using a {class}`~ott.geometry.pointcloud.PointCloud` geometry object, storing those point clouds coordinates and their pairwise costs. The problem is then fed to a {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` solver to output various quantities and variables of interest.\n",
"- Showcase the seamless integration with {mod}`jax`, to differentiate through the regularized OT distance outputted by the solver, and plot the gradient flow of that distance, to morph the first point cloud into the second."
]
},
{
Expand Down Expand Up @@ -58,7 +58,7 @@
"id": "7d97950d",
"metadata": {},
"source": [
"{mod}`ott` is built on top of {mod}`jax`, so we use {mod}`jax` to instantiate all variables. We generate two 2-dimensional random point clouds of $7$ and $11$ points, respectively, and store them in variables `x` and `y`:"
"{mod}`ott` is built on top of {mod}`jax`, so we use {mod}`jax.numpy` arrays to instantiate all variables. We generate two 2-dimensional random point clouds of $7$ and $11$ points, respectively, and store them in variables `x` and `y`:"
]
},
{
Expand All @@ -82,7 +82,7 @@
"id": "082158c3",
"metadata": {},
"source": [
"Because these point clouds are 2-dimensional, we can use scatter plots to illustrate them."
"Because these point clouds are 2-dimensional, we use scatter plots to display them."
]
},
{
Expand Down Expand Up @@ -120,7 +120,7 @@
"source": [
"## Optimal transport with {mod}`ott`\n",
"\n",
"We will now use {mod}`ott` to compute the optimal transport between `x` and `y`. To do so, we first create a `geom` object that stores the geometry (a.k.a. the ground cost) between `x` and `y`:"
"We use {mod}`ott` to compute the optimal transport between `x` and `y`. To do so, we first create a `geom` object that stores the geometry (a.k.a. the ground cost) between `x` and `y`:"
]
},
{
Expand All @@ -139,7 +139,7 @@
"id": "aafe996a",
"metadata": {},
"source": [
"`geom` holds the two datasets `x` and `y`, as well as a `cost_fn`, a function used to quantify a cost between two points. Here, we passed no `cost_fn`; this defaults to `cost_fn` equal to {class}`~ott.geometry.costs.SqEuclidean`, the usual squared-Euclidean distance between two points, $c(x,y)=\\|x-y\\|^2_2$.\n",
"`geom` holds the two datasets `x` and `y`, as well as a `cost_fn`, a function used to quantify a cost between two points. Here, we passed no `cost_fn`; this defaults to `cost_fn` equal to {class}`~ott.geometry.costs.SqEuclidean`, the squared-Euclidean distance between two points, $c(x,y)=\\|x-y\\|^2_2$.\n",
"\n",
"In order to compute the optimal coupling corresponding to `geom`, we use the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm, wrapped in the convenience wrapper {func}`~ott.solvers.linear.solve`. The Sinkhorn algorithm will use a regularization hyperparameter `epsilon`, which is typically of the scale of $c(x,y)$ found across point-clouds. For this reason, {mod}`ott` stores that parameter in `geom`, and uses by default the twentieth of the mean cost between all points in `x` and `y`. While it is also possible to set probably weights `a` for each point in `x` (and `b` for `y`), these will default to uniform by default when not passed, here $1/7$ and $1/13$, since $n=7$ and $m=13$."
]
Expand Down Expand Up @@ -216,7 +216,7 @@
"id": "0ef65792",
"metadata": {},
"source": [
"`ot` is a {class}`~ott.solvers.linear.sinkhorn.SinkhornOutput` object that stores many more things, notably a lower, as well as an upper bound of the \"true\" squared 2-Wasserstein metric between `x` and `y` (the gap between these two bounds can be made arbitrarily small as `epsilon` decreases, when `geom` is instantiated)."
"`ot` is a {class}`~ott.solvers.linear.sinkhorn.SinkhornOutput` object that stores many more things, notably a lower, as well as an upper bound of the \"true\" OT cost between `x` and `y` (the gap between these two bounds can be made arbitrarily small as `epsilon` decreases, when `geom` is instantiated)."
]
},
{
Expand All @@ -234,9 +234,7 @@
}
],
"source": [
"print(\n",
" f\"2-Wasserstein: Lower bound = {ot.dual_cost:3f}, upper = {ot.primal_cost:3f}\"\n",
")"
"print(f\"OT Cost, lower bound = {ot.dual_cost:3f}, upper = {ot.primal_cost:3f}\")"
]
},
{
Expand All @@ -247,7 +245,7 @@
"source": [
"## Automatic differentiation using {mod}`jax`\n",
"\n",
"We finish this quick tour by illustrating one of the main features of {mod}`ott`: it can be seamlessly integrated into differentiable, end-to-end architectures built using {mod}`jax` (see also {doc}`Hessians`) for an example exploiting implicit differentiation).\n",
"We finish this quick tour by illustrating one of the main features of {mod}`ott`: it can be seamlessly integrated into differentiable, end-to-end architectures built using {mod}`jax` (see also {doc}`Hessians`) for an example exploiting unrolling or implicit differentiation).\n",
"\n",
"We provide a simple use-case where we differentiate the (regularized) OT transport cost w.r.t. `x`,\n",
"by defining a wrapper that takes `x` and `y` as input, to output their regularized OT cost."
Expand Down Expand Up @@ -354,6 +352,7 @@
"- {doc}`LRSinkhorn` for faster solvers that constraint coupling matrices (see plot above) to have a low-rank factorization, and exploit low-rank properties of {class}{class}`~ott.geometry.geometry.Geometry` objects, both for the standard OT problem\n",
"and its GW variant in {doc}`GWLRSinkhorn`.\n",
"- Wasserstein barycenters, as in {doc}`wasserstein_barycenters_gmms` or {doc}`Sinkhorn_Barycenters`,\n",
"- Multimarginal generalizations in {doc}`mmsink`,\n",
"- Differentiable sorting in {doc}`soft_sort`,\n",
"- Neural solvers in {doc}`neural_dual`, to estimate maps in functional form.\n",
"- Visual interface to plot progress bars in {doc}`tracking_progress`."
Expand Down
149 changes: 70 additions & 79 deletions docs/tutorials/gromov_wasserstein_multiomics.ipynb

Large diffs are not rendered by default.

0 comments on commit 96366f7

Please sign in to comment.