Skip to content

Commit

Permalink
Deprecate power in PointCloud, introduce TICost and use it to c…
Browse files Browse the repository at this point in the history
…ompute Entropic (Brenier) maps. (#167)

* deperecate `power`, introduce h maps in potentials

* Deprecate power and introduce h function in costs.

* linter

* linter

* revert abstractmethod.

* linter

* linter

* PNorm -> SqPNorm

* PNorm -> SqPNorm in tests.

* another fix for abstract method.

* fix abc.abstractmethod

* linter

* nb fix

* linter

* nb bug fix

* modify ipynb

* abc.abstractmethod for RBF

* fixes and additions.

* fix `cor` in neuraldual

* fix in neuraldual

* p-norm ** p implemented, fixes.

* various fixes. Change to `TICost`

* various fixes

* fix nb

* last fixes.
  • Loading branch information
marcocuturi committed Nov 15, 2022
1 parent 7249909 commit c9ff6fb
Show file tree
Hide file tree
Showing 19 changed files with 216,017 additions and 146,680 deletions.
3 changes: 3 additions & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Cost Functions
:toctree: _autosummary

costs.CostFn
costs.TICost
costs.SqPNorm
costs.PNorm
costs.SqEuclidean
costs.Euclidean
costs.Cosine
Expand Down
10 changes: 5 additions & 5 deletions docs/notebooks/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@
"outputs": [],
"source": [
"@jax.jit\n",
"def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):\n",
"def sinkhorn_loss(x, y, epsilon=0.1):\n",
" \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n",
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, power=power, epsilon=epsilon, a=a, b=b\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
]
Expand Down Expand Up @@ -535,7 +535,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.15 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -549,11 +549,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "ba22eb0a90cf9680fd06e72916a6996fb9b27a2ebc703f47aacd356a82bf9683"
"hash": "a665b5d41d17b532ea9890333293a1b812fa0b73c9c25c950b3cedf1bebd0438"
}
}
},
Expand Down
362,122 changes: 215,625 additions & 146,497 deletions docs/notebooks/point_clouds.ipynb

Large diffs are not rendered by default.

