Skip to content

Commit

Permalink
Fix jax.device_put for potentials (#183)
Browse files Browse the repository at this point in the history
* Fix `jax.device_put` for potentials

* Make weights not optional

* Use `LinearProblem` in `EntropicPotentials`

* Remove extra docs
  • Loading branch information
michalk8 committed Nov 25, 2022
1 parent 156ef2b commit be94b74
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 42 deletions.
11 changes: 11 additions & 0 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,14 @@ @book{boyd:04
url={https://web.stanford.edu/~boyd/cvxbook/bv_cvxbook.pdf},
publisher={Cambridge university press}
}

@misc{pooladian:22,
doi = {10.48550/ARXIV.2202.08919},
url = {https://arxiv.org/abs/2202.08919},
author = {Pooladian, Aram-Alexandre and Cuturi, Marco and Niles-Weed, Jonathan},
keywords = {Optimization and Control (math.OC), Statistics Theory (math.ST), FOS: Mathematics, FOS: Mathematics},
title = {Debiaser Beware: Pitfalls of Centering Regularized Transport Maps},
publisher = {arXiv},
year = {2022},
copyright = {Creative Commons Attribution 4.0 International}
}
62 changes: 28 additions & 34 deletions ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu
from typing_extensions import Literal

from ott.problems.linear import linear_problem

if TYPE_CHECKING:
from ott.geometry import costs, pointcloud
from ott.geometry import costs

__all__ = ["DualPotentials", "EntropicPotentials"]
Potential_t = Callable[[jnp.ndarray], float]
Expand Down Expand Up @@ -89,7 +91,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._corr:
Expand All @@ -100,11 +101,9 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C
else:
g = jax.vmap(self.g)
C = jnp.mean(f(src))
C += jnp.mean(g(tgt))
return C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))

@property
def f(self) -> Potential_t:
Expand Down Expand Up @@ -137,7 +136,12 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self.cost_fn], {"corr": self._corr}
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}

@classmethod
def tree_unflatten(
Expand All @@ -153,35 +157,21 @@ class EntropicPotentials(DualPotentials):
Args:
f: The first dual potential vector of shape ``[n,]``.
g: The second dual potential vector of shape ``[m,]``.
geom: Geometry used to compute the dual potentials using
prob: Linear problem with :class:`~ott.geometry.pointcloud.PointCloud`
geometry that was used to compute the dual potentials using, e.g.,
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
a: Probability weights for the first measure. If `None`, use uniform.
b: Probability weights for the second measure. If `None`, use uniform.
"""

def __init__(
self,
f: jnp.ndarray,
g: jnp.ndarray,
geom: "pointcloud.PointCloud",
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
prob: linear_problem.LinearProblem,
):
n, m = geom.shape
a = jnp.ones(n) / n if a is None else a
b = jnp.ones(m) / m if b is None else b

assert f.shape == (n,) and a.shape == (n,), \
f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`."
assert g.shape == (m,) and b.shape == (m,), \
f"Expected `g` and `b` to be of shape `{m,}`, found `{g.shape}`."

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
self._geom = geom
self._a = a
self._b = b
super().__init__(f, g, cost_fn=prob.geom.cost_fn, corr=False)
self._prob = prob

@property
def f(self) -> Potential_t:
Expand All @@ -206,25 +196,29 @@ def callback(x: jnp.ndarray) -> float:
lse = -epsilon * jsp.special.logsumexp(z, b=prob_weights, axis=-1)
return jnp.squeeze(lse)

assert isinstance(
self._prob.geom, pointcloud.PointCloud
), f"Expected point cloud geometry, found `{type(self._prob.geom)}`."
epsilon = self.epsilon

if kind == "g":
# When seeking to evaluate 2nd potential function, 1st set of potential
# values and support should be used,
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential = self._f
y = self._geom.x
prob_weights = self._a
y = self._prob.geom.x
prob_weights = self._prob.a
else:
potential = self._g
y = self._geom.y
prob_weights = self._b
y = self._prob.geom.y
prob_weights = self._prob.b

return callback

@property
def epsilon(self) -> float:
"""Entropy regularizer."""
return self._geom.epsilon
return self._prob.geom.epsilon

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self._geom, self._a, self._b], {}
return [self._f, self._g, self._prob], {}
4 changes: 1 addition & 3 deletions ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,7 @@ def transport_mass(self) -> float:

def to_dual_potentials(self) -> potentials.EntropicPotentials:
"""Return the entropic map estimator."""
return potentials.EntropicPotentials(
self.f, self.g, geom=self.geom, a=self.a, b=self.b
)
return potentials.EntropicPotentials(self.f, self.g, self.ot_prob)


@jax.tree_util.register_pytree_node_class
Expand Down
9 changes: 5 additions & 4 deletions ott/tools/sinkhorn_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import jax.numpy as jnp

from ott.geometry import costs, geometry, pointcloud, segment
from ott.problems.linear import potentials
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn

__all__ = [
Expand All @@ -38,14 +38,15 @@ class SinkhornDivergenceOutput(NamedTuple):
b: jnp.ndarray

def to_dual_potentials(self) -> "potentials.EntropicPotentials":
"""Return dual estimators, (8) in https://arxiv.org/pdf/2202.08919.pdf ."""
"""Return dual estimators :cite:`pooladian:22`, eq. 8."""
geom_xy, *_ = self.geoms
(f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials
prob = linear_problem.LinearProblem(geom_xy, a=self.a, b=self.b)

(f_xy, g_xy), (f_x, g_x), (f_y, g_y) = self.potentials
f = f_xy - f_x
g = g_xy if g_y is None else (g_xy - g_y) # case when `static_b=True`

return potentials.EntropicPotentials(f, g, geom_xy, self.a, self.b)
return potentials.EntropicPotentials(f, g, prob)


def sinkhorn_divergence(
Expand Down
27 changes: 26 additions & 1 deletion tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,39 @@
import pytest

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from ott.tools.gaussian_mixture import gaussian


class TestDualPotentials:

def test_device_put(self):
pot = potentials.DualPotentials(
lambda x: x, lambda x: x, cost_fn=costs.SqEuclidean(), corr=True
)
_ = jax.device_put(pot, "cpu")


class TestEntropicPotentials:

def test_device_put(self, rng: jax.random.PRNGKeyArray):
n = 10
device = jax.devices()[0]
keys = jax.random.split(rng, 5)
f = jax.random.normal(keys[0], (n,))
g = jax.random.normal(keys[1], (n,))

geom = pointcloud.PointCloud(jax.random.normal(keys[2], (n, 3)))
a = jax.random.normal(keys[4], (n, 3))
b = jax.random.normal(keys[5], (n, 3))
prob = linear_problem.LinearProblem(geom, a, b)

pot = potentials.EntropicPotentials(f, g, prob)

_ = jax.device_put(pot, device)

@pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0)
def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float):
n1, n2, d = 64, 96, 2
Expand Down

0 comments on commit be94b74

Please sign in to comment.