Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jax.device_put for potentials #183

Merged
merged 4 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,14 @@ @book{boyd:04
url={https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf},
publisher={Cambridge university press}
}

@misc{pooladian:22,
doi = {10.48550/ARXIV.2202.08919},
url = {https://arxiv.org/abs/2202.08919},
author = {Pooladian, Aram-Alexandre and Cuturi, Marco and Niles-Weed, Jonathan},
keywords = {Optimization and Control (math.OC), Statistics Theory (math.ST), FOS: Mathematics, FOS: Mathematics},
title = {Debiaser Beware: Pitfalls of Centering Regularized Transport Maps},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
62 changes: 28 additions & 34 deletions ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs, pointcloud
from ott.geometry import costs

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Expand Down Expand Up @@ -89,7 +91,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._corr:
Expand All @@ -100,11 +101,9 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C
else:
g = jax.vmap(self.g)
C = jnp.mean(f(src))
C += jnp.mean(g(tgt))
return C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))

@property
def f(self) -> Potential_t:
Expand Down Expand Up @@ -137,7 +136,12 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self.cost_fn], {"corr": self._corr}
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}

@classmethod
def tree_unflatten(
Expand All @@ -153,35 +157,21 @@ class EntropicPotentials(DualPotentials):
Args:
f: The first dual potential vector of shape ``[n,]``.
g: The second dual potential vector of shape ``[m,]``.
geom: Geometry used to compute the dual potentials using
prob: Linear problem with :class:`~ott.geometry.pointcloud.PointCloud`
geometry that was used to compute the dual potentials using, e.g.,
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
a: Probability weights for the first measure. If `None`, use uniform.
b: Probability weights for the second measure. If `None`, use uniform.
"""

def __init__(
self,
f: jnp.ndarray,
g: jnp.ndarray,
geom: "pointcloud.PointCloud",
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
prob: linear_problem.LinearProblem,
):
n, m = geom.shape
a = jnp.ones(n) / n if a is None else a
b = jnp.ones(m) / m if b is None else b

assert f.shape == (n,) and a.shape == (n,), \
f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`."
assert g.shape == (m,) and b.shape == (m,), \
f"Expected `g` and `b` to be of shape `{m,}`, found `{g.shape}`."

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
self._geom = geom
self._a = a
self._b = b
super().__init__(f, g, cost_fn=prob.geom.cost_fn, corr=False)
self._prob = prob

@property
def f(self) -> Potential_t:
Expand All @@ -206,25 +196,29 @@ def callback(x: jnp.ndarray) -> float:
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

assert isinstance(
self._prob.geom, pointcloud.PointCloud
), f"Expected point cloud geometry, found `{type(self._prob.geom)}`."
epsilon = self.epsilon

if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential = self._f
y = self._geom.x
prob_weights = self._a
y = self._prob.geom.x
prob_weights = self._prob.a
else:
potential = self._g
y = self._geom.y
prob_weights = self._b
y = self._prob.geom.y
prob_weights = self._prob.b

return callback

@property
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
def epsilon(self) -> float:
"""Entropy regularizer."""
return self._geom.epsilon
return self._prob.geom.epsilon

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self._geom, self._a, self._b], {}
return [self._f, self._g, self._prob], {}
4 changes: 1 addition & 3 deletions ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,7 @@ def transport_mass(self) -> float:

def to_dual_potentials(self) -> potentials.EntropicPotentials:
"""Return the entropic map estimator."""
return potentials.EntropicPotentials(
self.f, self.g, geom=self.geom, a=self.a, b=self.b
)
return potentials.EntropicPotentials(self.f, self.g, self.ot_prob)


@jax.tree_util.register_pytree_node_class
Expand Down
9 changes: 5 additions & 4 deletions ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

from ott.geometry import costs, geometry, pointcloud, segment
from ott.problems.linear import potentials
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn

__all__ = [
Expand All @@ -38,14 +38,15 @@ class SinkhornDivergenceOutput(NamedTuple):
b: jnp.ndarray

def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual estimators, (8) in https://arxiv.org/pdf/2202.08919.pdf ."""
"""Return dual estimators :cite:`pooladian:22`, eq. 8."""
geom_xy, *_ = self.geoms
(f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials
prob = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)

(f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials
f = f_xy - f_x
g = g_xy if g_y is None else (g_xy - g_y) # case when `static_b=True`

return potentials.EntropicPotentials(f, g, geom_xy, self.a, self.b)
return potentials.EntropicPotentials(f, g, prob)


def sinkhorn_divergence(
Expand Down
27 changes: 26 additions & 1 deletion tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@
import pytest

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from ott.tools.gaussian_mixture import gaussian


class TestDualPotentials:

def test_device_put(self):
pot = potentials.DualPotentials(
lambda x: x, lambda x: x, cost_fn=costs.SqEuclidean(), corr=True
)
_ = jax.device_put(pot, "cpu")


class TestEntropicPotentials:

def test_device_put(self, rng: jax.random.PRNGKeyArray):
n = 10
device = jax.devices()[0]
keys = jax.random.split(rng, 5)
f = jax.random.normal(keys[0], (n,))
g = jax.random.normal(keys[1], (n,))

geom = pointcloud.PointCloud(jax.random.normal(keys[2], (n, 3)))
a = jax.random.normal(keys[4], (n, 3))
b = jax.random.normal(keys[5], (n, 3))
prob = linear_problem.LinearProblem(geom, a, b)

pot = potentials.EntropicPotentials(f, g, prob)

_ = jax.device_put(pot, device)

@pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0)
def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float):
n1, n2, d = 64, 96, 2
Expand Down