Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add relative_epsilon option to GromovWasserstein #355

Merged
merged 3 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions src/ott/initializers/quadratic/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -121,15 +121,20 @@ class QuadraticInitializer(BaseQuadraticInitializer):
"""

def _create_geometry(
self, quad_prob: "quadratic_problem.QuadraticProblem", *, epsilon: float,
self,
quad_prob: "quadratic_problem.QuadraticProblem",
*,
epsilon: float,
relative_epsilon: Optional[bool] = None,
**kwargs: Any
) -> geometry.Geometry:
"""Compute initial geometry for linearization.

Args:
quad_prob: Quadratic OT problem.
epsilon: Epsilon regularization.
kwargs: Additional keyword arguments, unused.
relative_epsilon: Whether to use relative epsilon in the geometry.
kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`.

Returns:
The initial geometry used to initialize the linearized problem.
Expand Down Expand Up @@ -160,8 +165,12 @@ def _create_geometry(
)
cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction

cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix()
return geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)
cost_matrix += quad_prob.fused_penalty * quad_prob._fused_cost_matrix
return geometry.Geometry(
cost_matrix=cost_matrix,
epsilon=epsilon,
relative_epsilon=relative_epsilon
)


class LRQuadraticInitializer(BaseQuadraticInitializer):
Expand All @@ -176,12 +185,16 @@ def __init__(self, lr_linear_initializer: "initializers_lr.LRInitializer"):
self._linear_lr_initializer = lr_linear_initializer

def _create_geometry(
self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any
self,
quad_prob: "quadratic_problem.QuadraticProblem",
relative_epsilon: Optional[bool] = False,
**kwargs: Any
) -> geometry.Geometry:
"""Compute initial geometry for linearization.

Args:
quad_prob: Quadratic OT problem.
relative_epsilon: Whether to use relative epsilon in the geometry.
kwargs: Keyword arguments for
:meth:`~ott.initializers.linear.initializers_lr.LRInitializer.__call__`.

Expand All @@ -195,7 +208,7 @@ def _create_geometry(
q=q, r=r, g=g, costs=None, errors=None, ot_prob=None
)

return quad_prob.update_lr_geom(tmp_out)
return quad_prob.update_lr_geom(tmp_out, relative_epsilon=relative_epsilon)

@property
def rank(self) -> int:
Expand Down
63 changes: 23 additions & 40 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def marginal_dependent_cost(
self,
marginal_1: jnp.ndarray,
marginal_2: jnp.ndarray,
*,
remove_scale: bool = False,
) -> low_rank.LRCGeometry:
r"""Initialize cost term that depends on the marginals of the transport.

Expand All @@ -160,17 +158,11 @@ def marginal_dependent_cost(
Args:
marginal_1: [n,], first marginal of transport matrix.
marginal_2: [m,], second marginal of transport matrix.
remove_scale: Whether to remove any scaling from the cost matrices before
computing the linearization.

Returns:
Low-rank geometry of rank 2, storing normalization constants.
"""
geom_xx, geom_yy = self.geom_xx, self.geom_yy
if remove_scale:
geom_xx = geom_xx.set_scale_cost(1.0)
geom_yy = geom_yy.set_scale_cost(1.0)

if self._loss_name == "sqeucl": # quadratic apply, efficient for LR
tmp1 = geom_xx.apply_square_cost(marginal_1, axis=1)
tmp2 = geom_yy.apply_square_cost(marginal_2, axis=1)
Expand Down Expand Up @@ -251,14 +243,12 @@ def init_transport_mass(self) -> float:
def update_lr_geom(
self,
lr_sink: "sinkhorn_lr.LRSinkhornOutput",
remove_scale: bool = False,
relative_epsilon: Optional[bool] = None,
) -> geometry.Geometry:
"""Recompute (possibly LRC) linearization using LR Sinkhorn output."""
marginal_1 = lr_sink.marginal(1)
marginal_2 = lr_sink.marginal(0)
marginal_cost = self.marginal_dependent_cost(
marginal_1, marginal_2, remove_scale=remove_scale
)
marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2)

# Extract factors from LR Sinkhorn output
q, r, inv_sqg = lr_sink.q, lr_sink.r, 1.0 / jnp.sqrt(lr_sink.g)
Expand All @@ -268,28 +258,28 @@ def update_lr_geom(
# Handle LRC Geometry case.
h1, h2 = self.quad_loss
geom_xx, geom_yy, geom_xy = self.geom_xx, self.geom_yy, self.geom_xy
if remove_scale:
geom_xx = geom_xx.set_scale_cost(1.0)
geom_yy = geom_yy.set_scale_cost(1.0)
geom_xy = geom_xy.set_scale_cost(1.0) if self.is_fused else None
tmp1 = apply_cost(geom_xx, q, axis=1, fn=h1)
tmp2 = apply_cost(geom_yy, r, axis=1, fn=h2)
if self.is_low_rank:
geom = low_rank.LRCGeometry(cost_1=tmp1, cost_2=-tmp2) + marginal_cost
geom = low_rank.LRCGeometry(
cost_1=tmp1, cost_2=-tmp2, relative_epsilon=relative_epsilon
) + marginal_cost
if self.is_fused:
geom = geom + geom_xy
else:
cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T)
cost_matrix += self.fused_penalty * self._fused_cost_matrix(remove_scale)
geom = geometry.Geometry(cost_matrix=cost_matrix)
cost_matrix += self.fused_penalty * self._fused_cost_matrix
geom = geometry.Geometry(
cost_matrix=cost_matrix, relative_epsilon=relative_epsilon
)
return geom # noqa: RET504

def update_linearization(
self,
transport: Transport,
epsilon: Optional[Union[epsilon_scheduler.Epsilon, float]] = None,
old_transport_mass: float = 1.0,
remove_scale: bool = False,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
"""Update linearization of GW problem by updating cost matrix.

Expand All @@ -307,11 +297,8 @@ def update_linearization(
epsilon: An epsilon scheduler or a float passed on to the linearization.
old_transport_mass: Sum of the elements of the transport matrix at the
previous iteration.
remove_scale: Whether to remove any scaling from the cost matrices when
computing the linearization of the quadratic cost. At the moment, this
is only used when doing this update at the last outer iteration of the
:class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`
solver.
relative_epsilon: Whether to use relative epsilon in the linearized
geometry.

