Skip to content

Commit

Permalink
TI transport map (#559)
Browse files Browse the repository at this point in the history
* Add `transport_map` function to `TICost`

* Update test

* Add more tests

* Add test for `RegTICost`

* Polish the docs

* Rename variable

* Remove forward argument

* Test using linear assignment
  • Loading branch information
michalk8 committed Jul 12, 2024
1 parent 67a64c0 commit c76e2ac
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 7 deletions.
34 changes: 34 additions & 0 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,40 @@ def twist_operator(
return vec + jax.grad(self.h_legendre)(-dual_vec)
return vec - jax.grad(self.h_legendre)(dual_vec)

def transport_map(self, g: Func) -> Callable[[jnp.ndarray, Any], jnp.ndarray]:
r"""Get an optimal transport map for a concave function :math:`g`.
Uses Proposition 1 from :cite:`klein:24` to define an OT map
:math:`x - (\nabla h^*) \circ \nabla \bar g^h(x)`, where :math:`h^*`
is the Legendre transform of :math:`h` and :math:`\bar g^h`
is the :meth:`h_transform` of a concave function :math:`g`.
Args:
g: Concave function.
Returns:
The transport map with a signature ``(x, **kwargs)``.
"""

def transport(x: jnp.ndarray, **kwargs: Any) -> jnp.ndarray:
"""Transport points from source to target.
Args:
x: Array of shape ``[n, d]``.
kwargs: Keyword arguments for the output of the
:meth:`h_transform` method.
Returns:
The transported points.
"""
g_h = functools.partial(self.h_transform(g), **kwargs)
grad_g_h = jax.vmap(jax.grad(g_h))
return jax.vmap(
self.twist_operator, in_axes=[0, 0, None]
)(x, grad_g_h(x), False)

return transport

def barycenter(self, weights: jnp.ndarray,
xs: jnp.ndarray) -> Tuple[jnp.ndarray, Any]:
"""Output barycenter of vectors."""
Expand Down
76 changes: 69 additions & 7 deletions tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import jax.numpy as jnp
import jaxopt
import numpy as np
import scipy as sp
from tslearn import metrics as ts_metrics

from ott.geometry import costs, pointcloud, regularizers
Expand Down Expand Up @@ -131,19 +132,53 @@ class TestTICost:

@pytest.mark.parametrize(
"cost_fn", [
costs.SqPNorm(p=1.0),
costs.SqPNorm(1.05),
costs.SqPNorm(2.4),
costs.PNormP(p=1.1),
costs.PNormP(1.1),
costs.PNormP(1.3),
costs.SqEuclidean()
]
)
def test_h_transform(self, rng: jax.Array, cost_fn: costs.TICost):
x = jax.random.normal(rng, (15, 3))
h_transform = cost_fn.h_transform(mu.logsumexp)
h_transform = jax.jit(jax.vmap(jax.grad(h_transform)))
def test_transport_map(self, rng: jax.Array, cost_fn: costs.TICost):
n, d = 15, 5
rng_x, rng_A = jax.random.split(rng)
x = jax.random.normal(rng_x, (n, d))
A = jax.random.normal(rng_A, (d, d * 2))
A = A @ A.T

np.testing.assert_array_equal(jnp.isfinite(h_transform(x)), True)
transport_fn = cost_fn.transport_map(lambda z: -jnp.sum(z * (A.dot(z))))
transport_fn = jax.jit(transport_fn)

y = transport_fn(x)
cost_matrix = cost_fn.all_pairs(x, y)

row_ixs, col_ixs = sp.optimize.linear_sum_assignment(cost_matrix)
np.testing.assert_array_equal(row_ixs, jnp.arange(n))
np.testing.assert_array_equal(col_ixs, jnp.arange(n))

@pytest.mark.parametrize(
"cost_fn", [
costs.SqEuclidean(),
costs.PNormP(2),
costs.RegTICost(regularizers.L2(lam=0.0), rho=2.0)
]
)
def test_sqeucl_transport(
self, rng: jax.Array, cost_fn: costs.TICost, enable_x64
):
x = jax.random.normal(rng, (12, 7))
f = mu.logsumexp

h_f = cost_fn.h_transform(f)
expected_fn = cost_fn.transport_map(f)
expected_fn = jax.jit(expected_fn)
if isinstance(cost_fn, costs.SqEuclidean):
# multiply by `0.5`, because `SqEuclidean := |x|_2^2`
actual_fn = jax.jit(jax.vmap(lambda x: x - 0.5 * jax.grad(h_f)(x)))
else:
actual_fn = jax.jit(jax.vmap(lambda x: x - jax.grad(h_f)(x)))

np.testing.assert_array_equal(expected_fn(x), actual_fn(x))

@pytest.mark.parametrize("cost_fn", [costs.SqEuclidean(), costs.PNormP(2)])
@pytest.mark.parametrize("d", [5, 10])
Expand Down Expand Up @@ -292,6 +327,33 @@ def test_stronger_regularization_increases_sparsity(
for fwd in [False, True]:
np.testing.assert_array_equal(np.diff(sparsity[fwd]) > 0.0, True)

@pytest.mark.parametrize(
"reg", [
regularizers.L1(lam=0.1),
regularizers.L2(lam=3.3),
regularizers.STVS(lam=1.0),
regularizers.SqKOverlap(k=3, lam=1.05)
]
)
def test_reg_transport_fn(
self, rng: jax.Array, reg: regularizers.ProximalOperator
):

@jax.jit
@functools.partial(jax.vmap, in_axes=0)
def expected_fn(x: jnp.ndarray) -> jnp.ndarray:
f_h = cost_fn.h_transform(f)
return x - cost_fn.regularizer.prox(jax.grad(f_h)(x))

x = jax.random.normal(rng, (11, 9))
cost_fn = costs.RegTICost(reg)
f = mu.logsumexp

actual_fn = cost_fn.transport_map(f)
actual_fn = jax.jit(actual_fn)

np.testing.assert_array_equal(expected_fn(x), actual_fn(x))


@pytest.mark.fast()
class TestSoftDTW:
Expand Down

0 comments on commit c76e2ac

Please sign in to comment.