Skip to content

Commit

Permalink
Fix jax.device_put for potentials
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Nov 24, 2022
1 parent 156ef2b commit 9b28c55
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 19 deletions.
45 changes: 27 additions & 18 deletions ott/problems/linear/potentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
Wasserstein distance.
"""
src, tgt = jnp.atleast_2d(src), jnp.atleast_2d(tgt)

f = jax.vmap(self.f)

if self._corr:
Expand All @@ -100,11 +99,9 @@ def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float:
C = jnp.mean(jnp.sum(src ** 2, axis=-1))
C += jnp.mean(jnp.sum(tgt ** 2, axis=-1))
return 2. * (term1 + term2) + C
else:
g = jax.vmap(self.g)
C = jnp.mean(f(src))
C += jnp.mean(g(tgt))
return C

g = jax.vmap(self.g)
return jnp.mean(f(src)) + jnp.mean(g(tgt))

@property
def f(self) -> Potential_t:
Expand Down Expand Up @@ -137,7 +134,12 @@ def _grad_h_inv(self) -> Callable[[jnp.ndarray], jnp.ndarray]:
return jax.vmap(jax.grad(self.cost_fn.h_legendre))

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:
return [self._f, self._g, self.cost_fn], {"corr": self._corr}
return [], {
"f": self._f,
"g": self._g,
"cost_fn": self.cost_fn,
"corr": self._corr
}

@classmethod
def tree_unflatten(
Expand Down Expand Up @@ -167,15 +169,6 @@ def __init__(
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
):
n, m = geom.shape
a = jnp.ones(n) / n if a is None else a
b = jnp.ones(m) / m if b is None else b

assert f.shape == (n,) and a.shape == (n,), \
f"Expected `f` and `a` to be of shape `{n,}`, found `{f.shape}`."
assert g.shape == (m,) and b.shape == (m,), \
f"Expected `g` and `b` to be of shape `{m,}`, found `{g.shape}`."

# we pass directly the arrays and override the properties
# since only the properties need to be callable
super().__init__(f, g, cost_fn=geom.cost_fn, corr=False)
Expand Down Expand Up @@ -213,14 +206,30 @@ def callback(x: jnp.ndarray) -> float:
# see proof of Prop. 2 in https://arxiv.org/pdf/2109.12004.pdf
potential = self._f
y = self._geom.x
prob_weights = self._a
prob_weights = self.a
else:
potential = self._g
y = self._geom.y
prob_weights = self._b
prob_weights = self.b

return callback

@property
def a(self) -> jnp.ndarray:
"""Probability weights of the first measure."""
if self._a is not None:
return self._a
n, _ = self._geom.shape
return jnp.ones(n) / n

@property
def b(self) -> jnp.ndarray:
"""Probability weights of the second measure."""
if self._b is not None:
return self._b
_, m = self._geom.shape
return jnp.ones(m) / m

@property
def epsilon(self) -> float:
"""Entropy regularizer."""
Expand Down
23 changes: 22 additions & 1 deletion tests/problems/linear/potentials_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,35 @@
import pytest

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import sinkhorn
from ott.tools import sinkhorn_divergence
from ott.tools.gaussian_mixture import gaussian


class TestDualPotentials:

def test_device_put(self):
pot = potentials.DualPotentials(
lambda x: x, lambda x: x, cost_fn=costs.SqEuclidean(), corr=True
)
_ = jax.device_put(pot, "cpu")


class TestEntropicPotentials:

def test_device_put(self, rng: jax.random.PRNGKeyArray):
n = 10
device = jax.devices()[0]
key1, key2, key3 = jax.random.split(rng, 3)
f = jax.random.normal(key1, (n,))
g = jax.random.normal(key2, (n,))
x = jax.random.normal(key3, (n, 3))

pot = potentials.EntropicPotentials(f, g, pointcloud.PointCloud(x))

_ = jax.device_put(pot, device)

@pytest.mark.fast.with_args(eps=[5e-2, 1e-1], only_fast=0)
def test_entropic_potentials_dist(self, rng: jnp.ndarray, eps: float):
n1, n2, d = 64, 96, 2
Expand Down

0 comments on commit 9b28c55

Please sign in to comment.