Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate power in PointCloud, introduce TICost and use it to compute Entropic (Brenier) maps. #167

Merged
merged 25 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ Cost Functions
:toctree: _autosummary

costs.CostFn
costs.TICost
costs.SqPNorm
costs.PNorm
costs.SqEuclidean
costs.Euclidean
costs.Cosine
Expand Down
10 changes: 5 additions & 5 deletions docs/notebooks/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@
"outputs": [],
"source": [
"@jax.jit\n",
"def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):\n",
"def sinkhorn_loss(x, y, epsilon=0.1):\n",
" \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n",
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, power=power, epsilon=epsilon, a=a, b=b\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
]
Expand Down Expand Up @@ -535,7 +535,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.15 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -549,11 +549,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "ba22eb0a90cf9680fd06e72916a6996fb9b27a2ebc703f47aacd356a82bf9683"
"hash": "a665b5d41d17b532ea9890333293a1b812fa0b73c9c25c950b3cedf1bebd0438"
}
}
},
Expand Down
362,122 changes: 215,625 additions & 146,497 deletions docs/notebooks/point_clouds.ipynb

Large diffs are not rendered by default.

39 changes: 34 additions & 5 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ @Article{benamou:15
URL = {https://doi.org/10.1137/141000439},
eprint = {https://doi.org/10.1137/141000439}
}
@article{brenier:91,
title={Polar factorization and monotone rearrangement of vector-valued functions},
author={Brenier, Yann},
journal={Communications on pure and applied mathematics},
volume={44},
number={4},
pages={375--417},
year={1991},
publisher={Wiley Online Library}
}

@InProceedings{cuturi:13,
author = {Cuturi, Marco},
Expand Down Expand Up @@ -544,6 +554,17 @@ @Article{heitz:21
url = "https://doi.org/10.1007/s10851-020-00996-z"
}

@article{santambrogio:15,
title={Optimal transport for applied mathematicians},
author={Santambrogio, Filippo},
journal={Birk{\"a}user, NY},
volume={55},
number={58-63},
pages={94},
year={2015},
publisher={Springer}
}

@article{cholmod:08,
author = {Chen, Yanqing and Davis, Timothy A. and Hager, William W. and Rajamanickam, Sivasankaran},
title = {Algorithm 887: CHOLMOD, Supernodal Sparse Cholesky Factorization and Update/Downdate},
Expand Down Expand Up @@ -615,9 +636,17 @@ @article{amos:22
}

@inproceedings{korotin:21,
title={Wasserstein-2 Generative Networks},
author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
title={Wasserstein-2 Generative Networks},
author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
}

@book{boyd:04,
title={Convex optimization},
author={Boyd, Stephen and Boyd, Stephen P and Vandenberghe, Lieven},
year={2004},
url={https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf},
publisher={Cambridge university press}
}
5 changes: 3 additions & 2 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def update_features(self, transports: jnp.ndarray,
transports = transports * inv_a[None, :, None]

if self._loss_name == "sqeucl":
cost = costs.SqEuclidean()
cost_fn = costs.SqEuclidean()
return jnp.sum(
weights * barycentric_projection(transports, y_fused, cost), axis=0
weights * barycentric_projection(transports, y_fused, cost_fn),
axis=0
)
raise NotImplementedError(self._loss_name)

Expand Down
3 changes: 2 additions & 1 deletion ott/core/neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Literal

from ott.core import icnn, potentials
from ott.geometry import costs

Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]]
Potentials_t = potentials.DualPotentials
Expand Down Expand Up @@ -306,7 +307,7 @@ def to_dual_potentials(self) -> potentials.DualPotentials:
"""Return the Kantorovich dual potentials from the trained potentials."""
f = lambda x: self.state_f.apply_fn({"params": self.state_f.params}, x)
g = lambda x: self.state_g.apply_fn({"params": self.state_g.params}, x)
return potentials.DualPotentials(f, g, cor=True)
return potentials.DualPotentials(f, g, costs.SqEuclidean(), corr=True)

@staticmethod
def _clip_weights_icnn(params):
Expand Down
81 changes: 51 additions & 30 deletions ott/core/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.geometry import pointcloud
from ott.geometry import costs, pointcloud

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same vein as cost_fn, I'm wondering if Potential_t could be renamed to PotentialFn_t or just PotentialFn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same reflex when Michal used it for the first time, but I think it makes sense :) Here it turns out this is just a type (_t) and can be either a vector of a function.

Expand All @@ -22,23 +22,38 @@ class DualPotentials:
Args:
f: The first dual potential function.
g: The second dual potential function.
cor: whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in
http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf)
cost_fn: The cost function used to solve the OT problem.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
corr: Whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in :cite:`makkuva:20`)
"""

def __init__(self, f: Potential_t, g: Potential_t, *, cor: bool = False):
def __init__(
self,
f: Potential_t,
g: Potential_t,
cost_fn: costs.CostFn,
*,
corr: bool = False
):
self._f = f
self._g = g
self._cor = cor
self.cost_fn = cost_fn
self._corr = corr

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
"""Transport ``vec`` according to Brenier formula.
r"""Transport ``vec`` according to Brenier formula.
Theorem 1.17 in http://math.univ-lyon1.fr/~santambrogio/OTAM-cvgmt.pdf
for case h(.) = ||.||^2, ∇h(.) = 2 ., [∇h]^-1(.) = 0.5 * .
Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.
or, when solved in correlation form, as ∇g for forward, ∇f for backward.
That OT map can be recovered as :math:`x- (\nabla h)^{-1}\circ \nabla f(x)`
For the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
and as a consequence :math:`h^*(\cdot) = \|.\|^2 / 4`, while one has that
:math:`\nabla h^*(\cdot) = (\nabla h)^{-1}(\cdot) = 0.5 \cdot\,`.
When the dual potentials are solved in correlation form (only in the Sq.
Euclidean distance case), the maps are :math:`\nabla g` for forward,
:math:`\nabla f` for backward.
Args:
vec: Points to transport, array of shape ``[n, d]``.
Expand All @@ -49,29 +64,31 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
The transported points.
"""
vec = jnp.atleast_2d(vec)
if self._cor:
if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_g(vec) if forward else self._grad_f(vec)
return vec - 0.5 * (self._grad_f(vec) if forward else self._grad_g(vec))
if forward:
return vec - self._grad_h_inv(self._grad_f(vec))
else:
return vec - self._grad_h_inv(self._grad_g(vec))

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.
Uses Eq. 5 from :cite:`makkuva:20` when given in cor form, direct estimation
by integrating dual function against points when using dual form.
Uses Eq. 5 from :cite:`makkuva:20` when given in `corr` form, direct
estimation by integrating dual function against points when using dual form.
Args:
src: Samples from the source distribution, array of shape ``[n, d]``.
tgt: Samples from the target distribution, array of shape ``[m, d]``.
Returns:
Wasserstein distance :math:`W^2_2`, assuming :math:`|x-y|^2` as the
ground distance.
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._cor:
if self._corr:
grad_g_y = self._grad_g(tgt)
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))
Expand All @@ -85,9 +102,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C += jnp.mean(g(tgt))
return C

# compute the final Wasserstein distance assuming ground metric |x-y|^2,
# thus an additional multiplication by 2

@property
def f(self) -> Potential_t:
"""The first dual potential function."""
Expand All @@ -108,8 +122,16 @@ def _grad_g(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
"""Vectorized gradient of the potential function :attr:`g`."""
return jax.vmap(jax.grad(self.g, argnums=0))

@property
def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
assert isinstance(self.cost_fn, costs.TICost), (
"Cost must be a `TICost` and "
"provide access to Legendre transform of `h`."
)
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g], {"cor": self._cor}
return [self._f, self._g, self.cost_fn], {"corr": self._corr}

@classmethod
def tree_unflatten(
Expand All @@ -127,6 +149,8 @@ class EntropicPotentials(DualPotentials):
g: The second dual potential vector of shape ``[m,]``.
geom: Geometry used to compute the dual potentials using
:class:`~ott.core.sinkhorn.Sinkhorn`.
a: probability weights for the first measure.
b: probaility weights for the second measure.
"""

def __init__(
Expand All @@ -141,7 +165,7 @@ def __init__(

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g)
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
self._geom = geom
self._a = a
self._b = b
Expand All @@ -160,16 +184,13 @@ def _create_potential_function(

def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self._geom.cost_fn,
power=self._geom.power,
epsilon=1.0 # epsilon is not used
jnp.atleast_2d(x), y, cost_fn=self._geom.cost_fn
).cost_matrix
return -eps * jsp.special.logsumexp((potential - cost) / eps,
b=prob_weights)
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

eps = self.epsilon
epsilon = self.epsilon
if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
Expand Down
10 changes: 6 additions & 4 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Because Protocol is not available in Python < 3.8
from typing_extensions import Literal, Protocol

from ott.core import _math_utils as mu
from ott.core import linear_problems, sinkhorn_lr
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud

Expand Down Expand Up @@ -58,16 +59,17 @@ class GWLoss(NamedTuple):
def make_square_loss() -> GWLoss:
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
h1 = Loss(lambda x: jnp.sqrt(2) * x, is_linear=True)
h2 = Loss(lambda y: jnp.sqrt(2) * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
def make_kl_loss(clipping_value: Optional[float] = None) -> GWLoss:
assert clipping_value is None, "Clipping deprecated in KL definition."
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
h2 = Loss(lambda y: mu.safe_log(y), is_linear=False)
return GWLoss(f1, f2, h1, h2)


Expand Down
Loading