Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix jax.device_put for potentials #183

Merged
merged 4 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._corr:
Expand All @@ -100,11 +99,9 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C
else:
g = jax.vmap(self.g)
C = jnp.mean(f(src))
C += jnp.mean(g(tgt))
return C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))

@property
def f(self) -> Potential_t:
Expand Down Expand Up @@ -137,7 +134,12 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self.cost_fn], {"corr": self._corr}
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}

@classmethod
def tree_unflatten(
Expand Down Expand Up @@ -167,15 +169,6 @@ def __init__(
a: Optional[jnp.ndarray] = None,
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
b: Optional[jnp.ndarray] = None,
):
n, m = geom.shape
a = jnp.ones(n) / n if a is None else a
b = jnp.ones(m) / m if b is None else b

assert f.shape == (n,) and a.shape == (n,), \
f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`."
assert g.shape == (m,) and b.shape == (m,), \
f"Expected `g` and `b` to be of shape `{m,}`, found `{g.shape}`."

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
Expand Down Expand Up @@ -213,14 +206,30 @@ def callback(x: jnp.ndarray) -> float:
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential = self._f
y = self._geom.x
prob_weights = self._a
prob_weights = self.a
else:
potential = self._g
y = self._geom.y
prob_weights = self._b
prob_weights = self.b

return callback

@property
michalk8 marked this conversation as resolved.
Show resolved Hide resolved
def a(self) -> jnp.ndarray:
"""Probability weights of the first measure."""
if self._a is not None:
return self._a
n, _ = self._geom.shape
return jnp.ones(n) / n

@property
def b(self) -> jnp.ndarray:
"""Probability weights of the second measure."""
if self._b is not None:
return self._b
_, m = self._geom.shape
return jnp.ones(m) / m

@property
def epsilon(self) -> float:
"""Entropy regularizer."""
Expand Down
23 changes: 22 additions & 1 deletion tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,35 @@
import pytest

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from ott.tools.gaussian_mixture import gaussian


class TestDualPotentials:

def test_device_put(self):
pot = potentials.DualPotentials(
lambda x: x, lambda x: x, cost_fn=costs.SqEuclidean(), corr=True
)
_ = jax.device_put(pot, "cpu")


class TestEntropicPotentials:

def test_device_put(self, rng: jax.random.PRNGKeyArray):
n = 10
device = jax.devices()[0]
key1, key2, key3 = jax.random.split(rng, 3)
f = jax.random.normal(key1, (n,))
g = jax.random.normal(key2, (n,))
x = jax.random.normal(key3, (n, 3))

pot = potentials.EntropicPotentials(f, g, pointcloud.PointCloud(x))

_ = jax.device_put(pot, device)

@pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0)
def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float):
n1, n2, d = 64, 96, 2
Expand Down