Skip to content

Commit

Permalink
Expose solver in TICost.h_transform
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Jul 3, 2024
1 parent 80be1c2 commit e6afcf2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 13 deletions.
50 changes: 38 additions & 12 deletions src/ott/geometry/costs.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
"SoftDTW",
]

Func = Callable[[jnp.ndarray], float]


@jtu.register_pytree_node_class
class CostFn(abc.ABC):
Expand Down Expand Up @@ -204,8 +206,10 @@ def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:

def h_transform(
self,
f: Callable[[jnp.ndarray], float],
f: Func,
ridge: float = 1e-8,
solver: Optional[Callable[[Func, jnp.ndarray, jnp.ndarray, Any],
jnp.ndarray]] = None,
) -> Callable[[jnp.ndarray, Optional[jnp.ndarray], Any], float]:
r"""Compute the h-transform of a concave function.
Expand All @@ -226,11 +230,19 @@ def h_transform(
Args:
f: Concave function.
ridge: Regularizer to ensure strong convexity of the objective.
solver: Solver with the signature ``(func, x, x_init, **kwargs) -> sol``.
If :obj:`None`, use an :class:`~jaxopt.LBFGS` wrapper.
Returns:
The h-transform of ``f``.
The h-transform :math:`f_h` of :math:`f`.
"""

def lbfgs(
fun: Func, x: jnp.ndarray, x_init: jnp.ndarray, **kwargs: Any
) -> jnp.ndarray:
solver = jaxopt.LBFGS(fun=fun, **kwargs)
return solver.run(x_init, x=x).params

def fun(z: jnp.ndarray, x: jnp.ndarray) -> float:
return self.h(z) + ridge * jnp.sum(z ** 2) - f(x - z)

Expand All @@ -239,12 +251,25 @@ def f_h(
x_init: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> float:
solver = jaxopt.LBFGS(fun=fun, **kwargs)
x0 = x if x_init is None else x_init
z = solver.run(x0, x=x).params
"""h-transform of a concave function.
Args:
x: Array of shape ``[d,]`` where to evaluate the function.
x_init: Initial estimate. If :obj:`None`, use ``x``.
kwargs: Keyword arguments for the solver.
Returns:
The output :math:`f_h(x)`.
"""
if x_init is None:
x_init = x
z = solver(fun, x, x_init, **kwargs)
z = jax.lax.stop_gradient(z)
return fun(z, x)

if solver is None:
solver = lbfgs

return f_h

def twist_operator(
Expand Down Expand Up @@ -333,11 +358,11 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102
class RegTICost(TICost):
r"""Regularized translation-invariant cost.
.. math:
.. math::
\frac{\rho}{2}\|\cdot\|_2^2 + \text{regularizer}(\cdot)
Args:
regularizer: Regularizer function.
regularizer: Regularization function.
rho: Scaling factor.
"""

Expand Down Expand Up @@ -385,7 +410,7 @@ def bwd(q: jnp.ndarray, g: jnp.ndarray) -> Tuple[jnp.ndarray]:

def h_transform(
self,
f: Callable[[jnp.ndarray], float],
f: Func,
) -> Callable[[jnp.ndarray, Optional[jnp.ndarray], Any], float]:
r"""Compute the h-transform of a concave function.
Expand Down Expand Up @@ -420,11 +445,11 @@ def f_h(
x_init: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> float:
"""H-transform of a concave function.
"""h-transform of a concave function.
Args:
x: Array of shape ``[d,]`` where to evaluate the function.
x_init: Initial estimate.
x_init: Initial estimate. If :obj:`None`, use ``x``.
kwargs: Keyword arguments for :class:`~jaxopt.ProximalGradient`.
Returns:
Expand All @@ -435,8 +460,9 @@ def f_h(
prox=lambda x, h, tau: h.prox(x, tau),
**kwargs,
)
x0 = x if x_init is None else x_init
z = solver.run(x0, self._h, x=x).params
if x_init is None:
x_init = x
z = solver.run(x_init, self._h, x=x).params
z = jax.lax.stop_gradient(z)
return self.h(z) - f(x - z)

Expand Down
24 changes: 23 additions & 1 deletion tests/geometry/costs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Type
from typing import Any, Type

import pytest

import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
from tslearn import metrics as ts_metrics

Expand Down Expand Up @@ -167,6 +168,27 @@ def test_h_transform_matches_unreg(

np.testing.assert_allclose(pred(x), gt(x), rtol=1e-5, atol=1e-5)

@pytest.mark.parametrize("cost_fn", [costs.SqEuclidean(), costs.PNormP(1.5)])
def test_h_transform_solver(self, rng: jax.Array, cost_fn: costs.TICost):

def gd_solver(
fun, x: jnp.ndarray, x_init: jnp.ndarray, **kwargs: Any
) -> jnp.ndarray:
solver = jaxopt.GradientDescent(fun=fun, **kwargs)
return solver.run(x, x_init).params

n, d = 21, 6
rngs = jax.random.split(rng, 2)
u = jnp.abs(jax.random.uniform(rngs[0], (d,)))
x = jax.random.normal(rngs[1], (n, d))

concave_fn = lambda z: -cost_fn.h(z) + jnp.dot(z, u)

expected = jax.vmap(cost_fn.h_transform(concave_fn, solver=None))
actual = jax.vmap(cost_fn.h_transform(concave_fn, solver=gd_solver))

np.testing.assert_allclose(expected(x), actual(x), rtol=1e-4, atol=1e-4)


@pytest.mark.fast()
class TestRegTICost:
Expand Down

0 comments on commit e6afcf2

Please sign in to comment.