Skip to content

Commit

Permalink
Re-add gromov_wasserstein.solve, polish docs
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Jan 6, 2023
1 parent 037c94c commit ce59e8e
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 33 deletions.
1 change: 1 addition & 0 deletions docs/solvers/quadratic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Gromov-Wasserstein Solvers
.. autosummary::
:toctree: _autosummary

gromov_wasserstein.solve
gromov_wasserstein.GromovWasserstein
gromov_wasserstein.GWOutput

Expand Down
16 changes: 8 additions & 8 deletions src/ott/problems/linear/linear_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,20 @@ class LinearProblem:
r"""Linear OT problem.
This class describes the main ingredients appearing in a linear OT problem.
Namely, a `geom` object (including cost structure/points) describing point
clouds or the support of measures, followed by probability masses `a` and `b`.
Unabalancedness of the problem is also kept track of, through two coefficients
`tau_a` and `tau_b`, which are both kept between 0 and 1
Namely, a ``geom`` object (including cost structure/points) describing point
clouds or the support of measures, followed by probability masses ``a`` and
``b``. Unbalancedness of the problem is also kept track of, through two
coefficients ``tau_a`` and ``tau_b``, which are both kept between 0 and 1
(1 corresponding to a balanced OT problem).
Args:
geom: The ground geometry cost of the linear problem.
a: The first marginal. If `None`, it will be uniform.
b: The second marginal. If `None`, it will be uniform.
tau_a: If smaller than `1`, defines how much unbalanced the problem is on
the first marginal.
tau_b: If smaller than `1`, defines how much unbalanced the problem is on
the second marginal.
tau_a: If `< 1`, defines how much unbalanced the problem is
on the first marginal.
tau_b: If `< 1`, defines how much unbalanced the problem is
on the second marginal.
"""