Returns:
Updated linear OT problem, a new local linearization of GW problem.
Expand All @@ -326,9 +313,7 @@ def update_linearization(

marginal_1 = transport.marginal(axis=1) * rescale_factor
marginal_2 = transport.marginal(axis=0) * rescale_factor
marginal_cost = self.marginal_dependent_cost(
marginal_1, marginal_2, remove_scale=remove_scale
)
marginal_cost = self.marginal_dependent_cost(marginal_1, marginal_2)

transport_matrix = transport.matrix * rescale_factor

Expand All @@ -342,18 +327,18 @@ def update_linearization(

h1, h2 = self.quad_loss
geom_xx, geom_yy = self.geom_xx, self.geom_yy
if remove_scale:
geom_xx = geom_xx.set_scale_cost(1.0)
geom_yy = geom_yy.set_scale_cost(1.0)

tmp = apply_cost(geom_xx, transport_matrix, axis=1, fn=h1)
tmp = apply_cost(geom_yy, tmp.T, axis=1, fn=h2).T

cost_matrix = marginal_cost.cost_matrix - tmp + unbalanced_correction
cost_matrix += self.fused_penalty * rescale_factor * \
self._fused_cost_matrix(remove_scale)
cost_matrix += self.fused_penalty * rescale_factor * self._fused_cost_matrix

geom = geometry.Geometry(cost_matrix=cost_matrix, epsilon=epsilon)
geom = geometry.Geometry(
cost_matrix=cost_matrix,
epsilon=epsilon,
relative_epsilon=relative_epsilon
)

return linear_problem.LinearProblem(
geom, self.a, self.b, tau_a=self.tau_a, tau_b=self.tau_b
Expand All @@ -363,24 +348,22 @@ def update_lr_linearization(
self,
lr_sink: "sinkhorn_lr.LRSinkhornOutput",
*,
remove_scale: bool = False,
relative_epsilon: Optional[bool] = None,
) -> linear_problem.LinearProblem:
"""Update a Quad problem linearization using a LR Sinkhorn."""
return linear_problem.LinearProblem(
self.update_lr_geom(lr_sink, remove_scale=remove_scale),
self.update_lr_geom(lr_sink, relative_epsilon=relative_epsilon),
self.a,
self.b,
tau_a=self.tau_a,
tau_b=self.tau_b
)

def _fused_cost_matrix(self,
unscale: bool = False) -> Union[float, jnp.ndarray]:
@property
def _fused_cost_matrix(self) -> Union[float, jnp.ndarray]:
if not self.is_fused:
return 0.0
geom_xy = self.geom_xy
if unscale:
geom_xy = geom_xy.set_scale_cost(1.0)

if isinstance(geom_xy, pointcloud.PointCloud) and geom_xy.is_online:
return geom_xy._compute_cost_matrix() * geom_xy.inv_scale_cost
Expand Down
36 changes: 21 additions & 15 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,8 @@ class GromovWasserstein(was_solver.WassersteinSolver):
warm_start: Whether to initialize (low-rank) Sinkhorn calls using values
from the previous iteration. If `None`, warm starts are not used for
standard Sinkhorn, but used for low-rank Sinkhorn.
unscale_last_linearization: Whether to remove any scaling from the
cost matrices of the last linearization stored in
:attr:`~ott.solvers.quadratic.gromov_wasserstein.GWOutput.geom`.
This has the practical benefit that, while the OT coupling matrices
obtained with GW might have been computed by re-scaling cost matrices for
numerical stability, the last linearization stored in the geometry will be
unscaled and recomputed with the original cost values.
relative_epsilon: Whether to use relative epsilon in the linearized
geometry.
quad_initializer: Quadratic initializer. If the solver is entropic,
:class:`~ott.initializers.quadratic.initializers.QuadraticInitializer`
is always used. Otherwise, the quadratic initializer wraps the low-rank
Expand All @@ -194,7 +189,7 @@ def __init__(
self,
*args: Any,
warm_start: Optional[bool] = None,
unscale_last_linearization: bool = False,
relative_epsilon: Optional[bool] = None,
quad_initializer: Optional[
Union[Literal["random", "rank2", "k-means", "generalized-k-means"],
quad_initializers.BaseQuadraticInitializer]] = None,
Expand All @@ -204,7 +199,7 @@ def __init__(
):
super().__init__(*args, **kwargs)
self._warm_start = warm_start
self.unscale_last_linearization = unscale_last_linearization
self.relative_epsilon = relative_epsilon
self.quad_initializer = quad_initializer
self.progress_fn = progress_fn
self.kwargs_init = {} if kwargs_init is None else kwargs_init
Expand Down Expand Up @@ -236,21 +231,27 @@ def __call__(

if init is None:
initializer = self.create_initializer(prob)
init = initializer(prob, epsilon=self.epsilon, rng=rng1, **kwargs)
init = initializer(
prob,
epsilon=self.epsilon,
rng=rng1,
relative_epsilon=self.relative_epsilon,
**kwargs
)

out = iterations(self, prob, init, rng2)
# TODO(lpapaxanthoos): remove stop_gradient when using backprop
if self.is_low_rank:
linearization = prob.update_lr_linearization(
jax.lax.stop_gradient(out.linear_state),
remove_scale=self.unscale_last_linearization
relative_epsilon=self.relative_epsilon,
)
else:
linearization = prob.update_linearization(
jax.lax.stop_gradient(out.linear_state),
epsilon=self.epsilon,
old_transport_mass=jax.lax.stop_gradient(out.old_transport_mass),
remove_scale=self.unscale_last_linearization,
relative_epsilon=self.relative_epsilon,
)

linear_state = out.linear_state.set_cost(linearization, True, True)
Expand Down Expand Up @@ -366,7 +367,7 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
children, aux_data = super().tree_flatten()
aux_data["warm_start"] = self._warm_start
aux_data["progress_fn"] = self.progress_fn
aux_data["unscale_last_linearization"] = self.unscale_last_linearization
aux_data["relative_epsilon"] = self.relative_epsilon
aux_data["quad_initializer"] = self.quad_initializer
aux_data["kwargs_init"] = self.kwargs_init
return children, aux_data
Expand Down Expand Up @@ -396,12 +397,17 @@ def body_fn(
rng = state.rngs[iteration]
init = (lin_state.q, lin_state.r,
lin_state.g) if solver.warm_start else (None, None, None)
linear_pb = prob.update_lr_linearization(state.linear_state)
linear_pb = prob.update_lr_linearization(
state.linear_state, relative_epsilon=solver.relative_epsilon
)
out = solver.linear_ot_solver(linear_pb, init=init, rng=rng)
else:
init = (lin_state.f, lin_state.g) if solver.warm_start else (None, None)
linear_pb = prob.update_linearization(
lin_state, solver.epsilon, state.old_transport_mass
lin_state,
solver.epsilon,
state.old_transport_mass,
relative_epsilon=solver.relative_epsilon,
)
out = solver.linear_ot_solver(linear_pb, init=init)

Expand Down
41 changes: 18 additions & 23 deletions tests/solvers/quadratic/gw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,36 +392,31 @@ def test_gw_lr_warm_start_helps(self, rng: jax.random.PRNGKeyArray):
with pytest.raises(AssertionError):
np.testing.assert_allclose(out_cold.matrix, out_warm.matrix)

@pytest.mark.parametrize("scale_cost", [1.15, 2.3])
def test_unscale_last_linearization(
self, rng: jax.random.PRNGKeyArray, scale_cost: float
@pytest.mark.parametrize("scale_cost", [1.0, "mean"])
def test_relative_epsilon(
self,
rng: jax.random.PRNGKeyArray,
scale_cost: Union[float, str],
):
eps = 1e-2
rng1, rng2 = jax.random.split(rng, 2)
n, m = 7, 16
rtol = atol = 1e-3

geom_x = pointcloud.PointCloud(
jax.random.normal(rng1, (n, 2)), scale_cost=scale_cost
jax.random.normal(rng1, (49, 5)), scale_cost=scale_cost
)
geom_y = pointcloud.PointCloud(
jax.random.normal(rng2, (m, 6)), scale_cost=scale_cost
jax.random.normal(rng2, (78, 6)), scale_cost=scale_cost
)
# hold true only when `scale_cost` is the same for both geometries
expected = 1.0 / (geom_x.inv_scale_cost * geom_y.inv_scale_cost)

prob = quadratic_problem.QuadraticProblem(geom_x, geom_y)
solver_scaled = gromov_wasserstein.GromovWasserstein(
unscale_last_linearization=False
)
solver_unscaled = gromov_wasserstein.GromovWasserstein(
unscale_last_linearization=True

solver = gromov_wasserstein.GromovWasserstein(
epsilon=eps, relative_epsilon=True
)

out_scaled = solver_scaled(prob)
out_unscaled = solver_unscaled(prob)
actual = out_unscaled.primal_cost / out_scaled.primal_cost
out = solver(prob)

np.testing.assert_allclose(
out_scaled.matrix, out_unscaled.matrix, rtol=rtol, atol=atol
)
np.testing.assert_allclose(expected, actual, rtol=rtol, atol=atol)
if scale_cost == 1.0:
assert 40 < out.reg_gw_cost < 41
assert 38 < out.primal_cost < 39
else:
assert 0.215 < out.reg_gw_cost < 0.22
assert 0.19 < out.primal_cost < 0.20
Loading