Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate power in PointCloud, introduce TICost and use it to compute Entropic (Brenier) maps. #167

Merged
merged 25 commits into from
Nov 15, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/notebooks/neural_dual.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,13 @@
"outputs": [],
"source": [
"@jax.jit\n",
"def sinkhorn_loss(x, y, epsilon=0.1, power=2.0):\n",
"def sinkhorn_loss(x, y, epsilon=0.1):\n",
" \"\"\"Computes transport between (x, a) and (y, b) via Sinkhorn algorithm.\"\"\"\n",
" a = jnp.ones(len(x)) / len(x)\n",
" b = jnp.ones(len(y)) / len(y)\n",
"\n",
" sdiv = sinkhorn_divergence(\n",
" pointcloud.PointCloud, x, y, power=power, epsilon=epsilon, a=a, b=b\n",
" pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b\n",
" )\n",
" return sdiv.divergence"
]
Expand Down
14 changes: 9 additions & 5 deletions docs/notebooks/point_clouds.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -88981,11 +88981,15 @@
}
],
"source": [
"plot_ots(\n",
" optimize(\n",
" x, y, num_iter=400, epsilon=1e-2, power=0.5, cost_fn=costs.Euclidean()\n",
" )\n",
")"
"@jax.tree_util.register_pytree_node_class\n",
"class Custom(costs.CostFn):\n",
" \"\"\"Custom cost, sqrt of Euclidean norm.\"\"\"\n",
"\n",
" def pairwise(self, x, y):\n",
" return jnp.sqrt(jnp.abs(jnp.linalg.norm(x - y)))\n",
"\n",
"\n",
"plot_ots(optimize(x, y, num_iter=400, epsilon=1e-2, cost_fn=Custom()))"
]
},
{
Expand Down
7 changes: 7 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -621,3 +621,10 @@ @inproceedings{korotin:21
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
}

@book{boyd:04,
title={Convex optimization},
author={Boyd, Stephen and Boyd, Stephen P and Vandenberghe, Lieven},
year={2004},
publisher={Cambridge university press}
}
5 changes: 3 additions & 2 deletions ott/core/bar_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ def update_features(self, transports: jnp.ndarray,
transports = transports * inv_a[None, :, None]

if self._loss_name == "sqeucl":
cost = costs.SqEuclidean()
cost_fn = costs.SqEuclidean()
return jnp.sum(
weights * barycentric_projection(transports, y_fused, cost), axis=0
weights * barycentric_projection(transports, y_fused, cost_fn),
axis=0
)
raise NotImplementedError(self._loss_name)

Expand Down
47 changes: 27 additions & 20 deletions ott/core/potentials.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import Any, Callable, Dict, Sequence, Tuple
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.geometry import pointcloud
from ott.geometry import costs, pointcloud

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the same vein as cost_fn, I'm wondering if Potential_t could be renamed to PotentialFn_t or just PotentialFn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had the same reflex when Michal used it for the first time, but I think it makes sense :) Here it turns out this is just a type (_t) and can be either a vector of a function.

Expand All @@ -22,21 +22,31 @@ class DualPotentials:
Args:
f: The first dual potential function.
g: The second dual potential function.
cost_fn: The cost function used to solve the OT problem.
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
cor: whether the duals solve the problem in distance form, or correlation
form (as used for instance for ICNNs, see e.g. top right of p.3 in
http://proceedings.mlr.press/v119/makkuva20a/makkuva20a.pdf)
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self, f: Potential_t, g: Potential_t, *, cor: bool = False):
def __init__(
self,
f: Potential_t,
g: Potential_t,
*,
cost_fn: Optional[costs.CostFn] = None,
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
cor: bool = False
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
):
self._f = f
self._g = g
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self._cor = cor

def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
"""Transport ``vec`` according to Brenier formula.

Theorem 1.17 in http://math.univ-lyon1.fr/~santambrogio/OTAM-cvgmt.pdf
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
for case h(.) = ||.||^2, ∇h(.) = 2 ., [∇h]^-1(.) = 0.5 * .
for case h(.) = ||.||^2, ∇h(.) = 2 .,
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
h*(.) = ||.||^2 / 4, [∇h*](.) = [∇h]^-1(.) = 0.5 * .

or, when solved in correlation form, as ∇g for forward, ∇f for backward.

Expand All @@ -49,9 +59,13 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray:
The transported points.
"""
vec = jnp.atleast_2d(vec)
if self._cor:
if self._cor and isinstance(self.cost_fn, costs.SqEuclidean):
return self._grad_g(vec) if forward else self._grad_f(vec)
return vec - 0.5 * (self._grad_f(vec) if forward else self._grad_g(vec))
grad_h_inv = jax.vmap(jax.grad(self.cost_fn.h_legendre))
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
if forward:
return vec - grad_h_inv(self._grad_f(vec))
else:
return vec - grad_h_inv(self._grad_g(vec))

def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
"""Evaluate 2-Wasserstein distance between samples using dual potentials.
Expand All @@ -64,8 +78,7 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
tgt: Samples from the target distribution, array of shape ``[m, d]``.

Returns:
Wasserstein distance :math:`W^2_2`, assuming :math:`|x-y|^2` as the
ground distance.
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

Expand All @@ -85,9 +98,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C += jnp.mean(g(tgt))
return C

# compute the final Wasserstein distance assuming ground metric |x-y|^2,
# thus an additional multiplication by 2

@property
def f(self) -> Potential_t:
"""The first dual potential function."""
Expand Down Expand Up @@ -141,7 +151,7 @@ def __init__(

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g)
super().__init__(f, g, cost_fn=geom.cost_fn, cor=False)
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
self._geom = geom
self._a = a
self._b = b
Expand All @@ -160,16 +170,13 @@ def _create_potential_function(

def callback(x: jnp.ndarray) -> float:
cost = pointcloud.PointCloud(
jnp.atleast_2d(x),
y,
cost_fn=self._geom.cost_fn,
power=self._geom.power,
epsilon=1.0 # epsilon is not used
jnp.atleast_2d(x), y, cost_fn=self._geom.cost_fn
).cost_matrix
return -eps * jsp.special.logsumexp((potential - cost) / eps,
b=prob_weights)
z = (potential - cost) / epsilon
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

eps = self.epsilon
epsilon = self.epsilon
if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
Expand Down
10 changes: 6 additions & 4 deletions ott/core/quad_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Because Protocol is not available in Python < 3.8
from typing_extensions import Literal, Protocol

from ott.core import _math_utils as mu
from ott.core import linear_problems, sinkhorn_lr
from ott.geometry import epsilon_scheduler, geometry, low_rank, pointcloud

Expand Down Expand Up @@ -58,16 +59,17 @@ class GWLoss(NamedTuple):
def make_square_loss() -> GWLoss:
f1 = Loss(lambda x: x ** 2, is_linear=False)
f2 = Loss(lambda y: y ** 2, is_linear=False)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: 2.0 * y, is_linear=True)
h1 = Loss(lambda x: jnp.sqrt(2) * x, is_linear=True)
h2 = Loss(lambda y: jnp.sqrt(2) * y, is_linear=True)
return GWLoss(f1, f2, h1, h2)


def make_kl_loss(clipping_value: float = 1e-8) -> GWLoss:
def make_kl_loss(clipping_value: Optional[float] = None) -> GWLoss:
assert clipping_value is None, "Clipping deprecated in KL definition."
f1 = Loss(lambda x: -jax.scipy.special.entr(x) - x, is_linear=False)
f2 = Loss(lambda y: y, is_linear=True)
h1 = Loss(lambda x: x, is_linear=True)
h2 = Loss(lambda y: jnp.log(jnp.clip(y, clipping_value)), is_linear=False)
h2 = Loss(lambda y: mu.safe_log(y), is_linear=False)
return GWLoss(f1, f2, h1, h2)


Expand Down
86 changes: 80 additions & 6 deletions ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
pass

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
pass
raise NotImplementedError("Barycenter not yet implemented for this cost.")

@classmethod
def padder(cls, dim: int) -> jnp.ndarray:
Expand Down Expand Up @@ -90,17 +90,88 @@ def tree_unflatten(cls, aux_data, children):
return cls(*children)


@jax.tree_util.register_pytree_node_class
class RBFCost(CostFn):
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""A radial-basis function cost class for translation invariant costs.

Such costs are defined as
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved

c(x,y) = h(z), where z := x-y.

where h is a function strictly convex (or concave) function mapping vectors
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small repetition here: I think you meant "where h is a strictly convex (or concave) function ...". It's minor, I know ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch! thanks.

to real-values.

For completeness (and differentiation using the Brenier theorem), the user
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
is also supposed to provide the Legendre transform of `h`, whose gradient (the
inverse of the gradient of `h`) will be used to form a Brenier map.
"""

def h(self, z: jnp.ndarray) -> float:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
pass

def h_legendre(self, z: jnp.ndarray) -> float:
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
pass

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Evaluate h on difference between x and y."""
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
return self.h(x - y)

def tree_flatten(self):
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
return (), None

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
pass

@classmethod
def tree_unflatten(cls, aux_data, children):
del aux_data
return cls(*children)


@jax.tree_util.register_pytree_node_class
class SqPNorm(RBFCost):
"""Squared p-norm of the difference of two vectors.

For details on the derivation of the Legendre transform of the norm, see e.g.
the reference :cite:`boyd:04`, p.93/94.
https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
"""
p: float
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, p: float):
self.p = p
marcocuturi marked this conversation as resolved.
Show resolved Hide resolved
self.q = 1. / (1 - 1 / self.p)

def h(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.p) ** 2

def h_legendre(self, z: jnp.ndarray) -> float:
return 0.5 * jnp.linalg.norm(z, self.q) ** 2

def tree_flatten(self):
return (), (self.p,)

@classmethod
def tree_unflatten(cls, aux_data, children):
del children
return cls(aux_data[0])


@jax.tree_util.register_pytree_node_class
class Euclidean(CostFn):
"""Euclidean distance."""
"""Euclidean distance.

Note that the Euclidean distance is not cast as a RBF cost, because this
would correspond to `h = abs`, whose gradient is not invertible.
"""

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute Euclidean norm."""
return jnp.linalg.norm(x - y)


@jax.tree_util.register_pytree_node_class
class SqEuclidean(CostFn):
class SqEuclidean(RBFCost):
"""Squared Euclidean distance."""

def norm(self, x: jnp.ndarray) -> Union[float, jnp.ndarray]:
Expand All @@ -111,6 +182,12 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Compute minus twice the dot-product between vectors."""
return -2. * jnp.vdot(x, y)

def h(self, z: jnp.ndarray) -> float:
return jnp.sum(z ** 2)

def h_legendre(self, z: jnp.ndarray) -> float:
return 0.25 * jnp.sum(z ** 2)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray:
"""Output barycenter of vectors when using squared-Euclidean distance."""
return jnp.average(xs, weights=weights, axis=0)
Expand All @@ -134,9 +211,6 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
# similarity is in [-1, 1], clip because of numerical imprecisions
return jnp.clip(cosine_distance, 0., 2.)

def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> float:
raise NotImplementedError("Barycenter for cosine cost not yet implemented.")

@classmethod
def padder(cls, dim: int) -> jnp.ndarray:
return jnp.ones((1, dim))
Expand Down
Loading