def __init__(
Expand Down
36 changes: 18 additions & 18 deletions src/ott/problems/quadratic/quadratic_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

@jax.tree_util.register_pytree_node_class
class QuadraticProblem:
r"""Quadratic regularized OT problem.
r"""Quadratic OT problem.
The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:`peyre:16`, eq. 4.
Expand All @@ -48,10 +48,10 @@ class QuadraticProblem:
Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for Fused Gromov
Wasserstein. If `None`, the problem reduces to a
plain Gromov Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov Wasserstein,
geom_xy: Geometry defining the linear penalty term for
Fused Gromov-Wasserstein. If `None`, the problem reduces to a plain
Gromov-Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
scale_cost: option to rescale the cost matrices:
Expand All @@ -63,22 +63,22 @@ class QuadraticProblem:
:class:`~ott.geometry.pointcloud.PointCloud`.
- if `None`, do not scale the cost matrices.
a: jnp.ndarray[n] representing the probability weights of the samples
from geom_xx. If None, it will be uniform.
b: jnp.ndarray[n] representing the probability weights of the samples
from geom_yy. If None, it will be uniform.
a: array representing the probability weights of the samples
from ``geom_xx``. If `None`, it will be uniform.
b: array representing the probability weights of the samples
from ``geom_yy``. If `None`, it will be uniform.
loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear
part of the loss (see in the pydoc of the class lin1, lin2). The second
one is the quadratic part (quad1, quad2). By default, the loss
is set as the 4 functions representing the squared Euclidean loss, and
this property is taken advantage of in subsequent computations. See
Alternatively, KL loss can be specified in no less optimized way.
tau_a: if lower that 1.0, defines how much unbalanced the problem is on
part of the loss. The second one is the quadratic part (quad1, quad2).
By default, the loss is set as the 4 functions representing the squared
Euclidean loss, and this property is taken advantage of in subsequent
computations. Alternatively, KL loss can be specified in no less optimized
way.
tau_a: if `< 1.0`, defines how much unbalanced the problem is on
the first marginal.
tau_b: if lower that 1.0, defines how much unbalanced the problem is on
tau_b: if `< 1.0`, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise ``tau_a`` and ``tau_b`` only affect
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect
the inner Sinkhorn loop.
ranks: Ranks of the cost matrices, see
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when
Expand Down Expand Up @@ -274,7 +274,7 @@ def update_linearization(
If the problem is unbalanced (``tau_a < 1.0 or tau_b < 1.0``), two cases are
possible, as explained in :meth:`init_linearization` above.
Finally, it is also possible to consider a Fused Gromov Wasserstein problem.
Finally, it is also possible to consider a Fused Gromov-Wasserstein problem.
Details about the resulting cost matrix are also given in
:meth:`init_linearization`.
Expand Down
10 changes: 5 additions & 5 deletions src/ott/solvers/linear/sinkhorn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1085,16 +1085,16 @@ def solve(
tau_b: float = 1.0,
**kwargs: Any
) -> SinkhornOutput:
"""Solve regularized OT problem using Sinkhorn iterations.
"""Solve linear regularized OT problem.
Args:
geom: The ground geometry cost of the linear problem.
a: The first marginal. If `None`, it will be uniform.
b: The second marginal. If `None`, it will be uniform.
tau_a: If smaller than `1`, defines how much unbalanced the problem is on
the first marginal.
tau_b: If smaller than `1`, defines how much unbalanced the problem is on
the second marginal.
tau_a: If `< 1`, defines how much unbalanced the problem is
on the first marginal.
tau_b: If `< 1`, defines how much unbalanced the problem is
on the second marginal.
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
Expand Down
103 changes: 101 additions & 2 deletions src/ott/solvers/quadratic/gromov_wasserstein.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
from ott.initializers.quadratic import initializers as quad_initializers
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.problems.quadratic import quadratic_costs, quadratic_problem
from ott.solvers import was_solver
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["GWOutput", "GromovWasserstein"]
__all__ = ["GWOutput", "GromovWasserstein", "solve"]

LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]

Expand Down Expand Up @@ -397,3 +397,102 @@ def body_fn(
)

return solver.output_from_state(state)


def solve(
geom_xx: geometry.Geometry,
geom_yy: geometry.Geometry,
geom_xy: Optional[geometry.Geometry] = None,
fused_penalty: float = 1.0,
scale_cost: Optional[Union[bool, float, str]] = False,
a: Optional[jnp.ndarray] = None,
b: Optional[jnp.ndarray] = None,
loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl',
tau_a: Optional[float] = 1.0,
tau_b: Optional[float] = 1.0,
gw_unbalanced_correction: bool = True,
ranks: Union[int, Tuple[int, ...]] = -1,
tolerances: Union[float, Tuple[float, ...]] = 1e-2,
**kwargs: Any,
) -> GWOutput:
r"""Solve quadratic regularized OT problem.
The quadratic loss of a single OT matrix is assumed to
have the form given in :cite:`peyre:16`, eq. 4.
The two geometries below parameterize matrices :math:`C` and :math:`\bar{C}`
in that equation. The function :math:`L` (of two real values) in that equation
is assumed to match the form given in eq. 5., with our notations:
.. math::
L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
Args:
geom_xx: Ground geometry of the first space.
geom_yy: Ground geometry of the second space.
geom_xy: Geometry defining the linear penalty term for
Fused Gromov-Wasserstein. If `None`, the problem reduces to
a plain Gromov-Wasserstein problem.
fused_penalty: multiplier of the linear term in Fused Gromov-Wasserstein,
i.e. problem = purely quadratic + fused_penalty * linear problem.
Ignored if ``geom_xy`` is not specified.
scale_cost: option to rescale the cost matrices:
- if `True`, use the default for each geometry.
- if `False`, keep the original scaling in geometries.
- if :class:`str`, use a specific method available in
:class:`~ott.geometry.geometry.Geometry` or
:class:`~ott.geometry.pointcloud.PointCloud`.
- if `None`, do not scale the cost matrices.
a: array representing the probability weights of the samples
from ``geom_xx``. If `None`, it will be uniform.
b: array representing the probability weights of the samples
from ``geom_yy``. If `None`, it will be uniform.
loss: a 2-tuple of 2-tuples of Callable. The first tuple is the linear
part of the loss. The second one is the quadratic part (quad1, quad2).
By default, the loss is set as the 4 functions representing the squared
Euclidean loss, and this property is taken advantage of in subsequent
computations. Alternatively, KL loss can be specified in no less optimized
way.
tau_a: if `< 1.0`, defines how much unbalanced the problem is on
the first marginal.
tau_b: if `< 1.0`, defines how much unbalanced the problem is on
the second marginal.
gw_unbalanced_correction: Whether the unbalanced version of
:cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect
the inner Sinkhorn loop.
ranks: Ranks of the cost matrices, see
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when
geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with
`'sqeucl'` cost function. If `-1`, the geometries will not be converted
to low-rank. If :class:`tuple`, it specifies the ranks of ``geom_xx``,
``geom_yy`` and ``geom_xy``, respectively. If :class:`int`, rank is shared
across all geometries.
tolerances: Tolerances used when converting geometries to low-rank. Used
when geometries are not :class:`~ott.geometry.pointcloud.PointCloud` with
`'sqeucl'` cost. If :class:`float`, it is shared across all geometries.
kwargs: Keyword arguments for
:class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein`.
Returns:
Gromov-Wasserstein output.
"""
prob = quadratic_problem.QuadraticProblem(
geom_xx,
geom_yy,
geom_xy=geom_xy,
fused_penalty=fused_penalty,
scale_cost=scale_cost,
a=a,
b=b,
loss=loss,
tau_a=tau_a,
tau_b=tau_b,
gw_unbalanced_correction=gw_unbalanced_correction,
ranks=ranks,
tolerances=tolerances
)
solver = GromovWasserstein(**kwargs)
return solver(prob)

0 comments on commit ce59e8e

Please sign in to comment.