39 changes: 34 additions & 5 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ @Article{benamou:15
URL = {https://doi.org/10.1137/141000439},
eprint = {https://doi.org/10.1137/141000439}
}
@article{brenier:91,
title={Polar factorization and monotone rearrangement of vector-valued functions},
author={Brenier, Yann},
journal={Communications on pure and applied mathematics},
volume={44},
number={4},
pages={375--417},
year={1991},
publisher={Wiley Online Library}
}

@InProceedings{cuturi:13,
author = {Cuturi, Marco},
Expand Down Expand Up @@ -544,6 +554,17 @@ @Article{heitz:21
url = "https://doi.org/10.1007/s10851-020-00996-z"
}

@article{santambrogio:15,
title={Optimal transport for applied mathematicians},
author={Santambrogio, Filippo},
journal={Birk{\"a}user, NY},
volume={55},
number={58-63},
pages={94},
year={2015},
publisher={Springer}
}

@article{cholmod:08,
author = {Chen, Yanqing and Davis, Timothy A. and Hager, William W. and Rajamanickam, Sivasankaran},
title = {Algorithm 887: CHOLMOD, Supernodal Sparse Cholesky Factorization and Update/Downdate},
Expand Down Expand Up @@ -615,9 +636,17 @@ @article{amos:22
}

@inproceedings{korotin:21,
title={Wasserstein-2 Generative Networks},
author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
title={Wasserstein-2 Generative Networks},
author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
}

@book{boyd:04,
title={Convex optimization},
author={Boyd, Stephen and Boyd, Stephen P and Vandenberghe, Lieven},
year={2004},
url={https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf},
publisher={Cambridge university press}
}
5 changes: 3 additions & 2 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def update_features(self, transports: jnp.ndarray,
transports = transports * inv_a[None, :, None]

if self._loss_name == "sqeucl":
cost = costs.SqEuclidean()
cost_fn = costs.SqEuclidean()
return jnp.sum(
weights * barycentric_projection(transports, y_fused, cost), axis=0
weights * barycentric_projection(transports, y_fused, cost_fn),
axis=0
)
raise NotImplementedError(self._loss_name)

Expand Down
3 changes: 2 additions & 1 deletion ott/core/neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Literal

from ott.core import icnn, potentials
from ott.geometry import costs

Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]]
Potentials_t = potentials.DualPotentials
Expand Down Expand Up @@ -306,7 +307,7 @@ def to_dual_potentials(self) -> potentials.DualPotentials:
"""Return the Kantorovich dual potentials from the trained potentials."""
f = lambda x: self.state_f.apply_fn({"params": self.state_f.params}, x)
g = lambda x: self.state_g.apply_fn({"params": self.state_g.params}, x)
return potentials.DualPotentials(f, g, cor=True)
return potentials.DualPotentials(f, g, costs.SqEuclidean(), corr=True)

@staticmethod
def _clip_weights_icnn(params):
Expand Down
81 changes: 51 additions & 30 deletions ott/core/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.geometry import pointcloud
from ott.geometry import costs, pointcloud

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Expand All @@ -22,23 +22,38 @@ class DualPotentials:
Args:
f: The first dual potential function.
g: The second dual potential function.
cor: whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in
http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf)
cost_fn: The cost function used to solve the OT problem.
corr: Whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in :cite:`makkuva:20`)
"""

def __init__(self, f: Potential_t, g: Potential_t, *, cor: bool = False):
def __init__(
self,
f: Potential_t,
g: Potential_t,
cost_fn: costs.CostFn,
*,
corr: bool = False
):
self._f = f
self._g = g
self._cor = cor
self.cost_fn = cost_fn
self._corr = corr

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
"""Transport ``vec`` according to Brenier formula.
r"""Transport ``vec`` according to Brenier formula.
Theorem 1.17 in http://math.univ-lyon1.fr/~santambrogio/OTAM-cvgmt.pdf
for case h(.) = ||.||^2, ∇h(.) = 2 ., [∇h]^-1(.) = 0.5 * .
Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.
or, when solved in correlation form, as ∇g for forward, ∇f for backward.
That OT map can be recovered as :math:`x- (\nabla h)^{-1}\circ \nabla f(x)`
For the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
and as a consequence :math:`h^*(\cdot) = \|.\|^2 / 4`, while one has that
:math:`\nabla h^*(\cdot) = (\nabla h)^{-1}(\cdot) = 0.5 \cdot\,`.
When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:`\nabla g` for forward,
:math:`\nabla f` for backward.
Args:
vec: Points to transport, array of shape ``[n, d]``.
Expand All @@ -49,29 +64,31 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
The transported points.
"""
vec = jnp.atleast_2d(vec)
if self._cor:
if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_g(vec) if forward else self._grad_f(vec)
return vec - 0.5 * (self._grad_f(vec) if forward else self._grad_g(vec))
if forward:
return vec - self._grad_h_inv(self._grad_f(vec))
else:
return vec - self._grad_h_inv(self._grad_g(vec))

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.
Uses Eq. 5 from :cite:`makkuva:20` when given in cor form, direct estimation
by integrating dual function against points when using dual form.
Uses Eq. 5 from :cite:`makkuva:20` when given in `corr` form, direct
estimation by integrating dual function against points when using dual form.
Args:
src: Samples from the source distribution, array of shape ``[n, d]``.
tgt: Samples from the target distribution, array of shape ``[m, d]``.
Returns:
Wasserstein distance :math:`W^2_2`, assuming :math:`|x-y|^2` as the
ground distance.
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._cor:
if self._corr:
grad_g_y = self._grad_g(tgt)
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))
Expand All @@ -85,9 +102,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C += jnp.mean(g(tgt))
return C

# compute the final Wasserstein distance assuming ground metric |x-y|^2,
# thus an additional multiplication by 2

@property
def f(self) -> Potential_t:
"""The first dual potential function."""
Expand All @@ -108,8 +122,16 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`g`."""
return jax.vmap(jax.grad(self.g, argnums=0))

@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g], {"cor": self._cor}
return [self._f, self._g, self.cost_fn], {"corr": self._corr}

@classmethod
def tree_unflatten(
Expand All @@ -127,6 +149,8 @@ class EntropicPotentials(DualPotentials):
g: The second dual potential vector of shape ``[m,]``.
geom: Geometry used to compute the dual potentials using
:class:`~ott.core.sinkhorn.Sinkhorn`.
a: probability weights for the first measure.
b: probaility weights for the second measure.
"""

def __init__(
Expand All @@ -141,7 +165,7 @@ def __init__(

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g)
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
self._geom = geom
self._a = a
self._b = b
Expand All @@ -160,16 +184,13 @@ def _create_potential_function(

def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self._geom.cost_fn,
power=self._geom.power,
epsilon=1.0 # epsilon is not used
jnp.atleast_2d(x), y, cost_fn=self._geom.cost_fn
).cost_matrix
return -eps * jsp.special.logsumexp((potential - cost) / eps,
b=prob_weights)
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

eps = self.epsilon
epsilon = self.epsilon
if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
Expand Down
10 changes: 6 additions & 4 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Because Protocol is not available in Python < 3.8
from typing_extensions import Literal, Protocol

from ott.core import _math_utils as mu
from ott.core import linear_problems, sinkhorn_lr
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud

Expand Down Expand Up @@ -58,16 +59,17 @@ class GWLoss(NamedTuple):
def make_square_loss() -> GWLoss:
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
h1 = Loss(lambda x: jnp.sqrt(2) * x, is_linear=True)
h2 = Loss(lambda y: jnp.sqrt(2) * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
def make_kl_loss(clipping_value: Optional[float] = None) -> GWLoss:
assert clipping_value is None, "Clipping deprecated in KL definition."
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
h2 = Loss(lambda y: mu.safe_log(y), is_linear=False)
return GWLoss(f1, f2, h1, h2)


Expand Down
Loading

0 comments on commit c9ff6fb

Please sign in to comment.