Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate power in PointCloud, introduce TICost and use it to compute Entropic (Brenier) maps. #167

Merged
merged 25 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/notebooks/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@
"outputs": [],
"source": [
"@jax.jit\n",
"def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):\n",
"def sinkhorn_loss(x, y, epsilon=0.1):\n",
" \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n",
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, power=power, epsilon=epsilon, a=a, b=b\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
]
Expand Down Expand Up @@ -535,7 +535,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.15 64-bit",
"language": "python",
"name": "python3"
},
Expand All @@ -549,11 +549,11 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.9.15"
},
"vscode": {
"interpreter": {
"hash": "ba22eb0a90cf9680fd06e72916a6996fb9b27a2ebc703f47aacd356a82bf9683"
"hash": "a665b5d41d17b532ea9890333293a1b812fa0b73c9c25c950b3cedf1bebd0438"
}
}
},
Expand Down
362,117 changes: 215,620 additions & 146,497 deletions docs/notebooks/point_clouds.ipynb

Large diffs are not rendered by default.

28 changes: 28 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,16 @@ @Article{benamou:15
URL = {https://doi.org/10.1137/141000439},
eprint = {https://doi.org/10.1137/141000439}
}
@article{brenier:91,
title={Polar factorization and monotone rearrangement of vector-valued functions},
author={Brenier, Yann},
journal={Communications on pure and applied mathematics},
volume={44},
number={4},
pages={375--417},
year={1991},
publisher={Wiley Online Library}
}

@InProceedings{cuturi:13,
author = {Cuturi, Marco},
Expand Down Expand Up @@ -544,6 +554,17 @@ @Article{heitz:21
url = "https://doi.org/10.1007/s10851-020-00996-z"
}

@article{santambrogio:15,
title={Optimal transport for applied mathematicians},
author={Santambrogio, Filippo},
journal={Birk{\"a}user, NY},
volume={55},
number={58-63},
pages={94},
year={2015},
publisher={Springer}
}

@article{cholmod:08,
author = {Chen, Yanqing and Davis, Timothy A. and Hager, William W. and Rajamanickam, Sivasankaran},
title = {Algorithm 887: CHOLMOD, Supernodal Sparse Cholesky Factorization and Update/Downdate},
Expand Down Expand Up @@ -621,3 +642,10 @@ @inproceedings{korotin:21
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
}

@book{boyd:04,
title={Convex optimization},
author={Boyd, Stephen and Boyd, Stephen P and Vandenberghe, Lieven},
year={2004},
publisher={Cambridge university press}
}
5 changes: 3 additions & 2 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def update_features(self, transports: jnp.ndarray,
transports = transports * inv_a[None, :, None]

if self._loss_name == "sqeucl":
cost = costs.SqEuclidean()
cost_fn = costs.SqEuclidean()
return jnp.sum(
weights * barycentric_projection(transports, y_fused, cost), axis=0
weights * barycentric_projection(transports, y_fused, cost_fn),
axis=0
)
raise NotImplementedError(self._loss_name)

Expand Down
3 changes: 2 additions & 1 deletion ott/core/neuraldual.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from typing_extensions import Literal

from ott.core import icnn, potentials
from ott.geometry import costs

Train_t = Dict[Literal["training_logs", "validation_logs"], List[float]]
Potentials_t = potentials.DualPotentials
Expand Down Expand Up @@ -306,7 +307,7 @@ def to_dual_potentials(self) -> potentials.DualPotentials:
"""Return the Kantorovich dual potentials from the trained potentials."""
f = lambda x: self.state_f.apply_fn({"params": self.state_f.params}, x)
g = lambda x: self.state_g.apply_fn({"params": self.state_g.params}, x)
return potentials.DualPotentials(f, g, cor=True)
return potentials.DualPotentials(f, g, costs.SqEuclidean(), corr=True)

@staticmethod
def _clip_weights_icnn(params):
Expand Down
67 changes: 38 additions & 29 deletions ott/core/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.geometry import pointcloud
from ott.geometry import costs, pointcloud

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same vein as cost_fn, I'm wondering if Potential_t could be renamed to PotentialFn_t or just PotentialFn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same reflex when Michal used it for the first time, but I think it makes sense :) Here it turns out this is just a type (_t) and can be either a vector of a function.

Expand All @@ -22,23 +22,35 @@ class DualPotentials:
Args:
f: The first dual potential function.
g: The second dual potential function.
cor: whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in
http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf)
cost_fn: The cost function used to solve the OT problem.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
corr: whether the duals solve the problem in distance form, or correlation form (as used for instance for ICNNs, see e.g. top right of p.3 in :cite:`makkuva:20`)
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, f: Potential_t, g: Potential_t, *, cor: bool = False):
def __init__(
self,
f: Potential_t,
g: Potential_t,
cost_fn: costs.CostFn,
*,
corr: bool = False
):
self._f = f
self._g = g
self._cor = cor
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self._corr = corr

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
"""Transport ``vec`` according to Brenier formula.
r"""Transport ``vec`` according to Brenier formula.

Uses Theorem 1.17 from :cite:`santambrogio:15` to compute an OT map when
given the Legendre transform of the dual potentials.

Theorem 1.17 in http://math.univ-lyon1.fr/~santambrogio/OTAM-cvgmt.pdf
for case h(.) = ||.||^2, ∇h(.) = 2 ., [∇h]^-1(.) = 0.5 * .
That OT map can be recovered as :math:`x- (\nabla h)^{-1}\circ \nabla f(x)`
For the case :math:`h(\cdot) = \|\cdot\|^2, \nabla h(\cdot) = 2 \cdot\,`,
and as a consequence :math:`h^*(\cdot) = \|.\|^2 / 4`, while one has that
:math:`\nabla h^*(\cdot) = (\nabla h)^{-1}(\cdot) = 0.5 \cdot\,`.

or, when solved in correlation form, as ∇g for forward, ∇f for backward.
When the dual potentials are solved in correlation form (only in the Squared Euclidean distance case), the maps are :math:`\nabla g` for forward, :math:`nabla f` for backward.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved

Args:
vec: Points to transport, array of shape ``[n, d]``.
Expand All @@ -49,29 +61,32 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
The transported points.
"""
vec = jnp.atleast_2d(vec)
if self._cor:
if self._corr and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_g(vec) if forward else self._grad_f(vec)
return vec - 0.5 * (self._grad_f(vec) if forward else self._grad_g(vec))
grad_h_inv = jax.vmap(jax.grad(self.cost_fn.h_legendre))
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
if forward:
return vec - grad_h_inv(self._grad_f(vec))
else:
return vec - grad_h_inv(self._grad_g(vec))

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.

Uses Eq. 5 from :cite:`makkuva:20` when given in cor form, direct estimation
by integrating dual function against points when using dual form.
Uses Eq. 5 from :cite:`makkuva:20` when given in `corr` form, direct
estimation by integrating dual function against points when using dual form.

Args:
src: Samples from the source distribution, array of shape ``[n, d]``.
tgt: Samples from the target distribution, array of shape ``[m, d]``.

Returns:
Wasserstein distance :math:`W^2_2`, assuming :math:`|x-y|^2` as the
ground distance.
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._cor:
if self._corr:
grad_g_y = self._grad_g(tgt)
term1 = -jnp.mean(f(src))
term2 = -jnp.mean(jnp.sum(tgt * grad_g_y, axis=-1) - f(grad_g_y))
Expand All @@ -85,9 +100,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C += jnp.mean(g(tgt))
return C

# compute the final Wasserstein distance assuming ground metric |x-y|^2,
# thus an additional multiplication by 2

@property
def f(self) -> Potential_t:
"""The first dual potential function."""
Expand Down Expand Up @@ -141,7 +153,7 @@ def __init__(

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g)
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
self._geom = geom
self._a = a
self._b = b
Expand All @@ -160,16 +172,13 @@ def _create_potential_function(

def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self._geom.cost_fn,
power=self._geom.power,
epsilon=1.0 # epsilon is not used
jnp.atleast_2d(x), y, cost_fn=self._geom.cost_fn
).cost_matrix
return -eps * jsp.special.logsumexp((potential - cost) / eps,
b=prob_weights)
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

eps = self.epsilon
epsilon = self.epsilon
if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
Expand Down
10 changes: 6 additions & 4 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Because Protocol is not available in Python < 3.8
from typing_extensions import Literal, Protocol

from ott.core import _math_utils as mu
from ott.core import linear_problems, sinkhorn_lr
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud

Expand Down Expand Up @@ -58,16 +59,17 @@ class GWLoss(NamedTuple):
def make_square_loss() -> GWLoss:
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
h1 = Loss(lambda x: jnp.sqrt(2) * x, is_linear=True)
h2 = Loss(lambda y: jnp.sqrt(2) * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
def make_kl_loss(clipping_value: Optional[float] = None) -> GWLoss:
assert clipping_value is None, "Clipping deprecated in KL definition."
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
h2 = Loss(lambda y: mu.safe_log(y), is_linear=False)
return GWLoss(f1, f2, h1, h2)


Expand Down
76 changes: 70 additions & 6 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
pass

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
pass
raise NotImplementedError("Barycenter not yet implemented for this cost.")

@classmethod
def padder(cls, dim: int) -> jnp.ndarray:
Expand Down Expand Up @@ -90,17 +90,78 @@ def tree_unflatten(cls, aux_data, children):
return cls(*children)


@jax.tree_util.register_pytree_node_class
class RBFCost(CostFn):
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""A radial-basis function cost class for translation invariant costs.

Such costs are defined using a function :math:`h`, mapping vectors to
real-values, to be used as:

:math:`c(x,y) = h(z)`, where :math:`z := x-y`.

If that cost function is used to form an Entropic map using the
:cite:`brenier:91` theorem, then the user should ensure :math:`h` is
strictly convex, as well as provide the Legendre transform of :math:`h`,
whose gradient is necessarily the inverse of the gradient of :math:`h`.
"""

@abc.abstractmethod
def h(self, z: jnp.ndarray) -> float:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""RBF function acting on difference of `x-y` to ouput cost."""

def h_legendre(self, z: jnp.ndarray) -> float:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
"""Legendre transform of RBF function `h` (when latter is convex)."""
raise NotImplementedError("`h_legendre` not implemented.")

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Evaluate h on difference between x and y."""
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
return self.h(x - y)


@jax.tree_util.register_pytree_node_class
class SqPNorm(RBFCost):
"""Squared p-norm of the difference of two vectors.

For details on the derivation of the Legendre transform of the norm, see e.g.
the reference :cite:`boyd:04`, p.93/94.
https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, p: float):
assert p > 1.0
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
self.p = p
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
self.q = 1. / (1 - 1 / self.p)

def h(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.p) ** 2

def h_legendre(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.q) ** 2

def tree_flatten(self):
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])


@jax.tree_util.register_pytree_node_class
class Euclidean(CostFn):
"""Euclidean distance."""
"""Euclidean distance.

Note that the Euclidean distance is not cast as a RBF cost, because this
would correspond to `h = abs`, whose gradient is not invertible.
"""

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute Euclidean norm."""
return jnp.linalg.norm(x - y)


@jax.tree_util.register_pytree_node_class
class SqEuclidean(CostFn):
class SqEuclidean(RBFCost):
"""Squared Euclidean distance."""

def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
Expand All @@ -111,6 +172,12 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2. * jnp.vdot(x, y)

def h(self, z: jnp.ndarray) -> float:
return jnp.sum(z ** 2)

def h_legendre(self, z: jnp.ndarray) -> float:
return 0.25 * jnp.sum(z ** 2)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Output barycenter of vectors when using squared-Euclidean distance."""
return jnp.average(xs, weights=weights, axis=0)
Expand All @@ -134,9 +201,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
# similarity is in [-1, 1], clip because of numerical imprecisions
return jnp.clip(cosine_distance, 0., 2.)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
raise NotImplementedError("Barycenter for cosine cost not yet implemented.")

@classmethod
def padder(cls, dim: int) -> jnp.ndarray:
return jnp.ones((1, dim))
Expand Down
Loading