Skip to content

Commit

Permalink
update nb (#403)
Browse files Browse the repository at this point in the history
* update nb

* ref

* ref

* no need to add ref!

* pydocs

* bump flax

* comments + orbax fix?

* flax?

* Pin `flax` version in the CI

* Try `--no-deps`

* Fix some math rendering

---------

Co-authored-by: Michal Klein <46717574+michalk8@users.noreply.github.com>
  • Loading branch information
marcocuturi and michalk8 committed Aug 1, 2023
1 parent 738547b commit 091e6f4
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 70 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,12 @@ jobs:
- name: Install dependencies
# `jax[cuda]<0.4` because of Docker issues: https://github.com/google/jax/issues/13758
# `chex<0.1.7` because it requires `jax>=0.4.6`
# `flax<0.6.5` because it requires `jax>=0.4.2`
# `flax<0.6.5` because it requires `jax>=0.4.2`, --no-deps because of `orbax`
run: |
python3 -m pip install --upgrade pip
python3 -m pip install -e".[test]"
python3 -m pip install "flax<0.6.5" "chex<0.1.7"
python3 -m pip install "orbax-checkpoint" "orbax-export" "chex<0.1.7"
python3 -m pip install --no-deps "flax<0.6.5"
python3 -m pip install "jax[cuda]<0.4" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- name: Nvidia SMI
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/notebooks/MetaOT.ipynb
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"id": "7eb605bd",
"metadata": {},
Expand Down Expand Up @@ -535,6 +534,7 @@
"We lastly compare how much the initializers help\n",
"{class}`~ott.solvers.linear.sinkhorn.Sinkhorn` converge on these problems, \n",
"measured by the marginal error:\n",
"\n",
"$${\\rm err}(f,g; \\alpha, \\beta, c) := \\|P1_m-a\\|_1 + \\|P^\\top1_n-b\\|_1$$"
]
},
Expand Down
6 changes: 4 additions & 2 deletions docs/tutorials/notebooks/OTT_&_POT.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@
"$$\\mu=\\sum_{i=1}^n a_i\\delta_{x_i}, \\nu =\\sum_{j=1}^n b_j\\delta_{y_j},$$\n",
"\n",
"to define the OT problem in its primal form,\n",
"\n",
"$$\\min_{P \\in U(a,b)} \\langle C, P \\rangle - \\varepsilon H(P).$$\n",
"\n",
"where $U(a,b):=\\{P \\in \\mathbf{R}_+^{n\\times n}, P\\mathbf{1}_{n}=b, P^T\\mathbf{1}_n=b\\}$, and $C = [ \\|x_i - y_j \\|^2 ]_{i,j}\\in \\mathbf{R}_+^{n\\times n}$.\n",
"\n",
"That problem is equivalent to the following dual form,\n",
"\n",
"$$\\max_{f, g} \\langle a, f \\rangle + \\langle b, g \\rangle - \\varepsilon \\langle e^{f/\\varepsilon},Ke^{g/\\varepsilon} \\rangle.$$\n",
"\n",
"These two problems are solved by `OTT` and `POT` using the {class}`~ott.solvers.linear.sinkhorn.Sinkhorn` iterations with a simple initialization for $u$, and subsequent updates $v \\leftarrow a / K^Tu, u \\leftarrow b / Kv$, where $K:=e^{-C/\\varepsilon}$.\n",
"\n",
"Upon convergence to fixed points $u^*, v^*$, one has $$P^*=D(u^*)KD(v^*)$$ or, alternatively, \n",
"$$f^*, g^* = \\varepsilon \\log(u^*), \\varepsilon\\log(v^*)$$"
"Upon convergence to fixed points $u^*, v^*$, one has $P^*=D(u^*)KD(v^*)$ or, alternatively, \n",
"$f^*, g^* = \\varepsilon \\log(u^*), \\varepsilon\\log(v^*)$."
]
},
{
Expand Down
211 changes: 149 additions & 62 deletions docs/tutorials/notebooks/sparse_monge_displacements.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ dependencies = [
"jaxopt>=0.5.5",
# numba/numpy compatibility issue in JAXOPT.
"numpy>=1.18.4, <1.25.0",
"flax>=0.5.2",
"flax>=0.6.6",
"optax>=0.1.1",
"lineax>=0.0.1; python_version >= '3.9'"
]
Expand Down
5 changes: 3 additions & 2 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,14 @@ class RegTICost(TICost, abc.ABC):
r"""Base class for regularized translation-invariant costs.
.. math::
\frac{1}{2} \|\cdot\|_2^2 + \text{scaling_reg} reg\left(\cdot\right)
\frac{1}{2} \|\cdot\|_2^2 + \text{scaling_reg} reg\left(matrix \cdot\right)
where :func:`reg` is the regularization function.
Args:
scaling_reg: Strength of the :meth:`regularization <reg>`.
matrix: :math:`p \times d` projection matrix with **orthogonal rows**.
matrix: :math:`p \times d` projection matrix in the Stiefel manifold,
namely with **orthonormalized rows**.
orthogonal: Whether to regularize in the orthogonal complement
to promote displacements in the span of ``matrix``.
"""
Expand Down

0 comments on commit 091e6f4

Please sign in to comment.