Skip to content

Commit

Permalink
refactoring UnivariateSolver (#472)
Browse files Browse the repository at this point in the history
Merging as docs issue comes from recent `neural` refactoring.

* refactoring

* refactoring following michal comments

* fix tests

* last fixes
  • Loading branch information
marcocuturi committed Nov 27, 2023
1 parent 75cdd11 commit fb7b76a
Show file tree
Hide file tree
Showing 10 changed files with 524 additions and 290 deletions.
1 change: 1 addition & 0 deletions docs/geometry.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ Cost Functions
costs.ElasticSTVS
costs.ElasticSqKOverlap
costs.SoftDTW
distrib_costs.UnivariateWasserstein

Utilities
---------
Expand Down
1 change: 1 addition & 0 deletions src/ott/geometry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from . import (
costs,
distrib_costs,
epsilon_scheduler,
geometry,
graph,
Expand Down
84 changes: 84 additions & 0 deletions src/ott/geometry/distrib_costs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import jax
import jax.numpy as jnp

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import univariate

__all__ = [
"UnivariateWasserstein",
]


@jax.tree_util.register_pytree_node_class
class UnivariateWasserstein(costs.CostFn):
"""1D Wasserstein cost for two 1D distributions.
This ground cost between considers vectors as a family of values. The
Wasserstein distance between them is the 1D OT cost, using a user-defined
ground cost.
Args:
kwargs: arguments passed on when calling the
:class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include
random key, or specific instructions to subsample or compute using
quantiles.
"""

def __init__(
self,
ground_cost: Optional[costs.TICost] = None,
solver: Optional[univariate.UnivariateSolver] = None,
**kwargs: Any
):
from ott.solvers.linear import univariate
super().__init__()

self.ground_cost = (
costs.SqEuclidean() if ground_cost is None else ground_cost
)

self._solver = univariate.UnivariateSolver() if solver is None else solver
self._kwargs_solve = kwargs

def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.
Args:
x: vector, array of shape ``[n,]``
y: vector, array of shape ``[m,]``
Returns:
The transport cost.
"""
out = self._solver(
linear_problem.LinearProblem(
pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
), **self._kwargs_solve
)
return jnp.squeeze(out.ot_costs)

def tree_flatten(self): # noqa: D102
return (self.ground_cost,), (self._solver,)

@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, *aux_data)
15 changes: 14 additions & 1 deletion src/ott/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp
Expand All @@ -29,6 +29,7 @@
"gen_js",
"logsumexp",
"softmin",
"sort_and_argsort",
"barycentric_projection",
]

Expand Down Expand Up @@ -220,3 +221,15 @@ def barycentric_projection(
return jax.vmap(
lambda m, y: cost_fn.barycenter(m, y)[0], in_axes=[0, None]
)(matrix, y)


def sort_and_argsort(
x: jnp.array,
*,
argsort: bool = False
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Unified function that returns both sort and argsort, if latter needed."""
if argsort:
i_x = jnp.argsort(x)
return x[i_x], i_x
return jnp.sort(x), None
10 changes: 10 additions & 0 deletions src/ott/problems/linear/linear_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ def is_balanced(self) -> bool:
"""Whether the problem is balanced."""
return self.tau_a == 1.0 and self.tau_b == 1.0

@property
def is_uniform(self) -> bool:
"""True if no weights ``a,b`` were passed, and have defaulted to uniform."""
return self._a is None and self._b is None

@property
def is_equal_size(self) -> bool:
"""True if square shape, i.e. ``n == m``."""
return self.geom.shape[0] == self.geom.shape[1]

@property
def epsilon(self) -> float:
"""Entropic regularization."""
Expand Down
Loading

0 comments on commit fb7b76a

Please sign in to comment.