diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 363a33365..000000000 --- a/.flake8 +++ /dev/null @@ -1,51 +0,0 @@ -[flake8] -max-line-length = 80 -ignore = - # line break before a binary operator -> black does not adhere to PEP8 - W503 - # line break occured after a binary operator -> black does not adhere to PEP8 - W504 - # line too long -> we accept long comment lines; black gets rid of long code lines - E501 - # whitespace before : -> black does not adhere to PEP8 - E203 - # missing whitespace after ,', ';', or ':' -> black does not adhere to PEP8 - E231 - # continuation line over-indented for hanging indent -> black does not adhere to PEP8 - E126 - # E266 too many leading '#' for block comment -> this is fine for indicating sections - E262 - # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient - E731 - # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation - E741 - # Missing docstring in public package - D104 - # ... imported but unused - F401 - # Missing docstring in public module - D100 - # Missing docstring in __init__ - D107 - # Do not perform function calls in argument defaults. - B008 - # line break before binary operator - W503 - # Missing docstring in magic method - D105 - # whitespace before ':' - E203 - # format string does contain unindexed parameters - P101 - # indentation is not a multiple of 4 - E111, E114 - # Missing blank line before section - D411 -exclude = .git,__pycache__,build,docs/_build,dist -# C409: Unnecessary call - rewrite as a literal. -per-file-ignores = - tests/*: D,C408 - */__init__.py: F401 - examples/*: D101, D102, D103 - docs/*: D101, D102 - src/ott/types.py: D102 diff --git a/docs/conf.py b/docs/conf.py index 7b9aac0be..1c7bb12f4 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -26,9 +26,8 @@ import logging from datetime import datetime -from sphinx.util import logging as sphinx_logging - import ott +from sphinx.util import logging as sphinx_logging # -- Project information ----------------------------------------------------- needs_sphinx = "4.0" @@ -47,17 +46,17 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosummary', - 'sphinx.ext.intersphinx', - 'sphinx.ext.mathjax', - 'sphinx.ext.napoleon', - 'sphinx.ext.viewcode', - 'sphinxcontrib.bibtex', - 'sphinx_copybutton', - 'myst_nb', - 'IPython.sphinxext.ipython_console_highlighting', - 'sphinx_autodoc_typehints', + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.intersphinx", + "sphinx.ext.mathjax", + "sphinx.ext.napoleon", + "sphinx.ext.viewcode", + "sphinxcontrib.bibtex", + "sphinx_copybutton", + "myst_nb", + "IPython.sphinxext.ipython_console_highlighting", + "sphinx_autodoc_typehints", ] intersphinx_mapping = { @@ -68,27 +67,28 @@ "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), "pot": ("https://pythonot.github.io/", None), "jaxopt": ("https://jaxopt.github.io/stable", None), + "optax": ("https://optax.readthedocs.io/en/latest/", None), "matplotlib": ("https://matplotlib.org/stable/", None), } -master_doc = 'index' +master_doc = "index" source_suffix = { - '.rst': 'restructuredtext', - '.ipynb': 'myst-nb', + ".rst": "restructuredtext", + ".ipynb": "myst-nb", } todo_include_todos = False autosummary_generate = True -autodoc_typehints = 'description' +autodoc_typehints = "description" # myst-nb myst_heading_anchors = 2 nb_execution_mode = "off" nb_mime_priority_overrides = [("spelling", "text/plain", 0)] myst_enable_extensions = [ - 'amsmath', - 'colon_fence', - 'dollarmath', + "amsmath", + "colon_fence", + "dollarmath", ] # bibliography @@ -108,37 +108,37 @@ ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. -exclude_patterns = ['_build', '**.ipynb_checkpoints'] +exclude_patterns = ["_build", "**.ipynb_checkpoints"] # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. -html_theme = 'sphinx_book_theme' -html_logo = '_static/images/logoOTT.png' -html_favicon = '_static/images/logoOTT.ico' +html_theme = "sphinx_book_theme" +html_logo = "_static/images/logoOTT.png" +html_favicon = "_static/images/logoOTT.ico" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] html_theme_options = { - 'repository_url': 'https://github.com/ott-jax/ott', - 'repository_branch': 'main', - 'path_to_docs': 'docs/', - 'use_repository_button': True, - 'use_fullscreen_button': False, - 'logo_only': True, - 'launch_buttons': { - 'colab_url': 'https://colab.research.google.com', - 'binderhub_url': 'https://mybinder.org', - 'notebook_interface': 'jupyterlab', + "repository_url": "https://github.com/ott-jax/ott", + "repository_branch": "main", + "path_to_docs": "docs/", + "use_repository_button": True, + "use_fullscreen_button": False, + "logo_only": True, + "launch_buttons": { + "colab_url": "https://colab.research.google.com", + "binderhub_url": "https://mybinder.org", + "notebook_interface": "jupyterlab", }, } diff --git a/docs/initializers/linear.rst b/docs/initializers/linear.rst index cb24e3db1..3abd988c1 100644 --- a/docs/initializers/linear.rst +++ b/docs/initializers/linear.rst @@ -12,7 +12,7 @@ Sinkhorn Initializers initializers.DefaultInitializer initializers.GaussianInitializer - initializers.SinkhornInitializer + initializers.SortingInitializer initializers.SubsampleInitializer Low-rank Sinkhorn Initializers diff --git a/pyproject.toml b/pyproject.toml index c9660b78b..2babb8d63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -250,66 +250,43 @@ exclude = [ "dist" ] ignore = [ - # line too long -> we accept long comment lines; black gets rid of long code lines - "E501", # Do not assign a lambda expression, use a def -> lambda expression assignments are convenient "E731", # allow I, O, l as variable names -> I is the identity matrix, i, j, k, l is reasonable indexing notation "E741", # Missing docstring in public package "D104", - # ... imported but unused - "F401", # Missing docstring in public module "D100", # Missing docstring in __init__ "D107", - # Do not perform function calls in argument defaults. - "B008", # Missing docstring in magic method "D105", - # Missing blank line before section - "D411", - ## Flake8 rules not supported by ruff: - # line break before a binary operator -> black does not adhere to PEP8 - # "W503", - # line break occured after a binary operator -> black does not adhere to PEP8 - # "W504", - # whitespace before : -> black does not adhere to PEP8 - # "E203", - # whitespace before : -> black does not adhere to PEP8 - # "E203", - # missing whitespace after ,', ';', or ':' -> black does not adhere to PEP8 - # "E231", - # continuation line over-indented for hanging indent -> black does not adhere to PEP8 - # "E126", - # inline comment should start with '#' -> Scanpy allows them for specific explanations - # "E266", - # format string does contain unindexed parameters - # "P101", - # indentation is not a multiple of 4 - # "E111", - # "E114", ] line-length = 80 select = [ + "D", # flake8-docstrings "I", # isort "E", # pycodestyle "F", # pyflakes "W", # pycodestyle - # below are not autofixed + "Q", # flake8-quotes + "SIM", # flake8-simplify + "NPY", # NumPy-specific rules + "PT", # flake8-pytest-style + "B", # flake8-bugbear "UP", # pyupgrade "C4", # flake8-comprehensions - "B", # flake8-bugbear "BLE", # flake8-blind-except + "T20", # flake8-print + "RET", # flake8-raise ] -unfixable = ["B", "UP", "C4", "BLE"] +unfixable = ["B", "UP", "C4", "BLE", "T20", "RET"] target-version = "py38" [tool.ruff.per-file-ignores] - "tests/*" = ["D", "E", "F", "W", "I", "C408"] + "tests/*" = ["D", "PT004"] # TODO(michalk8): remove `self.initialize` in `tests/` "*/__init__.py" = ["F401"] - "examples/*" = ["D101", "D102", "D103"] - "docs/*" = ["E", "F", "W", "I", "D101", "D102"] + "docs/*" = ["D"] "src/ott/types.py" = ["D102"] [tool.ruff.pydocstyle] convention = "google" @@ -319,3 +296,7 @@ keep-runtime-typing = true [tool.ruff.flake8-tidy-imports] # Disallow all relative imports. ban-relative-imports = "parents" +[tool.ruff.flake8-bugbear] +extend-immutable-calls = ["jax.random.PRNGKey"] +[tool.ruff.flake8-quotes] +inline-quotes = "double" diff --git a/src/ott/_version.py b/src/ott/_version.py index 85ed093e4..972f66cbd 100644 --- a/src/ott/_version.py +++ b/src/ott/_version.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from importlib.metadata import PackageNotFoundError, version try: diff --git a/src/ott/geometry/costs.py b/src/ott/geometry/costs.py index 495a4131b..bddd632d6 100644 --- a/src/ott/geometry/costs.py +++ b/src/ott/geometry/costs.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Several cost/norm functions for relevant vector types.""" import abc import functools import math @@ -47,10 +46,12 @@ class CostFn(abc.ABC): that function is split into two norms -- evaluated on each input separately -- followed by a pairwise cost that involves both inputs, as in: - ``c(x,y) = norm(x) + norm(y) + pairwise(x,y)`` + .. math:: + + c(x,y) = norm(x) + norm(y) + pairwise(x,y) - If the :attr:`norm` function is not implemented, that value is handled as a 0, - and only :func:`pairwise` is used. + If the :attr:`norm` function is not implemented, that value is handled as + :math:`0`, and only :func:`pairwise` is used. """ # no norm function created by default. @@ -58,7 +59,15 @@ class CostFn(abc.ABC): @abc.abstractmethod def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: - pass + """Compute cost between :math:`x` and :math:`y`. + + Args: + x: Array. + y: Array. + + Returns: + The cost. + """ def barycenter(self, weights: jnp.ndarray, xs: jnp.ndarray) -> jnp.ndarray: """Barycentric operator. @@ -85,38 +94,51 @@ def _padder(cls, dim: int) -> jnp.ndarray: return jnp.zeros((1, dim)) def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + """Compute cost between :math:`x` and :math:`y`. + + Args: + x: Array. + y: Array. + + Returns: + The cost, optionally including the :attr:`norms ` of + :math:`x`/:math:`y`. + """ cost = self.pairwise(x, y) if self.norm is None: return cost return cost + self.norm(x) + self.norm(y) + # TODO(michalk8): unused def all_pairs(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Compute matrix of all costs (including norms) for vectors in x / y. + """Compute matrix of all pairwise costs, including the :attr:`norms `. Args: - x: [num_a, d] jnp.ndarray - y: [num_b, d] jnp.ndarray + x: Array of shape ``[n, ...]``. + y: Array of shape ``[m, ...]``. + Returns: - [num_a, num_b] matrix of cost evaluations. + Array of shape ``[n, m]`` of cost evaluations. """ return jax.vmap(lambda x_: jax.vmap(lambda y_: self(x_, y_))(y))(x) def all_pairs_pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: - """Compute matrix of all pairwise-costs (no norms) for vectors in x / y. + """Compute matrix of all pairwise costs, excluding the :attr:`norms `. Args: - x: [num_a, d] jnp.ndarray - y: [num_b, d] jnp.ndarray + x: Array of shape ``[n, ...]``. + y: Array of shape ``[m, ...]``. + Returns: - [num_a, num_b] matrix of pairwise cost evaluations. + Array of shape ``[n, m]`` of cost evaluations. """ return jax.vmap(lambda x_: jax.vmap(lambda y_: self.pairwise(x_, y_))(y))(x) - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (), None @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 del aux_data return cls(*children) @@ -298,10 +320,10 @@ def prox_reg(self, z: jnp.ndarray) -> jnp.ndarray: """Proximal operator of :func:`reg`.""" raise NotImplementedError("Proximal operator is not implemented.") - def h(self, z: jnp.ndarray) -> float: + def h(self, z: jnp.ndarray) -> float: # noqa: D102 return 0.5 * jnp.linalg.norm(z, ord=2) ** 2 + self.reg(z) - def h_legendre(self, z: jnp.ndarray) -> float: + def h_legendre(self, z: jnp.ndarray) -> float: # noqa: D102 q = jax.lax.stop_gradient(self.prox_reg(z)) return jnp.sum(q * z) - self.h(q) @@ -604,13 +626,10 @@ def barycenter( cov_bary = self.covariance_fixpoint_iter( covs=covs, weights=weights, **kwargs ) - barycenter = mean_and_cov_to_x(mu_bary, cov_bary, self._dimension) - return barycenter + return mean_and_cov_to_x(mu_bary, cov_bary, self._dimension) @classmethod def _padder(cls, dim: int) -> jnp.ndarray: - """Pad with concatenated zero means and \ - raveled identity covariance matrix.""" dimension = int((-1 + math.sqrt(1 + 4 * dim)) / 2) padding = mean_and_cov_to_x( jnp.zeros((dimension,)), jnp.eye(dimension), dimension @@ -744,7 +763,7 @@ class SoftDTW(CostFn): """Soft dynamic time warping (DTW) cost :cite:`cuturi:17`. Args: - gamma: Smoothing parameter for the soft-min operator. + gamma: Smoothing parameter :math:`> 0` for the soft-min operator. ground_cost: Ground cost function. If ``None``, use :class:`~ott.geometry.costs.SqEuclidean`. debiased: Whether to compute the debiased soft-DTW :cite:`blondel:21`. @@ -760,7 +779,7 @@ def __init__( self.ground_cost = SqEuclidean() if ground_cost is None else ground_cost self.debiased = debiased - def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: + def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float: # noqa: D102 c_xy = self._soft_dtw(x, y) if self.debiased: return c_xy - 0.5 * (self._soft_dtw(x, x) + self._soft_dtw(y, y)) @@ -810,11 +829,11 @@ def body( (_, carry), _ = jax.lax.scan(body, init, model_matrix[2:]) return carry[-1] - def tree_flatten(self): + def tree_flatten(self): # noqa: D102 return (self.gamma, self.ground_cost), {"debiased": self.debiased} @classmethod - def tree_unflatten(cls, aux_data, children): + def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data) @@ -825,10 +844,10 @@ def x_to_means_and_covs(x: jnp.ndarray, Args: x: [num_gaussians, dimension, (1 + dimension)] array of concatenated means and covariances (raveled) dimension: the dimension of the Gaussians. + dimension: Dimensionality of the Gaussians. Returns: - means: [num_gaussians, dimension] array that holds the means. - covariances: [num_gaussians, dimension] array that holds the covariances. + Means and covariances of shape ``[num_gaussian, dimension]``. """ x = jnp.atleast_2d(x) means = x[:, :dimension] @@ -842,5 +861,6 @@ def mean_and_cov_to_x( mean: jnp.ndarray, covariance: jnp.ndarray, dimension: int ) -> jnp.ndarray: """Ravel a Gaussian's mean and covariance matrix to d(1 + d) vector.""" - x = jnp.concatenate((mean, jnp.reshape(covariance, (dimension * dimension)))) - return x + return jnp.concatenate( + (mean, jnp.reshape(covariance, (dimension * dimension))) + ) diff --git a/src/ott/geometry/epsilon_scheduler.py b/src/ott/geometry/epsilon_scheduler.py index 3dabed2ff..6c4b1905e 100644 --- a/src/ott/geometry/epsilon_scheduler.py +++ b/src/ott/geometry/epsilon_scheduler.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A class to define a scheduler for the entropic regularization epsilon.""" from typing import Any, Optional import jax diff --git a/src/ott/geometry/geometry.py b/src/ott/geometry/geometry.py index 6a0ee4a07..026ef007d 100644 --- a/src/ott/geometry/geometry.py +++ b/src/ott/geometry/geometry.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A class describing operations used to instantiate and use a geometry.""" import functools from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union @@ -77,8 +76,8 @@ def __init__( kernel_matrix: Optional[jnp.ndarray] = None, epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None, relative_epsilon: Optional[bool] = None, - scale_cost: Union[bool, int, float, Literal['mean', 'max_cost', - 'median']] = 1.0, + scale_cost: Union[bool, int, float, Literal["mean", "max_cost", + "median"]] = 1.0, src_mask: Optional[jnp.ndarray] = None, tgt_mask: Optional[jnp.ndarray] = None, ): @@ -99,7 +98,6 @@ def __init__( @property def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" - return None @property def cost_matrix(self) -> jnp.ndarray: @@ -127,8 +125,10 @@ def mean_cost_matrix(self) -> float: @property def kernel_matrix(self) -> jnp.ndarray: - """Kernel matrix, either provided by user or recomputed from \ - :attr:`cost_matrix`.""" + """Kernel matrix. + + Either provided by user or recomputed from :attr:`cost_matrix`. + """ if self._kernel_matrix is None: return jnp.exp(-(self._cost_matrix * self.inv_scale_cost / self.epsilon)) return self._kernel_matrix ** self.inv_scale_cost @@ -198,13 +198,13 @@ def inv_scale_cost(self) -> float: (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost self = self._masked_geom(mask_value=jnp.nan) - if self._scale_cost == 'max_cost': + if self._scale_cost == "max_cost": return 1.0 / jnp.nanmax(self._cost_matrix) - if self._scale_cost == 'mean': + if self._scale_cost == "mean": return 1.0 / jnp.nanmean(self._cost_matrix) - if self._scale_cost == 'median': + if self._scale_cost == "median": return 1.0 / jnp.nanmedian(self._cost_matrix) - raise ValueError(f'Scaling {self._scale_cost} not implemented.') + raise ValueError(f"Scaling {self._scale_cost} not implemented.") def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry": # case when `geom` doesn't have `scale_cost` or doesn't need to be modified @@ -215,7 +215,7 @@ def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry": aux_data["scale_cost"] = scale_cost return type(self).tree_unflatten(aux_data, children) - def copy_epsilon(self, other: 'Geometry') -> "Geometry": + def copy_epsilon(self, other: "Geometry") -> "Geometry": """Copy the epsilon parameters from another geometry.""" other_epsilon = other._epsilon children, aux_data = self.tree_flatten() @@ -244,7 +244,7 @@ def apply_lse_kernel( vec: jnp.ndarray = None, axis: int = 0 ) -> jnp.ndarray: - r"""Apply :attr:`kernel_matrix` in log domain on a pair of dual potential variables. + r"""Apply :attr:`kernel_matrix` in log domain. This function applies the ground geometry's kernel in log domain, using a stabilized formulation. At a high level, this iteration performs either: @@ -265,8 +265,8 @@ def apply_lse_kernel( eps: float, regularization strength vec: jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise - multiplication of exp(g / eps) by a vector. This is carried out by adding - weights to the log-sum-exp function, and needs to handle signs + multiplication of exp(g / eps) by a vector. This is carried out by + adding weights to the log-sum-exp function, and needs to handle signs separately. axis: summing over axis 0 when doing (2), or over axis 1 when doing (1) @@ -419,11 +419,11 @@ def _softmax( self._center(f, g) / eps, b=vec, axis=axis, return_sign=True ) return eps * lse_output[0], lse_output[1] - else: - lse_output = mu.logsumexp( - self._center(f, g) / eps, axis=axis, return_sign=False - ) - return eps * lse_output, jnp.array([1.0]) + + lse_output = mu.logsumexp( + self._center(f, g) / eps, axis=axis, return_sign=False + ) + return eps * lse_output, jnp.array([1.0]) @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_potentials( @@ -611,8 +611,8 @@ def prepare_divergences( """Instantiate 2 (or 3) geometries to compute a Sinkhorn divergence.""" size = 2 if static_b else 3 nones = [None, None, None] - cost_matrices = kwargs.pop('cost_matrix', args) - kernel_matrices = kwargs.pop('kernel_matrix', nones) + cost_matrices = kwargs.pop("cost_matrix", args) + kernel_matrices = kwargs.pop("kernel_matrix", nones) cost_matrices = cost_matrices if cost_matrices is not None else nones return tuple( cls(cost_matrix=arg1, kernel_matrix=arg2, **kwargs) @@ -625,7 +625,7 @@ def to_LRCGeometry( tol: float = 1e-2, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), scale: float = 1. - ) -> 'low_rank.LRCGeometry': + ) -> "low_rank.LRCGeometry": r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`. When `rank=min(n,m)` or `0` (by default), use :func:`jax.numpy.linalg.svd`. @@ -740,7 +740,7 @@ def subset_fn( arr = arr[jnp.atleast_1d(src_ixs)] if tgt_ixs is not None: arr = arr[:, jnp.atleast_1d(tgt_ixs)] - return arr + return arr # noqa: RET504 return self._mask_subset_helper( src_ixs, @@ -788,7 +788,7 @@ def mask_fn( arr = jnp.where(src_mask[:, None], arr, mask_value) if tgt_mask is not None: arr = jnp.where(tgt_mask[None, :], arr, mask_value) - return arr + return arr # noqa: RET504 src_mask = self._normalize_mask(src_mask, self.shape[0]) tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) @@ -861,16 +861,14 @@ def _masked_geom(self, mask_value: float = 0.) -> "Geometry": @property def _n_normed_ones(self) -> jnp.ndarray: - """Normalized array of shape ``[num_a,]`` \ - taking into account :attr:`src_mask`.""" + """Normalized array of shape ``[num_a,]``.""" mask = self.src_mask arr = jnp.ones(self.shape[0]) if mask is None else mask return arr / jnp.sum(arr) @property def _m_normed_ones(self) -> jnp.ndarray: - """Normalized array of shape ``[num_b,]`` \ - taking into account :attr:`tgt_mask`.""" + """Normalized array of shape ``[num_b,]``.""" mask = self.tgt_mask arr = jnp.ones(self.shape[1]) if mask is None else mask return arr / jnp.sum(arr) diff --git a/src/ott/geometry/grid.py b/src/ott/geometry/grid.py index f9449d731..c9eee4573 100644 --- a/src/ott/geometry/grid.py +++ b/src/ott/geometry/grid.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implements a geometry class for points supported on a cartesian product.""" import itertools from typing import Any, List, NoReturn, Optional, Sequence, Tuple @@ -99,15 +98,15 @@ def __init__( self.num_a = np.prod(np.array(grid_size)) self.grid_dimension = len(self.grid_size) else: - raise ValueError('Input either grid_size t-uple or grid locations x.') + raise ValueError("Input either grid_size t-uple or grid locations x.") if cost_fns is None: cost_fns = [costs.SqEuclidean()] self.cost_fns = cost_fns self.kwargs = { - 'num_a': self.num_a, - 'grid_size': self.grid_size, - 'grid_dimension': self.grid_dimension + "num_a": self.num_a, + "grid_size": self.grid_size, + "grid_dimension": self.grid_dimension } super().__init__(**kwargs) @@ -131,7 +130,7 @@ def geometries(self) -> List[geometry.Geometry]: @property def median_cost_matrix(self) -> NoReturn: """Not implemented.""" - raise NotImplementedError('Median cost not implemented for grids.') + raise NotImplementedError("Median cost not implemented for grids.") @property def can_LRC(self) -> bool: # noqa: D102 @@ -295,10 +294,10 @@ def transport_from_potentials( ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_potentials` instead.""" raise ValueError( - 'Grid geometry cannot instantiate a transport matrix, use', - ' apply_transport_from_potentials(...) if you wish to ', - ' apply the transport matrix to a vector, or use a point ' - ' cloud geometry instead' + "Grid geometry cannot instantiate a transport matrix, use", + " apply_transport_from_potentials(...) if you wish to ", + " apply the transport matrix to a vector, or use a point " + " cloud geometry instead" ) def transport_from_scalings( @@ -306,10 +305,10 @@ def transport_from_scalings( ) -> NoReturn: """Not implemented, use :meth:`apply_transport_from_scalings` instead.""" raise ValueError( - 'Grid geometry cannot instantiate a transport matrix, use ', - 'apply_transport_from_scalings(...) if you wish to ', - 'apply the transport matrix to a vector, or use a point ' - 'cloud geometry instead.' + "Grid geometry cannot instantiate a transport matrix, use ", + "apply_transport_from_scalings(...) if you wish to ", + "apply the transport matrix to a vector, or use a point " + "cloud geometry instead." ) def subset( @@ -335,8 +334,8 @@ def prepare_divergences( **kwargs: Any ) -> Tuple["Grid", ...]: """Instantiate the geometries used for a divergence computation.""" - grid_size = kwargs.pop('grid_size', None) - x = kwargs.pop('x', args) + grid_size = kwargs.pop("grid_size", None) + x = kwargs.pop("x", args) sep_grid = cls(x=x, grid_size=grid_size, **kwargs) size = 2 if static_b else 3 @@ -375,6 +374,7 @@ def to_LRCGeometry( kwargs: Keyword arguments, such as ``rank``, to :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when geometries on each slice are not low-rank. + Returns: :class:`~ott.geometry.low_rank.LRCGeometry` object. """ diff --git a/src/ott/geometry/low_rank.py b/src/ott/geometry/low_rank.py index 126137c31..516e7790e 100644 --- a/src/ott/geometry/low_rank.py +++ b/src/ott/geometry/low_rank.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A class describing low-rank geometries.""" from typing import Any, Callable, Literal, Optional, Tuple, Union import jax @@ -58,8 +57,8 @@ def __init__( cost_2: jnp.ndarray, bias: float = 0.0, scale_factor: float = 1.0, - scale_cost: Union[bool, int, float, Literal['mean', 'max_bound', - 'max_cost']] = 1.0, + scale_cost: Union[bool, int, float, Literal["mean", "max_bound", + "max_cost"]] = 1.0, batch_size: Optional[int] = None, **kwargs: Any, ): @@ -68,7 +67,7 @@ def __init__( self._cost_2 = cost_2 self._bias = bias self._scale_factor = scale_factor - self._scale_cost = 'mean' if scale_cost is True else scale_cost + self._scale_cost = "mean" if scale_cost is True else scale_cost self.batch_size = batch_size @property @@ -114,19 +113,19 @@ def inv_scale_cost(self) -> float: # noqa: D102 (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost self = self._masked_geom() - if self._scale_cost == 'max_bound': + if self._scale_cost == "max_bound": x_norm = self._cost_1[:, 0].max() y_norm = self._cost_2[:, 1].max() max_bound = x_norm + y_norm + 2 * jnp.sqrt(x_norm * y_norm) return 1.0 / (max_bound + self._bias) - if self._scale_cost == 'mean': + if self._scale_cost == "mean": factor1 = jnp.dot(self._n_normed_ones, self._cost_1) factor2 = jnp.dot(self._cost_2.T, self._m_normed_ones) mean = jnp.dot(factor1, factor2) + self._bias return 1.0 / mean - if self._scale_cost == 'max_cost': + if self._scale_cost == "max_cost": return 1.0 / self.compute_max_cost() - raise ValueError(f'Scaling {self._scale_cost} not implemented.') + raise ValueError(f"Scaling {self._scale_cost} not implemented.") def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix).""" @@ -136,13 +135,13 @@ def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: # and apply it. First is O(nm), the other is O((n+m)r^2). if n * m < (n + m) * r ** 2: # better use regular apply return super().apply_square_cost(arr, axis) - else: - new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :] - new_cost_2 = self.cost_2[:, :, None] * self.cost_2[:, None, :] - return LRCGeometry( - cost_1=new_cost_1.reshape((n, r ** 2)), - cost_2=new_cost_2.reshape((m, r ** 2)) - ).apply_cost(arr, axis) + + new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :] + new_cost_2 = self.cost_2[:, :, None] * self.cost_2[:, None, :] + return LRCGeometry( + cost_1=new_cost_1.reshape((n, r ** 2)), + cost_2=new_cost_2.reshape((m, r ** 2)) + ).apply_cost(arr, axis) def _apply_cost_to_vec( self, @@ -222,8 +221,7 @@ def body(carry, slice_idx): def finalize(carry): cost1, cost2 = carry - out_slice = jnp.dot(cost2[n_batch * batch_size:], cost1.T) - return out_slice + return jnp.dot(cost2[n_batch * batch_size:], cost1.T) _, out = jax.lax.scan(body, carry, jnp.arange(n_batch)) last_slice = finalize(carry) @@ -235,7 +233,7 @@ def to_LRCGeometry( rank: int = 0, tol: float = 1e-2, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), - ) -> 'LRCGeometry': + ) -> "LRCGeometry": """Return self.""" del rank, tol, rng return self @@ -304,7 +302,7 @@ def _mask_subset_helper( aux_data, [c1, c2, src_mask, tgt_mask] + children ) - def __add__(self, other: 'LRCGeometry') -> 'LRCGeometry': + def __add__(self, other: "LRCGeometry") -> "LRCGeometry": if not isinstance(other, LRCGeometry): return NotImplemented return LRCGeometry( @@ -330,8 +328,8 @@ def tree_flatten(self): # noqa: D102 self._bias, self._scale_factor, ), { - 'scale_cost': self._scale_cost, - 'batch_size': self.batch_size + "scale_cost": self._scale_cost, + "batch_size": self.batch_size } @classmethod diff --git a/src/ott/geometry/pointcloud.py b/src/ott/geometry/pointcloud.py index 22dad3830..8e6de1c6e 100644 --- a/src/ott/geometry/pointcloud.py +++ b/src/ott/geometry/pointcloud.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A geometry defined using 2 point clouds and a cost function between them.""" import math from typing import Any, Callable, Literal, Optional, Tuple, Union @@ -64,8 +63,8 @@ def __init__( cost_fn: Optional[costs.CostFn] = None, batch_size: Optional[int] = None, scale_cost: Union[bool, int, float, - Literal['mean', 'max_norm', 'max_bound', 'max_cost', - 'median']] = 1.0, + Literal["mean", "max_norm", "max_bound", "max_cost", + "median"]] = 1.0, **kwargs: Any ): super().__init__(**kwargs) @@ -134,8 +133,7 @@ def is_squared_euclidean(self) -> bool: # noqa: D102 @property def is_online(self) -> bool: - """Whether :attr:`cost_matrix` or :attr:`kernel_matrix` \ - is computed on-the-fly.""" + """Whether the cost/kernel is computed on-the-fly.""" return self.batch_size is not None # TODO(michalk8): when refactoring, consider PC as a subclass of LR? @@ -149,18 +147,18 @@ def inv_scale_cost(self) -> float: # noqa: D102 (int, float)) or utils.is_jax_array(self._scale_cost): return 1.0 / self._scale_cost self = self._masked_geom() - if self._scale_cost == 'max_cost': + if self._scale_cost == "max_cost": if self.is_online: return 1.0 / self._compute_summary_online(self._scale_cost) return 1.0 / jnp.max(self._compute_cost_matrix()) - if self._scale_cost == 'mean': + if self._scale_cost == "mean": if self.is_online: return 1.0 / self._compute_summary_online(self._scale_cost) if self.shape[0] > 0: geom = self._masked_geom(mask_value=jnp.nan)._compute_cost_matrix() return 1.0 / jnp.nanmean(geom) return 1.0 - if self._scale_cost == 'median': + if self._scale_cost == "median": if not self.is_online: geom = self._masked_geom(mask_value=jnp.nan) return 1.0 / jnp.nanmedian(geom._compute_cost_matrix()) @@ -168,11 +166,11 @@ def inv_scale_cost(self) -> float: # noqa: D102 "Using the median as scaling factor for " "the cost matrix with the online mode is not implemented." ) - if self._scale_cost == 'max_norm': + if self._scale_cost == "max_norm": if self.cost_fn.norm is not None: return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max()) return 1.0 - if self._scale_cost == 'max_bound': + if self._scale_cost == "max_bound": if self.is_squared_euclidean: x_argmax = jnp.argmax(self._norm_x) y_argmax = jnp.argmax(self._norm_y) @@ -186,7 +184,7 @@ def inv_scale_cost(self) -> float: # noqa: D102 "the cost matrix when the cost is not squared euclidean " "is not implemented." ) - raise ValueError(f'Scaling {self._scale_cost} not implemented.') + raise ValueError(f"Scaling {self._scale_cost} not implemented.") def _compute_cost_matrix(self) -> jnp.ndarray: cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y) @@ -302,11 +300,10 @@ def apply_kernel( # noqa: D102 self.x, self.y, self._norm_x, self._norm_y, scaling, eps, self.cost_fn, self.inv_scale_cost ) - if axis == 1: - return app( - self.y, self.x, self._norm_y, self._norm_x, scaling, eps, - self.cost_fn, self.inv_scale_cost - ) + return app( + self.y, self.x, self._norm_y, self._norm_x, scaling, eps, self.cost_fn, + self.inv_scale_cost + ) def transport_from_potentials( # noqa: D102 self, f: jnp.ndarray, g: jnp.ndarray @@ -385,26 +382,26 @@ def _apply_cost( self, arr: jnp.ndarray, axis: int = 0, fn=None ) -> jnp.ndarray: """See :meth:`apply_cost`.""" - if self.is_online: - app = jax.vmap( - _apply_cost_xy, - in_axes=[None, 0, None, self._axis_norm, None, None, None, None] - ) - if arr.ndim == 1: - arr = arr.reshape(-1, 1) - if axis == 0: - return app( - self.x, self.y, self._norm_x, self._norm_y, arr, self.cost_fn, - self.inv_scale_cost, fn - ) - if axis == 1: - return app( - self.y, self.x, self._norm_y, self._norm_x, arr, self.cost_fn, - self.inv_scale_cost, fn - ) - else: + if not self.is_online: return super().apply_cost(arr, axis, fn) + app = jax.vmap( + _apply_cost_xy, + in_axes=[None, 0, None, self._axis_norm, None, None, None, None] + ) + if arr.ndim == 1: + arr = arr.reshape(-1, 1) + + if axis == 0: + return app( + self.x, self.y, self._norm_x, self._norm_y, arr, self.cost_fn, + self.inv_scale_cost, fn + ) + return app( + self.y, self.x, self._norm_y, self._norm_x, arr, self.cost_fn, + self.inv_scale_cost, fn + ) + def vec_apply_cost( self, arr: jnp.ndarray, @@ -451,7 +448,7 @@ def _leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray: return jax.lax.dynamic_slice(t, start_indices, slice_sizes) def _compute_summary_online( - self, summary: Literal['mean', 'max_cost'] + self, summary: Literal["mean", "max_cost"] ) -> float: """Compute mean or max of cost matrix online, i.e. without instantiating it. @@ -500,13 +497,13 @@ def finalize(i: int): scale_cost ) - if summary == 'mean': + if summary == "mean": fn = _apply_cost_xy - elif summary == 'max_cost': + elif summary == "max_cost": fn = _apply_max_xy else: raise ValueError( - f'Scaling method {summary} does not exist for online mode.' + f"Scaling method {summary} does not exist for online mode." ) app = jax.vmap( fn, in_axes=[None, 0, None, self._axis_norm, None, None, None] @@ -527,13 +524,13 @@ def finalize(i: int): val_rest = finalize(n * self.batch_size) val_res = jnp.concatenate([val, val_rest]) - if summary == 'mean': + if summary == "mean": return jnp.sum(val_res * other) - if summary == 'max_cost': + if summary == "max_cost": # TODO(michalk8): explain why scaling is not needed return jnp.max(val_res) raise ValueError( - f'Scaling method {summary} does not exist for online mode.' + f"Scaling method {summary} does not exist for online mode." ) def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray: @@ -571,8 +568,8 @@ def tree_flatten(self): # noqa: D102 self._epsilon_init, self.cost_fn, ), { - 'batch_size': self._batch_size, - 'scale_cost': self._scale_cost + "batch_size": self._batch_size, + "scale_cost": self._scale_cost } @classmethod @@ -588,7 +585,7 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 **aux_data ) - def _cosine_to_sqeucl(self) -> 'PointCloud': + def _cosine_to_sqeucl(self) -> "PointCloud": assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn) (x, y, *args, _), aux_data = self.tree_flatten() x = x / jnp.linalg.norm(x, axis=-1, keepdims=True) @@ -602,7 +599,7 @@ def to_LRCGeometry( self, scale: float = 1.0, **kwargs: Any, - ) -> Union[low_rank.LRCGeometry, 'PointCloud']: + ) -> Union[low_rank.LRCGeometry, "PointCloud"]: r"""Convert point cloud to low-rank geometry. Args: diff --git a/src/ott/geometry/segment.py b/src/ott/geometry/segment.py index 15a45d514..eb90a6248 100644 --- a/src/ott/geometry/segment.py +++ b/src/ott/geometry/segment.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Prepare point clouds for parallel computations.""" from typing import Callable, Optional, Tuple import jax diff --git a/src/ott/initializers/linear/initializers.py b/src/ott/initializers/linear/initializers.py index 6004237e1..bf93ebbc4 100644 --- a/src/ott/initializers/linear/initializers.py +++ b/src/ott/initializers/linear/initializers.py @@ -11,15 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Sinkhorn initializers.""" import abc -import functools -from typing import Any, Dict, Mapping, Optional, Sequence, Tuple +from typing import Any, Dict, Optional, Sequence, Tuple import jax import jax.numpy as jnp -from ott.geometry import geometry, pointcloud +from ott.geometry import pointcloud from ott.problems.linear import linear_problem __all__ = [ @@ -37,9 +35,18 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.random.PRNGKeyArray] = jax.random.PRNGKey(0) + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/scaling f_u.""" + """Initialize Sinkhorn potential/scaling f_u. + + Args: + ot_prob: Linear OT problem. + lse_mode: Return potential if ``True``, scaling if ``False``. + rng: Random number generator for stochastic initializers. + + Returns: + potential/scaling, array of size ``[n,]``. + """ @abc.abstractmethod def init_dual_b( @@ -48,7 +55,16 @@ def init_dual_b( lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: - """Initialization for Sinkhorn potential/scaling g_v.""" + """Initialize Sinkhorn potential/scaling g_v. + + Args: + ot_prob: Linear OT problem. + lse_mode: Return potential if ``True``, scaling if ``False``. + rng: Random number generator for stochastic initializers. + + Returns: + potential/scaling, array of size ``[m,]``. + """ def __call__( self, @@ -62,17 +78,18 @@ def __call__( Args: ot_prob: Linear OT problem. - a: Initial potential/scaling f_u. If ``None``, it will be initialized using - :meth:`init_dual_a`. - b: Initial potential/scaling g_v. If ``None``, it will be initialized using - :meth:`init_dual_b`. - lse_mode: Return potentials if true, scalings otherwise. + a: Initial potential/scaling f_u. + If ``None``, it will be initialized using :meth:`init_dual_a`. + b: Initial potential/scaling g_v. + If ``None``, it will be initialized using :meth:`init_dual_b`. + lse_mode: Return potentials if ``True``, scalings if ``False``. + rng: Random number generator for stochastic initializers. Returns: The initial potentials/scalings. """ n, m = ot_prob.geom.shape - rng_x, rng_y = jax.random.split(rng) + rng_x, rng_y = jax.random.split(rng, 2) if a is None: a = self.init_dual_a(ot_prob, lse_mode=lse_mode, rng=rng_x) if b is None: @@ -91,11 +108,11 @@ def __call__( return a, b - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], {} @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "SinkhornInitializer": return cls(*children, **aux_data) @@ -105,78 +122,46 @@ def tree_unflatten( class DefaultInitializer(SinkhornInitializer): """Default initialization of Sinkhorn dual potentials/primal scalings.""" - def init_dual_a( + def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.random.PRNGKeyArray] = jax.random.PRNGKey(0) + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: - """Initialize Sinkhorn potential/scaling f_u. - - Args: - ot_prob: OT problem between discrete distributions of size n and m. - lse_mode: Return potential if true, scaling if false. - rng: Random number generator for stochastic initializers. - - Returns: - potential/scaling, array of size n. - """ del rng - a = ot_prob.a - init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) - return init_dual_a + return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a) - def init_dual_b( + def init_dual_b( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.random.PRNGKeyArray] = jax.random.PRNGKey(0) + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: - """Initialize Sinkhorn potential/scaling g_v. - - Args: - ot_prob: OT problem between discrete distributions of size n and m. - lse_mode: Return potential if true, scaling if false. - rng: Random number generator for stochastic initializers. - - Returns: - potential/scaling, array of size m. - """ - b = ot_prob.b - init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) - return init_dual_b + del rng + return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b) @jax.tree_util.register_pytree_node_class class GaussianInitializer(DefaultInitializer): """Gaussian initializer :cite:`thornton2022rethinking:22`. - Compute Gaussian approximations of each :class:`~ott.geometry.pointcloud.PointCloud`, - then compute closed from Kantorovich potential between Gaussian approximations using - Brenier's theorem (adapt convex/Brenier potential to Kantorovich). - Use this Gaussian potential to initialize Sinkhorn potentials/scalings. + Compute Gaussian approximations of each + :class:`~ott.geometry.pointcloud.PointCloud`, then compute closed from + Kantorovich potential between Gaussian approximations using Brenier's theorem + (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential + to initialize Sinkhorn potentials/scalings. """ - def init_dual_a( + def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, - rng: Optional[jax.random.PRNGKeyArray] = jax.random.PRNGKey(0) + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: - """Gaussian initialization function. - - Args: - ot_prob: OT problem between discrete distributions of size n and m. - lse_mode: Return potential if true, scaling if false. - rng: Random number generator, not needed for this initializer. - - Returns: - potential/scaling, array of size n. - """ - del rng # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian + del rng assert isinstance( ot_prob.geom, pointcloud.PointCloud ), "Gaussian initializer valid only for pointcloud geoms." @@ -189,22 +174,21 @@ def init_dual_a( # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_prob.geom.scaling_from_potential( + return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential ) - return f_u @jax.tree_util.register_pytree_node_class class SortingInitializer(DefaultInitializer): """Sorting initializer :cite:`thornton2022rethinking:22`. - Solves non-regularized OT problem via sorting, then compute potential through + Solve non-regularized OT problem via sorting, then compute potential through iterated minimum on C-transform and use this potential to initialize regularized potential. Args: - vectorized_update: Use vectorized inner loop if true. + vectorized_update: Whether to use vectorized loop. tolerance: DualSort convergence threshold. max_iter: Max DualSort steps. """ @@ -220,13 +204,6 @@ def __init__( self.max_iter = max_iter self.vectorized_update = vectorized_update - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 - return ([], { - 'tolerance': self.tolerance, - 'max_iter': self.max_iter, - 'vectorized_update': self.vectorized_update - }) - def _init_sorting_dual( self, modified_cost: jnp.ndarray, init_f: jnp.ndarray ) -> jnp.ndarray: @@ -266,20 +243,21 @@ def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, + rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), init_f: Optional[jnp.ndarray] = None, - rng: Optional[jax.random.PRNGKeyArray] = jax.random.PRNGKey(0), ) -> jnp.ndarray: """Apply DualSort algorithm. Args: - ot_prob: OT problem. - lse_mode: Return potential if true, scaling if false. - init_f: potential f, array of size n. This is the starting potential, - which is then updated to make the init potential, so an init of an init. - rng: random number generator for initializer. Not needed for this initializer. + ot_prob: OT problem between discrete distributions. + lse_mode: Return potential if ``True``, scaling if ``False``. + rng: Random number generator for stochastic initializers, unused. + init_f: potential f, array of size ``[n,]``. This is the starting + potential, which is then updated to make the init potential, + so an init of an init. Returns: - potential/scaling f_u, array of size n. + potential/scaling f_u, array of size ``[n,]``. """ del rng assert not ot_prob.geom.is_online, \ @@ -298,46 +276,16 @@ def init_dual_a( f_potential = self._init_sorting_dual(modified_cost, init_f) f_potential = f_potential - jnp.mean(f_potential) - f_u = f_potential if lse_mode else ot_prob.geom.scaling_from_potential( + return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential ) - return f_u - - -def _vectorized_update( - f: jnp.ndarray, modified_cost: jnp.ndarray -) -> jnp.ndarray: - """Inner loop DualSort Update. - - Args: - f: potential f, array of size n. - modified_cost: cost matrix minus diagonal column-wise. - - Returns: - updated potential vector, f. - """ - return jnp.min(modified_cost + f[None, :], axis=1) - - -def _coordinate_update( - f: jnp.ndarray, modified_cost: jnp.ndarray -) -> jnp.ndarray: - """Coordinate-wise updates within inner loop. - - Args: - f: potential f, array of size n. - modified_cost: cost matrix minus diagonal column-wise. - - Returns: - updated potential vector, f. - """ - - def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray: - new_f = jnp.min(modified_cost[i, :] + f) - return f.at[i].set(new_f) - - return jax.lax.fori_loop(0, len(f), body_fn, f) + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 + return ([], { + "tolerance": self.tolerance, + "max_iter": self.max_iter, + "vectorized_update": self.vectorized_update + }) @jax.tree_util.register_pytree_node_class @@ -345,43 +293,37 @@ class SubsampleInitializer(DefaultInitializer): """Subsample initializer :cite:`thornton2022rethinking:22`. Subsample each :class:`~ott.geometry.pointcloud.PointCloud`, then compute - :class:`Sinkhorn potential ` from the - subsampled approximations and use this potential to initialize Sinkhorn potentials/scalings - for the original problem. + :class:`Sinkhorn potential ` + from the subsampled approximations and use this potential to initialize + Sinkhorn potentials/scalings for the original problem. Args: - subsample_n_x: number of points to subsample from each x :class:`~ott.geometry.pointcloud.PointCloud`. - subsample_n_y: number of points to subsample from each y :class:`~ott.geometry.pointcloud.PointCloud`. - sinkhorn_kwargs: Sinkhorn solver args. + subsample_n_x: number of points to subsample from the first measure in + :class:`~ott.geometry.pointcloud.PointCloud`. + subsample_n_y: number of points to subsample from the second measure in + :class:`~ott.geometry.pointcloud.PointCloud`. + If ``None``, use ``subsample_n_x``. + kwargs: Keyword arguments for + :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( self, subsample_n_x: int, subsample_n_y: Optional[int] = None, - sinkhorn_kwargs: Optional[Mapping[str, Any]] = None, + **kwargs: Any, ): super().__init__() self.subsample_n_x = subsample_n_x self.subsample_n_y = subsample_n_y or subsample_n_x - self.sinkhorn_kwargs = sinkhorn_kwargs or {} + self.sinkhorn_kwargs = kwargs - def init_dual_a( + def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> jnp.ndarray: - """Subsample initializer function. - - Args: - ot_prob: OT problem between discrete distributions of size n and m. - lse_mode: Return potential if true, scaling if false. - rng: random number generator for subsampling. - - Returns: - potential/scaling, array of size n. - """ from ott.solvers.linear import sinkhorn assert isinstance( @@ -392,7 +334,7 @@ def init_dual_a( a, b = ot_prob.a, ot_prob.b # subsample - rng_x, rng_y = jax.random.split(rng) + rng_x, rng_y = jax.random.split(rng, 2) sub_x = jax.random.choice( key=rng_x, a=x, shape=(self.subsample_n_x,), replace=True, p=a, axis=0 ) @@ -416,14 +358,48 @@ def init_dual_a( dual_potentials = subsample_sink_out.to_dual_potentials() f_potential = jax.vmap(dual_potentials.f)(x) - f_u = f_potential if lse_mode else ot_prob.geom.scaling_from_potential( + return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential ) - return f_u def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([], { - 'subsample_n_x': self.subsample_n_x, - 'subsample_n_y': self.subsample_n_y, - 'sinkhorn_kwargs': self.sinkhorn_kwargs + "subsample_n_x": self.subsample_n_x, + "subsample_n_y": self.subsample_n_y, + **self.sinkhorn_kwargs }) + + +def _vectorized_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Inner loop DualSort Update. + + Args: + f: potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + return jnp.min(modified_cost + f[None, :], axis=1) + + +def _coordinate_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Coordinate-wise updates within inner loop. + + Args: + f: potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + + def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray: + new_f = jnp.min(modified_cost[i, :] + f) + return f.at[i].set(new_f) + + return jax.lax.fori_loop(0, len(f), body_fn, f) diff --git a/src/ott/initializers/linear/initializers_lr.py b/src/ott/initializers/linear/initializers_lr.py index 74b20e9e1..c88350630 100644 --- a/src/ott/initializers/linear/initializers_lr.py +++ b/src/ott/initializers/linear/initializers_lr.py @@ -125,12 +125,12 @@ def init_g( @classmethod def from_solver( cls, - solver: Union['sinkhorn_lr.LRSinkhorn', - 'gromov_wasserstein.GromovWasserstein'], + solver: Union["sinkhorn_lr.LRSinkhorn", + "gromov_wasserstein.GromovWasserstein"], *, kind: Literal["random", "rank2", "k-means", "generalized-k-means"], **kwargs: Any, - ) -> 'LRInitializer': + ) -> "LRInitializer": """Create a low-rank initializer from a linear or quadratic solver. Args: @@ -216,11 +216,11 @@ def rank(self) -> int: """Rank of the transport matrix factorization.""" return self._rank - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], {**self._kwargs, "rank": self.rank} @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "LRInitializer": return cls(*children, **aux_data) diff --git a/src/ott/initializers/nn/initializers.py b/src/ott/initializers/nn/initializers.py index 3b979f176..79ec164c4 100644 --- a/src/ott/initializers/nn/initializers.py +++ b/src/ott/initializers/nn/initializers.py @@ -18,6 +18,7 @@ import jax.numpy as jnp import optax from flax import linen as nn +from flax.core import frozen_dict from flax.training import train_state from ott.geometry import geometry @@ -49,7 +50,8 @@ class MetaInitializer(initializers.DefaultInitializer): Args: geom: The fixed geometry of the problem instances. meta_model: The model to predict the potential :math:`f` from the measures. - opt: The optimizer to update the parameters. + opt: The optimizer to update the parameters. If ``None``, use + :func:`optax.adam` with :math:`0.001` learning rate. rng: The PRNG key to use for initializing the model. state: The training state of the model to start from. @@ -73,7 +75,8 @@ def __init__( self, geom: geometry.Geometry, meta_model: Optional[nn.Module] = None, - opt: optax.GradientTransformation = optax.adam(learning_rate=1e-3), + opt: Optional[optax.GradientTransformation + ] = optax.adam(learning_rate=1e-3), # noqa: B008 rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), state: Optional[train_state.TrainState] = None ): @@ -91,7 +94,7 @@ def __init__( # Initialize the model's training state. a_placeholder = jnp.zeros(na, dtype=self.dtype) b_placeholder = jnp.zeros(nb, dtype=self.dtype) - params = self.meta_model.init(rng, a_placeholder, b_placeholder)['params'] + params = self.meta_model.init(rng, a_placeholder, b_placeholder)["params"] self.state = train_state.TrainState.create( apply_fn=self.meta_model.apply, params=params, tx=opt ) @@ -115,35 +118,37 @@ def update( \min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; J(\hat f_\theta(a, b); \alpha, \beta), - where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta`, - :math:`\mathcal{D}` is a meta distribution of optimal transport problems, + where :math:`a,b` are the probabilities of the measures :math:`\alpha,\beta` + ,:math:`\mathcal{D}` is a meta distribution of optimal transport problems, .. math:: -J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\}\right\rangle + \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} + \right\rangle is the entropic dual objective, and :math:`K_{i,j} := -C_{i,j}/\varepsilon` is the *Gibbs kernel*. Args: state: Optimizer state of the meta model. - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. Returns: The training loss, :math:`f`, and updated state. """ return self.update_impl(state, a, b) - def init_dual_a( + def init_dual_a( # noqa: D102 self, - ot_prob: 'linear_problem.LinearProblem', + ot_prob: "linear_problem.LinearProblem", lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: del rng # Detect if the problem is batched. - assert ot_prob.a.ndim in (1, 2) and ot_prob.b.ndim in (1, 2) + assert ot_prob.a.ndim in (1, 2) + assert ot_prob.b.ndim in (1, 2) vmap_a_val = 0 if ot_prob.a.ndim == 2 else None vmap_b_val = 0 if ot_prob.b.ndim == 2 else None @@ -155,8 +160,7 @@ def init_dual_a( compute_f_maybe_batch = self._compute_f init_f = compute_f_maybe_batch(ot_prob.a, ot_prob.b, self.state.params) - f_u = init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) - return f_u + return init_f if lse_mode else ot_prob.geom.scaling_from_potential(init_f) def _get_update_fn(self): """Return the implementation (and jitted) update function.""" @@ -190,23 +194,26 @@ def update(state, a, b): return update - def _compute_f(self, a, b, params): + def _compute_f( + self, a: jnp.ndarray, b: jnp.ndarray, + params: frozen_dict.FrozenDict[str, jnp.ndarray] + ) -> jnp.ndarray: r"""Predict the optimal :math:`f` potential. Args: - a: Probabilites of the :math:`\alpha` measure's atoms. - b: Probabilites of the :math:`\beta` measure's atoms. + a: Probabilities of the :math:`\alpha` measure's atoms. + b: Probabilities of the :math:`\beta` measure's atoms. params: The parameters of the Meta model. Returns: The :math:`f` potential. """ - return self.meta_model.apply({'params': params}, a, b) + return self.meta_model.apply({"params": params}, a, b) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [self.geom, self.meta_model, self.opt], { - 'rng': self.rng, - 'state': self.state + "rng": self.rng, + "state": self.state } @@ -241,5 +248,4 @@ def __call__(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: z = jnp.concatenate((a, b)) for _ in range(self.num_hidden_layers): z = nn.relu(nn.Dense(self.num_hidden_units, dtype=dtype)(z)) - f = nn.Dense(self.potential_size, dtype=dtype)(z) - return f + return nn.Dense(self.potential_size, dtype=dtype)(z) diff --git a/src/ott/initializers/quadratic/initializers.py b/src/ott/initializers/quadratic/initializers.py index 1d4614d9c..841bbd1a4 100644 --- a/src/ott/initializers/quadratic/initializers.py +++ b/src/ott/initializers/quadratic/initializers.py @@ -39,8 +39,8 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs def __call__( - self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any - ) -> 'linear_problem.LinearProblem': + self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any + ) -> "linear_problem.LinearProblem": """Compute the initial linearization of a quadratic problem. Args: @@ -66,23 +66,23 @@ def __call__( @abc.abstractmethod def _create_geometry( - self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any + self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. Args: - quad_problem: Quadratic problem. + quad_prob: Quadratic problem. kwargs: Additional keyword arguments. Returns: Geometry used to initialize the linearized problem. """ - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], self._kwargs @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "BaseQuadraticInitializer": return cls(*children, **aux_data) @@ -121,7 +121,7 @@ class QuadraticInitializer(BaseQuadraticInitializer): """ def _create_geometry( - self, quad_prob: 'quadratic_problem.QuadraticProblem', *, epsilon: float, + self, quad_prob: "quadratic_problem.QuadraticProblem", *, epsilon: float, **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. @@ -171,12 +171,12 @@ class LRQuadraticInitializer(BaseQuadraticInitializer): lr_linear_initializer: Low-rank linear initializer. """ - def __init__(self, lr_linear_initializer: 'initializers_lr.LRInitializer'): + def __init__(self, lr_linear_initializer: "initializers_lr.LRInitializer"): super().__init__() self._linear_lr_initializer = lr_linear_initializer def _create_geometry( - self, quad_prob: 'quadratic_problem.QuadraticProblem', **kwargs: Any + self, quad_prob: "quadratic_problem.QuadraticProblem", **kwargs: Any ) -> geometry.Geometry: """Compute initial geometry for linearization. diff --git a/src/ott/math/decomposition.py b/src/ott/math/decomposition.py index 1099ffe11..5eba722ba 100644 --- a/src/ott/math/decomposition.py +++ b/src/ott/math/decomposition.py @@ -59,10 +59,10 @@ def solve(self, b: jnp.ndarray) -> jnp.ndarray: """Solve the linear system :math:`A * x = b`. Args: - b: Vector of shape ``[n,]``. + b: Vector of shape ``[n,]``. Returns: - The solution of shape ``[n,]``. + The solution of shape ``[n,]``. """ return self._solve(self.L, b) @@ -106,11 +106,11 @@ def L(self) -> Optional[T]: self._L = self._decompose(self.A) return self._L - def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return (self.A, self.L), {} @classmethod - def tree_unflatten( + def tree_unflatten( # noqa: D102 cls, aux_data: Mapping[str, Any], children: Sequence[Any] ) -> "CholeskySolver": A, L = children @@ -163,7 +163,7 @@ class SparseCholeskySolver( kwargs: Keyword arguments for :func:`sksparse.cholmod.cholesky`. """ - # TODO(michalk8): in the future, define a jax primitive + use CHOLMOD directly + # TODO(michalk8): deprecate or fix _FACTOR_CACHE = {} def __init__( diff --git a/src/ott/math/fixed_point_loop.py b/src/ott/math/fixed_point_loop.py index b8de353b8..f27ec835d 100644 --- a/src/ott/math/fixed_point_loop.py +++ b/src/ott/math/fixed_point_loop.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""jheek@ backprop-friendly implementation of fixed point loop.""" from typing import Any, Callable import jax diff --git a/src/ott/math/matrix_square_root.py b/src/ott/math/matrix_square_root.py index 8044d1c16..b79e8de0d 100644 --- a/src/ott/math/matrix_square_root.py +++ b/src/ott/math/matrix_square_root.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A Jax backprop friendly version of Matrix square root.""" - import functools from typing import Tuple @@ -56,7 +54,7 @@ def sqrtm( norm_x = norm_x[..., jnp.newaxis, jnp.newaxis] def cond_fn(iteration, const, state): - """Stopping criterion. Checking decrease of objective is needed here.""" # noqa: D401 + """Stopping criterion. Checking decrease of objective is needed here.""" _, threshold = const errors, _, _ = state err = errors[iteration // inner_iterations - 1] @@ -184,7 +182,8 @@ def sqrtm_bwd( cotangent: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], ) -> Tuple[jnp.ndarray]: """Compute the derivative by solving a Sylvester equation.""" - del threshold, min_iterations, inner_iterations, max_iterations, regularization + del threshold, min_iterations, inner_iterations, \ + max_iterations, regularization sqrt_x, inv_sqrt_x = residual # ignores cotangent associated with errors cot_sqrt, cot_inv_sqrt, _ = cotangent @@ -250,7 +249,7 @@ def sqrtm_only( # noqa: D103 )[0] -def sqrtm_only_fwd( +def sqrtm_only_fwd( # noqa: D103 x: jnp.ndarray, threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float ) -> Tuple[jnp.ndarray, jnp.ndarray]: @@ -261,13 +260,13 @@ def sqrtm_only_fwd( return sqrt_x, sqrt_x -def sqrtm_only_bwd( +def sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, sqrt_x: jnp.ndarray, cotangent: jnp.ndarray ) -> Tuple[jnp.ndarray]: - del threshold, min_iterations, inner_iterations - del max_iterations, regularization + del threshold, min_iterations, inner_iterations, \ + max_iterations, regularization vjp = jnp.swapaxes( solve_sylvester_bartels_stewart( a=sqrt_x, b=-sqrt_x, c=jnp.swapaxes(cotangent, axis1=-2, axis2=-1) @@ -296,7 +295,7 @@ def inv_sqrtm_only( # noqa: D103 )[1] -def inv_sqrtm_only_fwd( +def inv_sqrtm_only_fwd( # noqa: D103 x: jnp.ndarray, threshold: float, min_iterations: int, @@ -311,11 +310,14 @@ def inv_sqrtm_only_fwd( return inv_sqrt_x, inv_sqrt_x -def inv_sqrtm_only_bwd( +def inv_sqrtm_only_bwd( # noqa: D103 threshold: float, min_iterations: int, inner_iterations: int, max_iterations: int, regularization: float, residual: jnp.ndarray, cotangent: jnp.ndarray ) -> Tuple[jnp.ndarray]: + del threshold, min_iterations, inner_iterations, \ + max_iterations, regularization + inv_sqrt_x = residual inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x) vjp = jnp.swapaxes( diff --git a/src/ott/math/unbalanced_functions.py b/src/ott/math/unbalanced_functions.py index 776c9b18e..5927ec3fc 100644 --- a/src/ott/math/unbalanced_functions.py +++ b/src/ott/math/unbalanced_functions.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Functions useful to define unbalanced OT problems.""" from typing import Callable import jax.numpy as jnp @@ -24,7 +23,7 @@ def phi_star(h: jnp.ndarray, rho: float) -> jnp.ndarray: # TODO(cuturi): use jax.grad directly. def derivative_phi_star(f: jnp.ndarray, rho: float) -> jnp.ndarray: - """Derivative of Legendre transform of phi_starKL, see phi_star.""" # noqa: D401 + """Derivative of Legendre transform of phi_starKL, see phi_star.""" return jnp.exp(f / rho) diff --git a/src/ott/math/utils.py b/src/ott/math/utils.py index 2bae4f817..620d9886e 100644 --- a/src/ott/math/utils.py +++ b/src/ott/math/utils.py @@ -112,8 +112,7 @@ def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents): res += jnp.sum(tan_b * centered_exp, axis=axis, keepdims=keepdims) if return_sign: return (lse, sign), (sign * res, jnp.zeros_like(sign)) - else: - return lse, res + return lse, res @functools.partial(jax.custom_vjp, nondiff_argnums=(2,)) @@ -124,7 +123,7 @@ def softmin( Args: x: Input data. - gamma: Smoothing parameter. + gamma: Smoothing parameter :math:`> 0`. axis: Axis or axes over which to operate. If ``None``, use flattened input. Returns: diff --git a/src/ott/problems/linear/barycenter_problem.py b/src/ott/problems/linear/barycenter_problem.py index 136585327..b7ed68082 100644 --- a/src/ott/problems/linear/barycenter_problem.py +++ b/src/ott/problems/linear/barycenter_problem.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Classes defining OT problem(s) (objective function + utilities).""" from typing import Any, Dict, Optional, Sequence, Tuple import jax @@ -156,7 +155,7 @@ def weights(self) -> jnp.ndarray: # By default, we assume that weights sum to 1, and enforce this if needed. weights = self._weights / jnp.sum(self._weights) if self.debiased: - weights = jnp.concatenate((weights, jnp.array([-0.5]))) + return jnp.concatenate((weights, jnp.array([-0.5]))) return weights @property @@ -165,9 +164,9 @@ def _is_segmented(self) -> bool: def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self._y, self._b, self._weights], { - 'cost_fn': self.cost_fn, - 'epsilon': self.epsilon, - 'debiased': self.debiased, + "cost_fn": self.cost_fn, + "epsilon": self.epsilon, + "debiased": self.debiased, **self._kwargs, }) @@ -211,13 +210,12 @@ def num_measures(self) -> int: def weights(self) -> jnp.ndarray: """Barycenter weights of shape ``[num_measures,]`` that sum to :math`1`.""" if self._weights is None: - weights = jnp.ones((self.num_measures,)) / self.num_measures - else: - # Check that the number of measures coincides with the weights' size. - assert self._weights.shape[0] == self.num_measures - # By default, we assume that weights sum to 1, and enforce this if needed. - weights = self._weights / jnp.sum(self._weights) - return weights + return jnp.ones((self.num_measures,)) / self.num_measures + + # check that the number of measures coincides with the weights' size + assert self._weights.shape[0] == self.num_measures + # by default, we assume that weights sum to 1, and enforce this if needed + return self._weights / jnp.sum(self._weights) def tree_flatten(self): # noqa: D102 return [self.geom, self.a, self._weights], None diff --git a/src/ott/problems/linear/linear_problem.py b/src/ott/problems/linear/linear_problem.py index 60baf731d..f5f87c594 100644 --- a/src/ott/problems/linear/linear_problem.py +++ b/src/ott/problems/linear/linear_problem.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Classes defining OT problem(s) (objective function + utilities).""" - from typing import Any, Callable, Dict, Optional, Sequence, Tuple import jax @@ -109,8 +107,8 @@ def get_transport_functions( def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.geom, self._a, self._b], { - 'tau_a': self.tau_a, - 'tau_b': self.tau_b + "tau_a": self.tau_a, + "tau_b": self.tau_b }) @classmethod diff --git a/src/ott/problems/linear/potentials.py b/src/ott/problems/linear/potentials.py index 9c68de9a6..299d64f42 100644 --- a/src/ott/problems/linear/potentials.py +++ b/src/ott/problems/linear/potentials.py @@ -64,7 +64,7 @@ def __init__( f: Potential_t, g: Potential_t, *, - cost_fn: 'costs.CostFn', + cost_fn: "costs.CostFn", corr: bool = False ): self._f = f @@ -102,8 +102,7 @@ def transport(self, vec: jnp.ndarray, forward: bool = True) -> jnp.ndarray: return self._grad_f(vec) if forward else self._grad_g(vec) if forward: return vec - self._grad_h_inv(self._grad_f(vec)) - else: - return vec - self._grad_h_inv(self._grad_g(vec)) + return vec - self._grad_h_inv(self._grad_g(vec)) def distance(self, src: jnp.ndarray, tgt: jnp.ndarray) -> float: """Evaluate 2-Wasserstein distance between samples using dual potentials. @@ -206,13 +205,13 @@ def plot_ot_map( raise RuntimeError("Please install `matplotlib` first.") if scatter_kwargs is None: - scatter_kwargs = {'alpha': 0.5} + scatter_kwargs = {"alpha": 0.5} if legend_kwargs is None: legend_kwargs = { - 'ncol': 3, - 'loc': 'upper center', - 'bbox_to_anchor': (0.5, -0.05), - 'edgecolor': 'k' + "ncol": 3, + "loc": "upper center", + "bbox_to_anchor": (0.5, -0.05), + "edgecolor": "k" } if ax is None: @@ -233,14 +232,14 @@ def plot_ot_map( source[:, 0], source[:, 1], color=source_color, - label='source', + label="source", **scatter_kwargs, ) ax.scatter( target[:, 0], target[:, 1], color=target_color, - label='target', + label="target", **scatter_kwargs, ) diff --git a/src/ott/problems/nn/dataset.py b/src/ott/problems/nn/dataset.py index 9d4c96d2c..0a780d05b 100644 --- a/src/ott/problems/nn/dataset.py +++ b/src/ott/problems/nn/dataset.py @@ -11,18 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Toy datasets for neural OT.""" - import dataclasses -from typing import Iterable, Iterator, Literal, NamedTuple, Tuple +from typing import Iterator, Literal, NamedTuple, Tuple import jax import jax.numpy as jnp import numpy as np -__all__ = ['create_gaussian_mixture_samplers', 'Dataset', 'GaussianMixture'] +__all__ = ["create_gaussian_mixture_samplers", "Dataset", "GaussianMixture"] -Name_t = Literal['simple', 'circle', 'square_five', 'square_four'] +Name_t = Literal["simple", "circle", "square_five", "square_four"] class Dataset(NamedTuple): @@ -32,8 +30,8 @@ class Dataset(NamedTuple): source_iter: loader for the source measure target_iter: loader for the target measure """ - source_iter: Iterable[jnp.ndarray] - target_iter: Iterable[jnp.ndarray] + source_iter: Iterator[jnp.ndarray] + target_iter: Iterator[jnp.ndarray] @dataclasses.dataclass @@ -43,11 +41,12 @@ class GaussianMixture: Args: name: the name specifying the centers of the mixture components: - - ``simple`` (data clustered in one center), - - ``circle`` (two-dimensional Gaussians arranged on a circle), - - ``square_five`` (two-dimensional Gaussians on a square with - one Gaussian in the center), and - - ``square_four`` (two-dimensional Gaussians in the corners of a rectangle) + - ``simple`` - data clustered in one center, + - ``circle`` - two-dimensional Gaussians arranged on a circle, + - ``square_five`` - two-dimensional Gaussians on a square with + one Gaussian in the center, and + - ``square_four`` - two-dimensional Gaussians in the corners of a + rectangle batch_size: batch size of the samples init_rng: initial PRNG key @@ -62,9 +61,9 @@ class GaussianMixture: def __post_init__(self): gaussian_centers = { - 'simple': + "simple": np.array([[0, 0]]), - 'circle': + "circle": np.array([ (1, 0), (-1, 0), @@ -75,26 +74,26 @@ def __post_init__(self): (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)), (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)), ]), - 'square_five': + "square_five": np.array([[0, 0], [1, 1], [-1, 1], [-1, -1], [1, -1]]), - 'square_four': + "square_four": np.array([[1, 0], [0, 1], [-1, 0], [0, -1]]), } if self.name not in gaussian_centers: raise ValueError( - f'{self.name} is not a valid dataset for GaussianMixture' + f"{self.name} is not a valid dataset for GaussianMixture" ) self.centers = gaussian_centers[self.name] def __iter__(self) -> Iterator[jnp.array]: - return self.create_sample_generators() - - def create_sample_generators(self) -> Iterator[jnp.array]: """Random sample generator from Gaussian mixture. Returns: A generator of samples from the Gaussian mixture. """ + return self._create_sample_generators() + + def _create_sample_generators(self) -> Iterator[jnp.array]: rng = self.init_rng while True: rng1, rng2, rng = jax.random.split(rng, 3) @@ -111,7 +110,7 @@ def create_gaussian_mixture_samplers( valid_batch_size: int = 2048, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> Tuple[Dataset, Dataset, int]: - """Creates Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`. + """Gaussian samplers for :class:`~ott.solvers.nn.neuraldual.W2NeuralDual`. Args: name_source: name of the source sampler diff --git a/src/ott/problems/quadratic/__init__.py b/src/ott/problems/quadratic/__init__.py index 18ff1c517..d2bb647ee 100644 --- a/src/ott/problems/quadratic/__init__.py +++ b/src/ott/problems/quadratic/__init__.py @@ -1 +1,14 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from . import gw_barycenter, quadratic_costs, quadratic_problem diff --git a/src/ott/problems/quadratic/gw_barycenter.py b/src/ott/problems/quadratic/gw_barycenter.py index db8d428ce..dfe562d98 100644 --- a/src/ott/problems/quadratic/gw_barycenter.py +++ b/src/ott/problems/quadratic/gw_barycenter.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import functools from typing import Any, Dict, Literal, Optional, Sequence, Tuple, Union @@ -53,7 +66,7 @@ def __init__( costs: Optional[jnp.ndarray] = None, y_fused: Optional[jnp.ndarray] = None, fused_penalty: float = 1.0, - gw_loss: Literal['sqeucl', 'kl'] = 'sqeucl', + gw_loss: Literal["sqeucl", "kl"] = "sqeucl", scale_cost: Union[int, float, Literal["mean", "max_cost"]] = 1.0, **kwargs: Any, ): @@ -119,7 +132,7 @@ def project( ) return transport @ tmp - fn = None if self._loss_name == 'sqeucl' else self.gw_loss.h2 + fn = None if self._loss_name == "sqeucl" else self.gw_loss.h2 y, b = self.segmented_y_b weights = self.weights[:, None, None] @@ -129,8 +142,8 @@ def project( # TODO(michalk8): in future, use `isinstanceof(self.gw_loss, ...)` # once refactoring has been done - if self._loss_name == 'kl': - barycenter = jnp.exp(barycenter) + if self._loss_name == "kl": + return jnp.exp(barycenter) return barycenter def update_features(self, transports: jnp.ndarray, @@ -221,7 +234,7 @@ def _create_fused_geometry( def _create_problem( self, - state: 'GWBarycenterState', # noqa: F821 + state: "GWBarycenterState", # noqa: F821 y: jnp.ndarray, b: jnp.ndarray, f: Optional[jnp.ndarray] = None @@ -257,8 +270,7 @@ def is_fused(self) -> bool: @property def segmented_y_fused(self) -> Optional[jnp.ndarray]: - """Feature array of shape ``[num_measures, max_measure_size, ndim_fused]`` \ - used in the fused case.""" + """Feature array of shape used in the fused case.""" if not self.is_fused or self._y_fused.ndim == 3: return self._y_fused y_fused, _ = segment.segment_point_cloud( @@ -284,9 +296,9 @@ def gw_loss(self) -> quadratic_costs.GWLoss: # `https://jax.readthedocs.io/en/latest/notebooks/ some fns; # Writing_custom_interpreters_in_Jax.html#your-first-interpreter-invert` # might be useful - if self._loss_name == 'sqeucl': + if self._loss_name == "sqeucl": return quadratic_costs.make_square_loss() - if self._loss_name == 'kl': + if self._loss_name == "kl": return quadratic_costs.make_kl_loss() raise NotImplementedError( f"Loss `{self._loss_name}` is not yet implemented." @@ -298,9 +310,9 @@ def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children = [None, b, weights, y] else: children = [y, b, weights, None] - aux['fused_penalty'] = self.fused_penalty - aux['gw_loss'] = self._loss_name - aux['scale_cost'] = self.scale_cost + aux["fused_penalty"] = self.fused_penalty + aux["gw_loss"] = self._loss_name + aux["scale_cost"] = self.scale_cost return children + [self._y_fused], aux @classmethod diff --git a/src/ott/problems/quadratic/quadratic_costs.py b/src/ott/problems/quadratic/quadratic_costs.py index 2460dfbf7..e55acc739 100644 --- a/src/ott/problems/quadratic/quadratic_costs.py +++ b/src/ott/problems/quadratic/quadratic_costs.py @@ -1,3 +1,16 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Callable, NamedTuple import jax @@ -6,12 +19,12 @@ __all__ = ["make_square_loss", "make_kl_loss"] -class Loss(NamedTuple): +class Loss(NamedTuple): # noqa: D101 func: Callable[[jnp.ndarray], jnp.ndarray] is_linear: bool -class GWLoss(NamedTuple): +class GWLoss(NamedTuple): # noqa: D101 f1: Loss f2: Loss h1: Loss diff --git a/src/ott/problems/quadratic/quadratic_problem.py b/src/ott/problems/quadratic/quadratic_problem.py index f73f63747..395d43714 100644 --- a/src/ott/problems/quadratic/quadratic_problem.py +++ b/src/ott/problems/quadratic/quadratic_problem.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Classes defining OT problem(s) (objective function + utilities).""" - from typing import TYPE_CHECKING, Literal, Optional, Tuple, Union import jax @@ -78,8 +76,8 @@ class QuadraticProblem: tau_b: if `< 1.0`, defines how much unbalanced the problem is on the second marginal. gw_unbalanced_correction: Whether the unbalanced version of - :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect - the inner Sinkhorn loop. + :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` + only affect the inner Sinkhorn loop. ranks: Ranks of the cost matrices, see :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with @@ -101,7 +99,7 @@ def __init__( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', + loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -124,9 +122,9 @@ def __init__( self.tolerances = tolerances self._loss_name = loss - if self._loss_name == 'sqeucl': + if self._loss_name == "sqeucl": self.loss = quadratic_costs.make_square_loss() - elif loss == 'kl': + elif loss == "kl": self.loss = quadratic_costs.make_kl_loss() else: self.loss = loss @@ -161,7 +159,7 @@ def marginal_dependent_cost( Returns: Low-rank geometry of rank 2, storing normalization constants. """ - if self._loss_name == 'sqeucl': # quadratic apply, efficient for LR + if self._loss_name == "sqeucl": # quadratic apply, efficient for LR tmp1 = self.geom_xx.apply_square_cost(marginal_1, axis=1) tmp2 = self.geom_yy.apply_square_cost(marginal_2, axis=1) else: @@ -239,7 +237,7 @@ def init_transport_mass(self) -> float: return a.sum() * b.sum() def update_lr_geom( - self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' + self, lr_sink: "sinkhorn_lr.LRSinkhornOutput" ) -> geometry.Geometry: """Recompute (possibly LRC) linearization using LR Sinkhorn output.""" marginal_1 = lr_sink.marginal(1) @@ -263,7 +261,7 @@ def update_lr_geom( cost_matrix = marginal_cost.cost_matrix - jnp.dot(tmp1, tmp2.T) cost_matrix += self.fused_penalty * self._fused_cost_matrix geom = geometry.Geometry(cost_matrix=cost_matrix) - return geom + return geom # noqa: RET504 def update_linearization( self, @@ -326,7 +324,7 @@ def update_linearization( ) def update_lr_linearization( - self, lr_sink: 'sinkhorn_lr.LRSinkhornOutput' + self, lr_sink: "sinkhorn_lr.LRSinkhornOutput" ) -> linear_problem.LinearProblem: """Update a Quad problem linearization using a LR Sinkhorn.""" return linear_problem.LinearProblem( @@ -468,14 +466,14 @@ def is_balanced(self) -> bool: def tree_flatten(self): # noqa: D102 return ([self.geom_xx, self.geom_yy, self.geom_xy, self._a, self._b], { - 'tau_a': self.tau_a, - 'tau_b': self.tau_b, - 'loss': self._loss_name, - 'fused_penalty': self.fused_penalty, - 'scale_cost': self.scale_cost, - 'gw_unbalanced_correction': self.gw_unbalanced_correction, - 'ranks': self.ranks, - 'tolerances': self.tolerances + "tau_a": self.tau_a, + "tau_b": self.tau_b, + "loss": self._loss_name, + "fused_penalty": self.fused_penalty, + "scale_cost": self.scale_cost, + "gw_unbalanced_correction": self.gw_unbalanced_correction, + "ranks": self.ranks, + "tolerances": self.tolerances }) @classmethod @@ -484,7 +482,7 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*geoms, a=a, b=b, **aux_data) -def update_epsilon_unbalanced( +def update_epsilon_unbalanced( # noqa: D103 epsilon: Union[float, epsilon_scheduler.Epsilon], transport_mass: float ) -> epsilon_scheduler.Epsilon: if not isinstance(epsilon, epsilon_scheduler.Epsilon): @@ -492,7 +490,7 @@ def update_epsilon_unbalanced( return epsilon.set(scale_epsilon=epsilon._scale_epsilon * transport_mass) -def apply_cost( +def apply_cost( # noqa: D103 geom: geometry.Geometry, arr: jnp.ndarray, *, axis: int, fn: quadratic_costs.Loss ) -> jnp.ndarray: diff --git a/src/ott/solvers/linear/acceleration.py b/src/ott/solvers/linear/acceleration.py index 93f137874..9422a5dff 100644 --- a/src/ott/solvers/linear/acceleration.py +++ b/src/ott/solvers/linear/acceleration.py @@ -57,8 +57,8 @@ def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: return jnp.where(jnp.isfinite(combination), combination, -jnp.inf) def update( - self, state: 'sinkhorn.SinkhornState', iteration: int, pb, lse_mode: bool - ) -> 'sinkhorn.SinkhornState': + self, state: "sinkhorn.SinkhornState", iteration: int, pb, lse_mode: bool + ) -> "sinkhorn.SinkhornState": """Anderson acceleration update. When using Anderson acceleration, first update the dual variable f_u with @@ -107,15 +107,15 @@ def update( return state.set(fu=fu, old_fus=old_fus) def init_maps( - self, pb, state: 'sinkhorn.SinkhornState' - ) -> 'sinkhorn.SinkhornState': + self, pb, state: "sinkhorn.SinkhornState" + ) -> "sinkhorn.SinkhornState": """Initialize log matrix used in Anderson acceleration with nan values.""" fus = jnp.ones((pb.geom.shape[0], self.memory)) * jnp.nan return state.set(old_fus=fus, old_mapped_fus=fus) def update_history( - self, state: 'sinkhorn.SinkhornState', pb, lse_mode: bool - ) -> 'sinkhorn.SinkhornState': + self, state: "sinkhorn.SinkhornState", pb, lse_mode: bool + ) -> "sinkhorn.SinkhornState": """Update history of mapped dual variables.""" f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) @@ -124,29 +124,30 @@ def update_history( @utils.register_pytree_node class Momentum: - """Momentum for Sinkhorn updates, either constant :cite:`thibault:21` or \ - adaptive :cite:`lehmann:21`.""" + """Momentum for Sinkhorn updates. + + Can be either constant :cite:`thibault:21` or adaptive :cite:`lehmann:21`. + """ start: int = 0 error_threshold: float = jnp.inf value: float = 1.0 inner_iterations: int = 1 - def weight(self, state: 'sinkhorn.SinkhornState', iteration: int) -> float: + def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float: """Compute momentum term if needed, using previously seen errors.""" if self.start == 0: return self.value idx = self.start // self.inner_iterations - weight = jax.lax.cond( + return jax.lax.cond( jnp.logical_and( iteration >= self.start, state.errors[idx - 1, -1] < self.error_threshold ), lambda state: self.lehmann(state), lambda state: self.value, state ) - return weight - def lehmann(self, state: 'sinkhorn.SinkhornState') -> float: + def lehmann(self, state: "sinkhorn.SinkhornState") -> float: """Momentum formula :cite:`lehmann:21`, eq. 5.""" idx = self.start // self.inner_iterations error_ratio = jnp.minimum( @@ -165,6 +166,5 @@ def __call__( # noqa: D102 if lse_mode: value = jnp.where(jnp.isfinite(value), value, 0.0) return (1.0 - weight) * value + weight * new_value - else: - value = jnp.where(value > 0.0, value, 1.0) - return value ** (1.0 - weight) * new_value ** weight + value = jnp.where(value > 0.0, value, 1.0) + return value ** (1.0 - weight) * new_value ** weight diff --git a/src/ott/solvers/linear/continuous_barycenter.py b/src/ott/solvers/linear/continuous_barycenter.py index d584082b5..d1d504de2 100644 --- a/src/ott/solvers/linear/continuous_barycenter.py +++ b/src/ott/solvers/linear/continuous_barycenter.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A Jax version of the W barycenter algorithm (Cuturi Doucet 2014).""" import functools from typing import Any, NamedTuple, Optional, Tuple @@ -47,14 +46,14 @@ class FreeBarycenterState(NamedTuple): x: Optional[jnp.ndarray] = None a: Optional[jnp.ndarray] = None - def set(self, **kwargs: Any) -> 'FreeBarycenterState': + def set(self, **kwargs: Any) -> "FreeBarycenterState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) def update( self, iteration: int, bar_prob: barycenter_problem.FreeBarycenterProblem, linear_ot_solver: Any, store_errors: bool - ) -> 'FreeBarycenterState': + ) -> "FreeBarycenterState": """Update the state of the solver. Args: @@ -129,7 +128,7 @@ def solve_linear_ot( @jax.tree_util.register_pytree_node_class class FreeWassersteinBarycenter(was_solver.WassersteinSolver): - """Continuous Wassertsein barycenter solver.""" + """Continuous Wassertstein barycenter solver :cite:`cuturi:14`.""" def __call__( # noqa: D102 self, diff --git a/src/ott/solvers/linear/discrete_barycenter.py b/src/ott/solvers/linear/discrete_barycenter.py index 481b5c68b..dcfdc1470 100644 --- a/src/ott/solvers/linear/discrete_barycenter.py +++ b/src/ott/solvers/linear/discrete_barycenter.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of :cite:`janati:20` Wasserstein barycenter algorithm.""" - import functools from typing import NamedTuple, Optional, Sequence @@ -108,7 +106,7 @@ def __call__( )[jnp.newaxis, :] if self.debiased and not geom.is_symmetric: - raise ValueError('Geometry must be symmetric to use debiased option.') + raise ValueError("Geometry must be symmetric to use debiased option.") norm_error = (self.norm_error,) return _discrete_barycenter( geom, a, weights, dual_initialization, self.threshold, norm_error, @@ -118,7 +116,7 @@ def __call__( def tree_flatten(self): # noqa: D102 aux = vars(self).copy() - aux.pop('threshold') + aux.pop("threshold") return [ self.threshold, ], aux diff --git a/src/ott/solvers/linear/sinkhorn.py b/src/ott/solvers/linear/sinkhorn.py index 139e47853..1b8e7b085 100644 --- a/src/ott/solvers/linear/sinkhorn.py +++ b/src/ott/solvers/linear/sinkhorn.py @@ -57,7 +57,7 @@ class SinkhornState(NamedTuple): old_fus: Optional[jnp.ndarray] = None old_mapped_fus: Optional[jnp.ndarray] = None - def set(self, **kwargs: Any) -> 'SinkhornState': + def set(self, **kwargs: Any) -> "SinkhornState": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) @@ -84,7 +84,7 @@ def solution_error( parallel_dual_updates=parallel_dual_updates ) - def ent_reg_cost( + def ent_reg_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool ) -> float: return ent_reg_cost(self.fu, self.gv, ot_prob, lse_mode) @@ -222,10 +222,9 @@ def marginal_error( else: marginal = geom.marginal_from_scalings(f_u, g_v, axis=axis) norm_error = jnp.asarray(norm_error) - error = jnp.sum( + return jnp.sum( jnp.abs(marginal - target) ** norm_error[:, jnp.newaxis], axis=1 ) ** (1.0 / norm_error) - return error def ent_reg_cost( @@ -292,14 +291,14 @@ class SinkhornOutput(NamedTuple): reg_ot_cost: Optional[float] = None ot_prob: Optional[linear_problem.LinearProblem] = None - def set(self, **kwargs: Any) -> 'SinkhornOutput': + def set(self, **kwargs: Any) -> "SinkhornOutput": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) def set_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool - ) -> 'SinkhornOutput': + ) -> "SinkhornOutput": f = jax.lax.stop_gradient(self.f) if use_danskin else self.f g = jax.lax.stop_gradient(self.g) if use_danskin else self.g return self.set(reg_ot_cost=ent_reg_cost(f, g, ot_prob, lse_mode)) @@ -709,7 +708,7 @@ def __init__( use_danskin: Optional[bool] = None, jit: bool = True, implicit_diff: Optional[implicit_lib.ImplicitDiff - ] = implicit_lib.ImplicitDiff(), # noqa: E124 + ] = implicit_lib.ImplicitDiff(), # noqa: B008 initializer: Union[Literal["default", "gaussian", "sorting", "subsample"], init_lib.SinkhornInitializer] = "default", progress_fn: Optional[ProgressCallbackFn_t] = None, @@ -1015,8 +1014,8 @@ def create_initializer(self) -> init_lib.SinkhornInitializer: # noqa: D102 def tree_flatten(self): # noqa: D102 aux = vars(self).copy() - aux['norm_error'] = aux.pop('_norm_error') - aux.pop('threshold') + aux["norm_error"] = aux.pop("_norm_error") + aux.pop("threshold") return [self.threshold], aux @classmethod @@ -1126,7 +1125,7 @@ def solve( tau_b: float = 1.0, rank: int = -1, **kwargs: Any -) -> Union[SinkhornOutput, 'LRSinkhornOutput']: +) -> Union[SinkhornOutput, "LRSinkhornOutput"]: """Solve linear regularized OT problem using Sinkhorn iterations. Args: diff --git a/src/ott/solvers/linear/sinkhorn_lr.py b/src/ott/solvers/linear/sinkhorn_lr.py index 106457416..6e8ae2558 100644 --- a/src/ott/solvers/linear/sinkhorn_lr.py +++ b/src/ott/solvers/linear/sinkhorn_lr.py @@ -48,27 +48,29 @@ class LRSinkhornState(NamedTuple): errors: jnp.ndarray crossed_threshold: bool - def compute_error(self, previous_state: "LRSinkhornState") -> float: + def compute_error( # noqa: D102 + self, previous_state: "LRSinkhornState" + ) -> float: err_1 = mu.js(self.q, previous_state.q, c=1.) err_2 = mu.js(self.r, previous_state.r, c=1.) err_3 = mu.js(self.g, previous_state.g, c=1.) return ((1. / self.gamma) ** 2) * (err_1 + err_2 + err_3) - def reg_ot_cost( + def reg_ot_cost( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, use_danskin: bool = False ) -> float: return compute_reg_ot_cost(self.q, self.r, self.g, ot_prob, use_danskin) - def solution_error( + def solution_error( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...], lse_mode: bool ) -> jnp.ndarray: return solution_error(self.q, self.r, ot_prob, norm_error, lse_mode) - def set(self, **kwargs: Any) -> 'LRSinkhornState': + def set(self, **kwargs: Any) -> "LRSinkhornState": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) @@ -148,7 +150,7 @@ class LRSinkhornOutput(NamedTuple): # TODO(michalk8): Optional is an artifact of the current impl., refactor reg_ot_cost: Optional[float] = None - def set(self, **kwargs: Any) -> 'LRSinkhornOutput': + def set(self, **kwargs: Any) -> "LRSinkhornOutput": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs) @@ -157,7 +159,7 @@ def set_cost( # noqa: D102 ot_prob: linear_problem.LinearProblem, lse_mode: bool, use_danskin: bool = False - ) -> 'LRSinkhornOutput': + ) -> "LRSinkhornOutput": del lse_mode return self.set(reg_ot_cost=self.compute_reg_ot_cost(ot_prob, use_danskin)) @@ -251,18 +253,13 @@ class LRSinkhorn(sinkhorn.Sinkhorn): described in :cite:`scetbon:22b`. epsilon: Entropic regularization added on top of low-rank problem. initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g` - factors. Valid options are: - - - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. - - If `None`, :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` + factors. Valid options are `'random'`, `'rank2'`, `'k-means'`, and + `'generalized-k-means`. If `None`, + :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` is used when the linear problem's geometry is :class:`~ott.geometry.pointcloud.PointCloud` or - :class:`~ott.geometry.low_rank.LRCGeometry`. - Otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. + :class:`~ott.geometry.low_rank.LRCGeometry`. Otherwise, use + :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. lse_mode: Whether to run computations in lse or kernel mode. At the moment, only ``lse_mode = True`` is implemented. diff --git a/src/ott/solvers/nn/conjugate_solvers.py b/src/ott/solvers/nn/conjugate_solvers.py index 19ec83bfb..2cace8464 100644 --- a/src/ott/solvers/nn/conjugate_solvers.py +++ b/src/ott/solvers/nn/conjugate_solvers.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implementation of :cite:`amos:17` input convex neural networks (ICNN).""" - import abc from typing import Callable, Literal, NamedTuple, Optional @@ -37,7 +35,6 @@ class ConjugateResults(NamedTuple): grad: the gradient, i.e., :math:`\nabla f^\star(y)` num_iter: the number of iterations taken by the solver """ - val: float grad: jnp.ndarray num_iter: int @@ -85,9 +82,9 @@ class FenchelConjugateLBFGS(FenchelConjugateSolver): gtol: float = 1e-3 max_iter: int = 10 max_linesearch_iter: int = 10 - linesearch_type: Literal['zoom', 'backtracking'] = 'backtracking' + linesearch_type: Literal["zoom", "backtracking"] = "backtracking" decrease_factor: float = 0.66 - ls_method: Literal['wolf', 'strong-wolfe'] = 'strong-wolfe' + ls_method: Literal["wolf", "strong-wolfe"] = "strong-wolfe" def solve( # noqa: D102 self, @@ -95,7 +92,7 @@ def solve( # noqa: D102 y: jnp.ndarray, x_init: Optional[jnp.array] = None ) -> ConjugateResults: - assert y.ndim == 1 + assert y.ndim == 1, y.ndim solver = LBFGS( fun=lambda x: f(x) - x.dot(y), @@ -118,5 +115,5 @@ def solve( # noqa: D102 gtol=1e-5, max_iter=20, max_linesearch_iter=20, - linesearch_type='backtracking', + linesearch_type="backtracking", ) diff --git a/src/ott/solvers/nn/layers.py b/src/ott/solvers/nn/layers.py index ea5883566..3c24d3de3 100644 --- a/src/ott/solvers/nn/layers.py +++ b/src/ott/solvers/nn/layers.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Layers used in input convex neural networks :cite:`amos:17,bunne:22`.""" - from typing import Any, Callable, Tuple import flax.linen as nn @@ -58,12 +56,13 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: Args: inputs: Array to be transformed. + Returns: The transformed input. """ inputs = jnp.asarray(inputs, self.dtype) kernel = self.param( - 'kernel', self.kernel_init, (inputs.shape[-1], self.dim_hidden) + "kernel", self.kernel_init, (inputs.shape[-1], self.dim_hidden) ) kernel = self.rectifier_fn(kernel) y = jax.lax.dot_general( @@ -72,9 +71,9 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: precision=self.precision ) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.dim_hidden,)) + bias = self.param("bias", self.bias_init, (self.dim_hidden,)) bias = jnp.asarray(bias, self.dtype) - y = y + bias + return y + bias return y @@ -133,5 +132,4 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray: ) y = 0.5 * y * y - out = jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2) - return out + return jnp.sum(y.reshape((-1, self.num_potentials, self.dim_data)), axis=2) diff --git a/src/ott/solvers/nn/models.py b/src/ott/solvers/nn/models.py index 7823b0339..ad1de4878 100644 --- a/src/ott/solvers/nn/models.py +++ b/src/ott/solvers/nn/models.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Neural potential models.""" - import abc from typing import Callable, Optional, Sequence, Tuple, Union @@ -78,6 +76,7 @@ def potential_value_fn( constructs the value of the potential from the gradient with .. math:: + g(y) = -f(\nabla_y g(y)) + y^T \nabla_y g(y) where :math:`\nabla_y g(y)` is detached for the envelope theorem @@ -86,29 +85,29 @@ def potential_value_fn( Args: params: parameters of the module - x: point to evaluate the value at - other_potential_value: function giving the value of the other potential. - Only needed when :attr:`is_potential` is ``False``. + other_potential_value_fn: function giving the value of the other + potential. Only needed when :attr:`is_potential` is ``False``. Returns: A function that can be evaluated to obtain the potential's value """ if self.is_potential: return lambda x: self.apply({"params": params}, x) - else: - assert other_potential_value_fn is not None, \ - "The value of the gradient-based potential depends on the value of the other potential" - def value_fn(x: jnp.ndarray) -> jnp.ndarray: - squeeze = x.ndim == 1 - if squeeze: - x = jnp.expand_dims(x, 0) - grad_g_x = jax.lax.stop_gradient(self.apply({"params": params}, x)) - value = -other_potential_value_fn(grad_g_x) + \ - jax.vmap(jnp.dot)(grad_g_x, x) - return value.squeeze(0) if squeeze else value + assert other_potential_value_fn is not None, \ + "The value of the gradient-based potential depends " \ + "on the value of the other potential." + + def value_fn(x: jnp.ndarray) -> jnp.ndarray: + squeeze = x.ndim == 1 + if squeeze: + x = jnp.expand_dims(x, 0) + grad_g_x = jax.lax.stop_gradient(self.apply({"params": params}, x)) + value = -other_potential_value_fn(grad_g_x) + \ + jax.vmap(jnp.dot)(grad_g_x, x) + return value.squeeze(0) if squeeze else value - return value_fn + return value_fn def potential_gradient_fn( self, @@ -124,8 +123,7 @@ def potential_gradient_fn( """ if self.is_potential: return jax.vmap(jax.grad(self.potential_value_fn(params))) - else: - return lambda x: self.apply({'params': params}, x) + return lambda x: self.apply({"params": params}, x) class ICNN(ModelBase): @@ -150,14 +148,13 @@ class ICNN(ModelBase): initialization scheme based on Gaussian approximation of input and target measure (if ``None``, identity initialization is used). """ - dim_data: int dim_hidden: Sequence[int] init_std: float = 1e-2 init_fn: Callable = jax.nn.initializers.normal - act_fn: Callable = nn.relu + act_fn: Callable[[jnp.ndarray], jnp.ndarray] = nn.relu pos_weights: bool = True - gaussian_map: Tuple[jnp.ndarray, jnp.ndarray] = None + gaussian_map: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None @property def is_potential(self) -> bool: # noqa: D102 @@ -213,7 +210,7 @@ def setup(self) -> None: # noqa: D102 use_bias=True, ) - # subsequent layers reinjected into convex functions + # subsequent layers re-injected into convex functions w_xs = [] for i in range(self.num_hidden): w_xs.append( @@ -255,8 +252,7 @@ def compute_moments(x, reg=1e-4, sqrt_inv=False): if sqrt_inv: sigma_sqrt, sigma_inv_sqrt, _ = matrix_square_root.sqrtm(sigma) return sigma, sigma_sqrt, sigma_inv_sqrt, mu - else: - return sigma, mu + return sigma, mu source, target = inputs _, covs_sqrt, covs_inv_sqrt, mus = compute_moments(source, sqrt_inv=True) @@ -268,14 +264,12 @@ def compute_moments(x, reg=1e-4, sqrt_inv=False): A = jnp.dot(jnp.dot(covs_inv_sqrt, mo), covs_inv_sqrt) b = jnp.squeeze(mus) - jnp.linalg.solve(A, jnp.squeeze(mut)) A = matrix_square_root.sqrtm_only(A) - return jnp.expand_dims(A, 0), jnp.expand_dims(b, 0) @staticmethod def _compute_identity_map(input_dim: int) -> Tuple[jnp.ndarray, jnp.ndarray]: A = jnp.eye(input_dim).reshape((1, input_dim, input_dim)) b = jnp.zeros((1, input_dim)) - return A, b @nn.compact @@ -292,6 +286,7 @@ def create_train_state( rng: jax.random.PRNGKeyArray, optimizer: optax.OptState, input: Union[int, Tuple[int, ...]], + # TODO(michalk8): do not ignore or delete in code? params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] = None, ) -> NeuralTrainState: """Create initial `TrainState`.""" diff --git a/src/ott/solvers/nn/neuraldual.py b/src/ott/solvers/nn/neuraldual.py index 5af1e9d41..0c0da69df 100644 --- a/src/ott/solvers/nn/neuraldual.py +++ b/src/ott/solvers/nn/neuraldual.py @@ -11,13 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A Jax implementation of the neural-based Kantorovich dual.""" - import warnings from typing import ( Callable, Dict, - Iterable, + Iterator, List, Literal, Optional, @@ -46,26 +44,21 @@ class W2NeuralDual: r"""Solver for the Wasserstein-2 Kantorovich dual between Euclidean spaces. Learn the Wasserstein-2 optimal transport between two measures - :math:`\alpha` and :math:`\beta` in - :math:`n`-dimensional Euclidean space, - denoted source and target, respectively. - This is achieved by parameterizing a Kantorovich potential - :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}` - associated with the :math:`\alpha` measure with - an :class:`~ott.solvers.nn.models.ICNN`, - :class:`~ott.solvers.nn.models.MLP`, or other - :class:`~ott.solvers.nn.models.ModelBase`, where - :math:`\nabla f` transports source to target cells. - This potential is learned by optimizing the dual - form associated with the negative inner product cost + :math:`\alpha` and :math:`\beta` in :math:`n`-dimensional Euclidean space, + denoted source and target, respectively. This is achieved by parameterizing + a Kantorovich potential :math:`f_\theta: \mathbb{R}^n\rightarrow\mathbb{R}` + associated with the :math:`\alpha` measure with an + :class:`~ott.solvers.nn.models.ICNN`, :class:`~ott.solvers.nn.models.MLP`, + or other :class:`~ott.solvers.nn.models.ModelBase`, where :math:`\nabla f` + transports source to target cells. This potential is learned by optimizing + the dual form associated with the negative inner product cost .. math:: \text{argsup}_{\theta}\; -\mathbb{E}_{x\sim\alpha}[f_\theta(x)] - \mathbb{E}_{y\sim\beta}[f^\star_\theta(y)], - where - :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle` + where :math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle` is the convex conjugate. :math:`\nabla f^\star` transports from the target to source cells and provides the inverse optimal @@ -89,13 +82,15 @@ class W2NeuralDual: Args: dim_data: input dimensionality of data required for network init neural_f: network architecture for potential :math:`f`. - neural_g: network architecture for the conjugate potential :math:`g\approx f^\star` + neural_g: network architecture for the conjugate potential + :math:`g\approx f^\star` optimizer_f: optimizer function for potential :math:`f` optimizer_g: optimizer function for the conjugate potential :math:`g` num_train_iters: number of total training iterations - num_inner_iters: number of training iterations of :math:`g` per iteration of :math:`f` - back_and_forth: alternate between updating the forward and backward directions. - Inspired from :cite:`jacobs:20` + num_inner_iters: number of training iterations of :math:`g` per iteration + of :math:`f` + back_and_forth: alternate between updating the forward and backward + directions. Inspired from :cite:`jacobs:20` valid_freq: frequency with which model is validated log_freq: frequency with training and validation are logged logging: option to return logs @@ -103,8 +98,9 @@ class W2NeuralDual: pos_weights: option to train networks with positive weights or regularizer beta: regularization parameter when not training with positive weights conjugate_solver: numerical solver for the Fenchel conjugate. - amortization_loss: amortization loss for the conjugate :math:`g\approx f^\star`. - Options are 'objective' :cite:`makkuva:20` or 'regression' :cite:`amos:23`. + amortization_loss: amortization loss for the conjugate + :math:`g\approx f^\star`. Options are `'objective'` :cite:`makkuva:20` or + `'regression'` :cite:`amos:23`. parallel_updates: Update :math:`f` and :math:`g` at the same time init_f_params: initial parameters for :math:`f` init_g_params: initial parameters for :math:`g` @@ -127,7 +123,7 @@ def __init__( pos_weights: bool = True, beta: float = 1.0, conjugate_solver: Conj_t = conjugate_solvers.DEFAULT_CONJUGATE_SOLVER, - amortization_loss: Literal['objective', 'regression'] = 'regression', + amortization_loss: Literal["objective", "regression"] = "regression", parallel_updates: bool = True, init_f_params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] = None, init_g_params: Optional[frozen_dict.FrozenDict[str, jnp.ndarray]] = None, @@ -214,8 +210,8 @@ def setup( else: if self.parallel_updates: warnings.warn( - 'parallel_updates set to True but disabling it ' - 'because num_inner_iters>1', + "parallel_updates set to True but disabling it " + "because num_inner_iters>1", stacklevel=2 ) if self.back_and_forth: @@ -230,10 +226,10 @@ def setup( def __call__( # noqa: D102 self, - trainloader_source: Iterable[jnp.ndarray], - trainloader_target: Iterable[jnp.ndarray], - validloader_source: Iterable[jnp.ndarray], - validloader_target: Iterable[jnp.ndarray], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Union[potentials.DualPotentials, Tuple[potentials.DualPotentials, Train_t]]: @@ -250,20 +246,19 @@ def __call__( # noqa: D102 def train_neuraldual_parallel( self, - trainloader_source: Iterable[jnp.ndarray], - trainloader_target: Iterable[jnp.ndarray], - validloader_source: Iterable[jnp.ndarray], - validloader_target: Iterable[jnp.ndarray], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: - """Implementation of the training and validation with parallel updates.""" # noqa: D401 + """Training and validation with parallel updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch - train_batch = {} - valid_batch = {} + train_batch, valid_batch = {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": [], "directions": []} @@ -274,24 +269,26 @@ def train_neuraldual_parallel( if update_forward: train_batch["source"] = jnp.asarray(next(trainloader_source)) train_batch["target"] = jnp.asarray(next(trainloader_target)) - self.state_f, self.state_g, loss, loss_f, loss_g, w_dist = self.train_step_parallel( - self.state_f, - self.state_g, - train_batch, - ) + (self.state_f, self.state_g, loss, loss_f, loss_g, + w_dist) = self.train_step_parallel( + self.state_f, + self.state_g, + train_batch, + ) else: train_batch["target"] = jnp.asarray(next(trainloader_source)) train_batch["source"] = jnp.asarray(next(trainloader_target)) - self.state_g, self.state_f, loss, loss_f, loss_g, w_dist = self.train_step_parallel( - self.state_g, - self.state_f, - train_batch, - ) + (self.state_g, self.state_f, loss, loss_f, loss_g, + w_dist) = self.train_step_parallel( + self.state_g, + self.state_f, + train_batch, + ) if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) train_logs["directions"].append( - 'forward' if update_forward else 'backward' + "forward" if update_forward else "backward" ) if callback is not None: @@ -324,21 +321,19 @@ def train_neuraldual_parallel( def train_neuraldual_alternating( self, - trainloader_source: Iterable[jnp.ndarray], - trainloader_target: Iterable[jnp.ndarray], - validloader_source: Iterable[jnp.ndarray], - validloader_target: Iterable[jnp.ndarray], + trainloader_source: Iterator[jnp.ndarray], + trainloader_target: Iterator[jnp.ndarray], + validloader_source: Iterator[jnp.ndarray], + validloader_target: Iterator[jnp.ndarray], callback: Optional[Callback_t] = None, ) -> Train_t: - """Implementation of the training and validation with alternating updates.""" # noqa: D401 + """Training and validation with alternating updates.""" try: from tqdm.auto import tqdm except ImportError: tqdm = lambda _: _ # define dict to contain source and target batch - batch_g = {} - batch_f = {} - valid_batch = {} + batch_g, batch_f, valid_batch = {}, {}, {} # set logging dictionaries train_logs = {"loss_f": [], "loss_g": [], "w_dist": []} @@ -374,7 +369,7 @@ def train_neuraldual_alternating( if self.logging and step % self.log_freq == 0: self._update_logs(train_logs, loss_f, loss_g, w_dist) - # report the loss on an validuation dataset periodically + # report the loss on validation dataset periodically if step != 0 and step % self.valid_freq == 0: # get batch valid_batch["source"] = jnp.asarray(next(validloader_source)) @@ -395,7 +390,7 @@ def train_neuraldual_alternating( return {"train_logs": train_logs, "valid_logs": valid_logs} def get_step_fn( - self, train: bool, to_optimize: Literal["f", "g", "parallel"] + self, train: bool, to_optimize: Literal["f", "g", "parallel", "both"] ): """Create a parallel training and evaluation function.""" @@ -431,9 +426,9 @@ def g_value_partial(y: jnp.ndarray) -> jnp.ndarray: dual_target = f_star_target.mean() dual_loss = dual_source + dual_target - if self.amortization_loss == 'regression': + if self.amortization_loss == "regression": amor_loss = ((init_source_hat - source_hat_detach) ** 2).mean() - elif self.amortization_loss == 'objective': + elif self.amortization_loss == "objective": f_value_parameters_detached = lambda x: f_value( jax.lax.stop_gradient(params_f), x ) @@ -484,41 +479,37 @@ def step_fn(state_f, state_g, batch): state_g.potential_gradient_fn, batch, ) - # update state if to_optimize == "both": - return state_f.apply_gradients(grads=grads_f), \ - state_g.apply_gradients(grads=grads_g), \ - loss, loss_f, loss_g, W2_dist - elif to_optimize == "f": - return state_f.apply_gradients(grads=grads_f), \ - loss_f, W2_dist - elif to_optimize == "g": - return state_g.apply_gradients(grads=grads_g), \ - loss_g, W2_dist - else: - raise ValueError("Optimization target has been misspecified.") - - else: - # compute loss and gradients - (loss, (loss_f, loss_g, W2_dist)), (grads_f, grads_g) = grad_fn( - state_f.params, - state_g.params, - state_f.potential_value_fn, - state_g.potential_value_fn, - state_g.potential_gradient_fn, - batch, - ) + return ( + state_f.apply_gradients(grads=grads_f), + state_g.apply_gradients(grads=grads_g), loss, loss_f, loss_g, + W2_dist + ) + if to_optimize == "f": + return state_f.apply_gradients(grads=grads_f), loss_f, W2_dist + if to_optimize == "g": + return state_g.apply_gradients(grads=grads_g), loss_g, W2_dist + raise ValueError("Optimization target has been misspecified.") + + # compute loss and gradients + (loss, (loss_f, loss_g, W2_dist)), _ = grad_fn( + state_f.params, + state_g.params, + state_f.potential_value_fn, + state_g.potential_value_fn, + state_g.potential_gradient_fn, + batch, + ) - # do not update state - if to_optimize == "both": - return loss_f, loss_g, W2_dist - elif to_optimize == "f": - return loss_f, W2_dist - elif to_optimize == "g": - return loss_g, W2_dist - else: - raise ValueError("Optimization target has been misspecified.") + # do not update state + if to_optimize == "both": + return loss_f, loss_g, W2_dist + if to_optimize == "f": + return loss_f, W2_dist + if to_optimize == "g": + return loss_g, W2_dist + raise ValueError("Optimization target has been misspecified.") return step_fn @@ -543,12 +534,11 @@ def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: grad_g_y = jax.lax.stop_gradient( self.conjugate_solver.solve(f_value, y, x_init=x_hat).grad ) - g_y = -f_value(grad_g_y) + jnp.dot(grad_g_y, y) - return g_y + return -f_value(grad_g_y) + jnp.dot(grad_g_y, y) return potentials.DualPotentials( f=f_value, - g=g_value_prediction if not finetune_g else g_value_finetuned, + g=g_value_finetuned if finetune_g else g_value_prediction, cost_fn=costs.SqEuclidean(), corr=True ) @@ -556,7 +546,7 @@ def g_value_finetuned(y: jnp.ndarray) -> jnp.ndarray: @staticmethod def _clip_weights_icnn(params): params = params.unfreeze() - for k in params.keys(): + for k in params: if k.startswith("w_z"): params[k]["kernel"] = jnp.clip(params[k]["kernel"], a_min=0) @@ -572,7 +562,7 @@ def _penalize_weights_icnn(params: Dict[str, jnp.ndarray]) -> float: @staticmethod def _update_logs( - logs: Dict[str, Union[float, str]], + logs: Dict[str, List[Union[float, str]]], loss_f: jnp.ndarray, loss_g: jnp.ndarray, w_dist: jnp.ndarray, diff --git a/src/ott/solvers/quadratic/gromov_wasserstein.py b/src/ott/solvers/quadratic/gromov_wasserstein.py index 875e8311c..cedffb29a 100644 --- a/src/ott/solvers/quadratic/gromov_wasserstein.py +++ b/src/ott/solvers/quadratic/gromov_wasserstein.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""A Jax version of the regularised GW Solver (Peyre et al. 2016).""" from typing import ( Any, Dict, @@ -67,7 +66,7 @@ class GWOutput(NamedTuple): # Intermediate values. old_transport_mass: float = 1.0 - def set(self, **kwargs: Any) -> 'GWOutput': + def set(self, **kwargs: Any) -> "GWOutput": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) @@ -96,7 +95,7 @@ def primal_cost(self) -> float: class GWState(NamedTuple): - """Holds the state of the Gromov-Wasserstein solver. + """State of the Gromov-Wasserstein solver. Attributes: costs: Holds the sequence of regularized GW costs seen through the outer @@ -121,15 +120,15 @@ class GWState(NamedTuple): rngs: Optional[jax.random.PRNGKeyArray] = None errors: Optional[jnp.ndarray] = None - def set(self, **kwargs: Any) -> 'GWState': + def set(self, **kwargs: Any) -> "GWState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) - def update( + def update( # noqa: D102 self, iteration: int, linear_sol: LinearOutput, linear_pb: linear_problem.LinearProblem, store_errors: bool, old_transport_mass: float - ) -> 'GWState': + ) -> "GWState": costs = self.costs.at[iteration].set(linear_sol.reg_ot_cost) errors = None if store_errors and self.errors is not None: @@ -150,7 +149,7 @@ def update( @jax.tree_util.register_pytree_node_class class GromovWasserstein(was_solver.WassersteinSolver): - """Gromov-Wasserstein solver. + """Gromov-Wasserstein solver :cite:`peyre:16`. Args: args: Positional arguments for @@ -161,23 +160,13 @@ class GromovWasserstein(was_solver.WassersteinSolver): quad_initializer: Quadratic initializer. If the solver is entropic, :class:`~ott.initializers.quadratic.initializers.QuadraticInitializer` is always used. Otherwise, the quadratic initializer wraps the low-rank - Sinkhorn initializers: - - - `'random'` - :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. - - `'rank2'` - :class:`~ott.initializers.linear.initializers_lr.Rank2Initializer`. - - `'k-means'` - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. - - `'generalized-k-means'` - :class:`~ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer`. - - If `None`, the low-rank initializer will be selected in a problem-specific - manner: - - - if both :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_xx` - and :attr:`~ott.problems.quadratic.quadratic_problem.QuadraticProblem.geom_yy` - are :class:`~ott.geometry.pointcloud.PointCloud` or :class:`~ott.geometry.low_rank.LRCGeometry`, - :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer` - is used. - - otherwise, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. - + Sinkhorn initializers. If `None`, the low-rank initializer will be + selected in a problem-specific manner. If both ``geom_xx`` and ``geom_yy`` + are :class:`~ott.geometry.pointcloud.PointCloud` or + :class:`~ott.geometry.low_rank.LRCGeometry`, use + :class:`~ott.initializers.linear.initializers_lr.KMeansInitializer`. + Otherwise, use + :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. kwargs_init: Keyword arguments when creating the initializer. kwargs: Keyword arguments for :class:`~ott.solvers.was_solver.WassersteinSolver`. @@ -211,6 +200,7 @@ def __call__( prob: Quadratic OT problem. init: Initial linearization of the quadratic problem. If `None`, it will be computed using the initializer. + rng: Random number key. kwargs: Keyword arguments used when calling the initializer. Returns: @@ -410,7 +400,7 @@ def solve( scale_cost: Optional[Union[bool, float, str]] = False, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, - loss: Union[Literal['sqeucl', 'kl'], quadratic_costs.GWLoss] = 'sqeucl', + loss: Union[Literal["sqeucl", "kl"], quadratic_costs.GWLoss] = "sqeucl", tau_a: Optional[float] = 1.0, tau_b: Optional[float] = 1.0, gw_unbalanced_correction: bool = True, @@ -464,8 +454,8 @@ def solve( tau_b: if `< 1.0`, defines how much unbalanced the problem is on the second marginal. gw_unbalanced_correction: Whether the unbalanced version of - :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only affect - the inner Sinkhorn loop. + :cite:`sejourne:21` is used. Otherwise, ``tau_a`` and ``tau_b`` only + affect the inner Sinkhorn loop. ranks: Ranks of the cost matrices, see :meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry`. Used when geometries are *not* :class:`~ott.geometry.pointcloud.PointCloud` with diff --git a/src/ott/solvers/quadratic/gw_barycenter.py b/src/ott/solvers/quadratic/gw_barycenter.py index 76a769845..730790f90 100644 --- a/src/ott/solvers/quadratic/gw_barycenter.py +++ b/src/ott/solvers/quadratic/gw_barycenter.py @@ -28,11 +28,10 @@ class GWBarycenterState(NamedTuple): - """Holds the state of the \ - :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. + """State of the GW barycenter problem. Args: - c: Barycenter cost matrix of shape ``[bar_size, bar_size]``. + cost: Barycenter cost matrix of shape ``[bar_size, bar_size]``. x: Barycenter features of shape ``[bar_size, ndim_fused]``. Only used in the fused case. a: Weights of the barycenter of shape ``[bar_size,]``. @@ -50,15 +49,14 @@ class GWBarycenterState(NamedTuple): costs: Optional[jnp.ndarray] = None gw_convergence: Optional[jnp.ndarray] = None - def set(self, **kwargs: Any) -> 'GWBarycenterState': + def set(self, **kwargs: Any) -> "GWBarycenterState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs) @jax.tree_util.register_pytree_node_class class GromovWassersteinBarycenter(was_solver.WassersteinSolver): - """Gromov-Wasserstein barycenter solver of the \ - :class:`~ott.problems.quadratic.gw_barycenter.GWBarycenterProblem`. + """Gromov-Wasserstein barycenter solver. Args: epsilon: Entropy regulariser. @@ -296,7 +294,7 @@ def init_transports( return solver(problem).matrix -def iterations( +def iterations( # noqa: D103 solver: GromovWassersteinBarycenter, problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState ) -> GWBarycenterState: @@ -317,7 +315,7 @@ def body_fn( solver, problem = constants return solver.update_state(state, iteration, problem) - state = fixed_point_loop.fixpoint_iter( + return fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, @@ -326,4 +324,3 @@ def body_fn( constants=(solver, problem), state=init_state, ) - return state diff --git a/src/ott/tools/gaussian_mixture/fit_gmm.py b/src/ott/tools/gaussian_mixture/fit_gmm.py index 35e16ff38..30a98e42d 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm.py @@ -166,7 +166,7 @@ def fit_model_em( assignment_probs = e_step_fn(gmm, points) gmm_new = m_step_fn(points, point_weights, assignment_probs) if gmm_new.has_nans(): - raise ValueError('NaNs in fit.') + raise ValueError("NaNs in fit.") if verbose: loss = loss_fn(gmm_new, points, point_weights) q = get_q_fn( @@ -175,7 +175,7 @@ def fit_model_em( points=points, point_weights=point_weights ) - print(f'{i} q={q} -log prob={loss}', flush=True) + print(f"{i} q={q} -log prob={loss}") # noqa: T201 gmm = gmm_new return gmm @@ -299,5 +299,5 @@ def initialize( ) except ValueError: if verbose: - print(f'Failed to initialize, attempt {attempt}.', flush=True) - raise ValueError('Failed to initialize.') + print(f"Failed to initialize, attempt {attempt}.") # noqa: T201 + raise ValueError("Failed to initialize.") diff --git a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py index 688e0961a..ccf02fbab 100644 --- a/src/ott/tools/gaussian_mixture/fit_gmm_pair.py +++ b/src/ott/tools/gaussian_mixture/fit_gmm_pair.py @@ -194,18 +194,17 @@ def print_losses( transport_penalty = sinkhorn_output.reg_ot_cost objective = q0 + q1 - weight_transport * transport_penalty - print(( - f'{iteration:3d} {q0:.3f} {q1:.3f} ' - f'transport:{transport_penalty:.3f} ' - f'objective:{objective:.3f}' - ), - flush=True) + print( # noqa: T201 + f"{iteration:3d} {q0:.3f} {q1:.3f} " + f"transport:{transport_penalty:.3f} " + f"objective:{objective:.3f}" + ) # The E-step for a single GMM -def do_e_step( +def do_e_step( # noqa: D103 e_step_fn: Callable[[gaussian_mixture.GaussianMixture, jnp.ndarray], jnp.ndarray], gmm: gaussian_mixture.GaussianMixture, @@ -236,17 +235,6 @@ def get_m_step_fn(learning_rate: float, objective_fn, jit: bool): Returns: A function that performs the M-step of EM. """ - grad_objective_fn = jax.grad(objective_fn, argnums=(0,)) - gmm_m_step_fn = gaussian_mixture.GaussianMixture.from_points_and_assignment_probs - if jit: - grad_objective_fn = jax.jit(grad_objective_fn) - gmm_m_step_fn = jax.jit(gmm_m_step_fn) - - opt_init, opt_update = optax.chain( - # Set the parameters of Adam. Note the learning_rate is not here. - optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), - optax.scale(learning_rate) - ) def _m_step_fn( pair: gaussian_mixture_pair.GaussianMixturePair, @@ -273,9 +261,19 @@ def _m_step_fn( (pair,) = optax.apply_updates((pair,), updates) for j, gmm in enumerate((pair.gmm0, pair.gmm1)): if gmm.has_nans(): - raise ValueError(f'NaN in gmm{j}') + raise ValueError(f"NaN in gmm{j}") return pair + grad_objective_fn = jax.grad(objective_fn, argnums=(0,)) + if jit: + grad_objective_fn = jax.jit(grad_objective_fn) + + opt_init, opt_update = optax.chain( + # Set the parameters of Adam. Note the learning_rate is not here. + optax.scale_by_adam(b1=0.9, b2=0.999, eps=1e-8), + optax.scale(learning_rate) + ) + return _m_step_fn diff --git a/src/ott/tools/gaussian_mixture/gaussian.py b/src/ott/tools/gaussian_mixture/gaussian.py index a28414539..064263c6f 100644 --- a/src/ott/tools/gaussian_mixture/gaussian.py +++ b/src/ott/tools/gaussian_mixture/gaussian.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytree for a normal distribution.""" - import math from typing import Optional, Union @@ -39,11 +37,11 @@ def from_samples( cls, points: jnp.ndarray, weights: Optional[jnp.ndarray] = None - ) -> 'Gaussian': + ) -> "Gaussian": """Construct a Gaussian from weighted samples. - Unbiased, weighted covariance formula from https://en.wikipedia.org/wiki/Sample_mean_and_covariance#Weighted_samples - and https://www.gnu.org/software/gsl/doc/html/statistics.html?highlight=weighted#weighted-samples + Unbiased, weighted covariance formula from `GSL + `_. Args: points: [n x d] array of samples @@ -69,16 +67,18 @@ def from_random( n_dimensions: int, stdev_mean: float = 0.1, stdev_cov: float = 0.1, - ridge: Union[float, jnp.array] = 0, + ridge: Union[float, jnp.ndarray] = 0, dtype: Optional[jnp.dtype] = None - ) -> 'Gaussian': + ) -> "Gaussian": """Construct a random Gaussian. Args: rng: jax.random key n_dimensions: desired covariance dimensions - stdev: standard deviation of loc and log eigenvalues + stdev_mean: standard deviation of loc and log eigenvalues (means for both are 0) + stdev_cov: standard deviated of the covariance + ridge: Offset for means. dtype: data type Returns: @@ -94,7 +94,7 @@ def from_random( return cls(loc=loc, scale=scale) @classmethod - def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> 'Gaussian': + def from_mean_and_cov(cls, mean: jnp.ndarray, cov: jnp.ndarray) -> "Gaussian": """Construct a Gaussian from a mean and covariance.""" scale = scale_tril.ScaleTriL.from_covariance(cov) return cls(loc=mean, scale=scale) @@ -149,7 +149,7 @@ def sample(self, rng: jax.random.PRNGKeyArray, size: int) -> jnp.ndarray: ) ) - def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: + def w2_dist(self, other: "Gaussian") -> jnp.ndarray: r"""Wasserstein distance W_2^2 to another Gaussian. W_2^2 = ||\mu_0-\mu_1||^2 + @@ -165,7 +165,7 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma - def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + def f_potential(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Optimal potential for W2 distance between Gaussians. Evaluated on points. Args: @@ -189,7 +189,7 @@ def batch_inner_product(x, y): points.dot(dest.loc) ) - def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + def transport(self, dest: "Gaussian", points: jnp.ndarray) -> jnp.ndarray: """Transport points according to map between two Gaussian measures. Args: diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture.py b/src/ott/tools/gaussian_mixture/gaussian_mixture.py index bbd37b325..ba52f94c5 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytree for a Gaussian mixture model.""" - from typing import List, Optional, Tuple, Union import jax @@ -88,7 +86,7 @@ def from_random( stdev_weights: float = 0.1, ridge: Union[float, jnp.array] = 0, dtype: Optional[jnp.dtype] = None - ) -> 'GaussianMixture': + ) -> "GaussianMixture": """Construct a random GMM.""" loc = [] scale_params = [] @@ -133,7 +131,7 @@ def from_points_and_assignment_probs( points: jnp.ndarray, point_weights: jnp.ndarray, assignment_probs: jnp.ndarray, - ) -> 'GaussianMixture': + ) -> "GaussianMixture": """Estimate a GMM from points and a set of component probabilities.""" mean, cov, wts = get_summary_stats_from_points_and_assignment_probs( points=points, @@ -321,9 +319,9 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 def __repr__(self): class_name = type(self).__name__ children, aux = self.tree_flatten() - return '{}({})'.format( - class_name, ', '.join([repr(c) for c in children] + - [f'{k}: {repr(v)}' for k, v in aux.items()]) + return "{}({})".format( + class_name, ", ".join([repr(c) for c in children] + + [f"{k}: {repr(v)}" for k, v in aux.items()]) ) def __hash__(self): diff --git a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py index 1cc8f8c63..214fd991a 100644 --- a/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py +++ b/src/ott/tools/gaussian_mixture/gaussian_mixture_pair.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytree containing parameters for a pair of coupled Gaussian mixture models. -""" # noqa: D200 from typing import Any import jax @@ -178,12 +176,12 @@ def tree_flatten(self): """ # noqa: D401 children = [self.gmm0] aux_data = { - 'epsilon': self.epsilon, - 'tau': self.tau, - 'lock_gmm1': self.lock_gmm1 + "epsilon": self.epsilon, + "tau": self.tau, + "lock_gmm1": self.lock_gmm1 } if self.lock_gmm1: - aux_data['gmm1'] = self.gmm1 + aux_data["gmm1"] = self.gmm1 else: children.append(self.gmm1) return tuple(children), aux_data @@ -203,17 +201,17 @@ def tree_unflatten(cls, aux_data, children): A GaussianMixturePair. """ # noqa: D401 children = list(children) - if 'gmm1' in aux_data: - gmm1 = aux_data.pop('gmm1') + if "gmm1" in aux_data: + gmm1 = aux_data.pop("gmm1") children.insert(1, gmm1) return cls(*children, **aux_data) def __repr__(self): class_name = type(self).__name__ children, aux = self.tree_flatten() - return '{}({})'.format( - class_name, ', '.join([repr(c) for c in children] + - [f'{k}: {repr(v)}' for k, v in aux.items()]) + return "{}({})".format( + class_name, ", ".join([repr(c) for c in children] + + [f"{k}: {repr(v)}" for k, v in aux.items()]) ) def __hash__(self): diff --git a/src/ott/tools/gaussian_mixture/linalg.py b/src/ott/tools/gaussian_mixture/linalg.py index 869a0975e..34b1c6dff 100644 --- a/src/ott/tools/gaussian_mixture/linalg.py +++ b/src/ott/tools/gaussian_mixture/linalg.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Linear algebra utility methods for optimal transport of Gaussian mixtures.""" - from typing import Callable, Iterable, List, Optional, Tuple import jax diff --git a/src/ott/tools/gaussian_mixture/probabilities.py b/src/ott/tools/gaussian_mixture/probabilities.py index ce7fa1d6d..9a7ad5e24 100644 --- a/src/ott/tools/gaussian_mixture/probabilities.py +++ b/src/ott/tools/gaussian_mixture/probabilities.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytree for a vector of probabilities.""" - from typing import Optional import jax @@ -41,7 +39,7 @@ def from_random( n_dimensions: int, stdev: Optional[float] = 0.1, dtype: Optional[jnp.dtype] = None - ) -> 'Probabilities': + ) -> "Probabilities": """Construct a random Probabilities.""" return cls( params=jax.random @@ -49,7 +47,7 @@ def from_random( ) @classmethod - def from_probs(cls, probs: jnp.ndarray) -> 'Probabilities': + def from_probs(cls, probs: jnp.ndarray) -> "Probabilities": """Construct Probabilities from a vector of probabilities.""" log_probs = jnp.log(probs) log_probs_normalized, norm = log_probs[:-1], log_probs[-1] @@ -96,9 +94,9 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 def __repr__(self): class_name = type(self).__name__ children, aux = self.tree_flatten() - return '{}({})'.format( - class_name, ', '.join([repr(c) for c in children] + - [f'{k}: {repr(v)}' for k, v in aux.items()]) + return "{}({})".format( + class_name, ", ".join([repr(c) for c in children] + + [f"{k}: {repr(v)}" for k, v in aux.items()]) ) def __hash__(self): diff --git a/src/ott/tools/gaussian_mixture/scale_tril.py b/src/ott/tools/gaussian_mixture/scale_tril.py index 8aac0a455..03cbf0a7a 100644 --- a/src/ott/tools/gaussian_mixture/scale_tril.py +++ b/src/ott/tools/gaussian_mixture/scale_tril.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Pytree for a lower triangular Cholesky factored covariance matrix.""" - from typing import Optional, Tuple import jax @@ -38,7 +36,7 @@ def from_points_and_weights( cls, points: jnp.ndarray, weights: jnp.ndarray, - ) -> Tuple[jnp.ndarray, 'ScaleTriL']: + ) -> Tuple[jnp.ndarray, "ScaleTriL"]: """Get a mean and a ScaleTriL from a set of points and weights.""" mean, cov = linalg.get_mean_and_cov(points=points, weights=weights) return mean, cls.from_covariance(cov) @@ -50,7 +48,7 @@ def from_random( n_dimensions: int, stdev: Optional[float] = 0.1, dtype: jnp.dtype = jnp.float32, - ) -> 'ScaleTriL': + ) -> "ScaleTriL": """Construct a random ScaleTriL. Args: @@ -82,7 +80,7 @@ def from_random( return cls(params=flat, size=n_dimensions) @classmethod - def from_cholesky(cls, cholesky: jnp.ndarray) -> 'ScaleTriL': + def from_cholesky(cls, cholesky: jnp.ndarray) -> "ScaleTriL": """Construct ScaleTriL from a Cholesky factor of a covariance matrix.""" m = linalg.apply_to_diag(cholesky, jnp.log) flat = linalg.tril_to_flat(m) @@ -92,7 +90,7 @@ def from_cholesky(cls, cholesky: jnp.ndarray) -> 'ScaleTriL': def from_covariance( cls, covariance: jnp.ndarray, - ) -> 'ScaleTriL': + ) -> "ScaleTriL": """Construct ScaleTriL from a covariance matrix.""" cholesky = jnp.linalg.cholesky(covariance) return cls.from_cholesky(cholesky) @@ -139,7 +137,7 @@ def z_to_centered(self, z: jnp.ndarray) -> jnp.ndarray: """Scale standardized points to points with the specified covariance.""" return (self.cholesky() @ z.T).T - def w2_dist(self, other: 'ScaleTriL') -> jnp.ndarray: + def w2_dist(self, other: "ScaleTriL") -> jnp.ndarray: r"""Wasserstein distance W_2^2 to another Gaussian with same mean. Args: @@ -160,7 +158,7 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[...,] - def gaussian_map(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: + def gaussian_map(self, dest_scale: "ScaleTriL") -> jnp.ndarray: """Scaling matrix used in transport between 0-mean Gaussians. Sigma_mu^{-1/2} @ @@ -178,11 +176,10 @@ def gaussian_map(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: m = matrix_square_root.sqrtm_only( jnp.matmul(sqrt0, jnp.matmul(sigma1, sqrt0)) ) - m = jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) - return m + return jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) def transport( - self, dest_scale: 'ScaleTriL', points: jnp.ndarray + self, dest_scale: "ScaleTriL", points: jnp.ndarray ) -> jnp.ndarray: """Apply Monge map, computed between two 0-mean Gaussians, to points. @@ -198,7 +195,7 @@ def transport( def tree_flatten(self): # noqa: D102 children = (self.params,) - aux_data = {'size': self.size} + aux_data = {"size": self.size} return children, aux_data @classmethod @@ -208,9 +205,9 @@ def tree_unflatten(cls, aux_data, children): # noqa: D102 def __repr__(self): class_name = type(self).__name__ children, aux = self.tree_flatten() - return '{}({})'.format( - class_name, ', '.join([repr(c) for c in children] + - [f'{k}: {repr(v)}' for k, v in aux.items()]) + return "{}({})".format( + class_name, ", ".join([repr(c) for c in children] + + [f"{k}: {repr(v)}" for k, v in aux.items()]) ) def __hash__(self): diff --git a/src/ott/tools/k_means.py b/src/ott/tools/k_means.py index 8f0b46334..46aed8be7 100644 --- a/src/ott/tools/k_means.py +++ b/src/ott/tools/k_means.py @@ -27,13 +27,13 @@ Callable[[pointcloud.PointCloud, int, jnp.ndarray], jnp.ndarray]] -class KPPState(NamedTuple): +class KPPState(NamedTuple): # noqa: D101 rng: jax.random.PRNGKeyArray centroids: jnp.ndarray centroid_dists: jnp.ndarray -class KMeansState(NamedTuple): +class KMeansState(NamedTuple): # noqa: D101 centroids: jnp.ndarray prev_assignment: jnp.ndarray assignment: jnp.ndarray @@ -41,7 +41,7 @@ class KMeansState(NamedTuple): center_shift: float -class KMeansConst(NamedTuple): +class KMeansConst(NamedTuple): # noqa: D101 geom: pointcloud.PointCloud x_weights: jnp.ndarray diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index d9eb34f23..9d99a22d4 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Plotting utils.""" - from typing import List, Optional, Sequence, Tuple, Union import jax.numpy as jnp @@ -48,8 +46,7 @@ def bidimensional(x: jnp.ndarray, class Plot: - """Plot an optimal transport map between two \ - :class:`PointClouds `. + """Plot an optimal transport map between two point clouds. It enables to either plot or update a plot in a single object, offering the possibilities to create animations as a @@ -68,7 +65,7 @@ def __init__( cost_threshold: float = -1.0, # should be negative for animations. scale: int = 200, show_lines: bool = True, - cmap: str = 'cool' + cmap: str = "cool" ): if plt is None: raise RuntimeError("Please install `matplotlib` first.") @@ -92,7 +89,7 @@ def __init__( def _scatter(self, ot: Transport): """Compute the position and scales of the points on a 2D plot.""" if not isinstance(ot.geom, pointcloud.PointCloud): - raise ValueError('So far we only plot PointCloud geometry.') + raise ValueError("So far we only plot PointCloud geometry.") x, y = ot.geom.x, ot.geom.y a, b = ot.a, ot.b @@ -116,10 +113,10 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]: """Plot 2-D couplings. Projects via PCA if data is higher dimensional.""" x, y, sx, sy = self._scatter(ot) self._points_x = self.ax.scatter( - *x.T, s=sx, edgecolors='k', marker='o', label='x' + *x.T, s=sx, edgecolors="k", marker="o", label="x" ) self._points_y = self.ax.scatter( - *y.T, s=sy, edgecolors='k', marker='X', label='y' + *y.T, s=sy, edgecolors="k", marker="X", label="y" ) self.ax.legend(fontsize=15) if not self._show_lines: @@ -201,9 +198,9 @@ def _barycenters( ) -> None: """Plot 2-D sinkhorn barycenters.""" sa, sb = jnp.min(a) / scale, jnp.min(b) / scale - ax.scatter(*y.T, s=b / sb, edgecolors='k', marker='X', label='y') + ax.scatter(*y.T, s=b / sb, edgecolors="k", marker="X", label="y") tx = 1 / a[:, None] * jnp.matmul(matrix, y) - ax.scatter(*tx.T, s=a / sa, edgecolors='k', marker='X', label='T(x)') + ax.scatter(*tx.T, s=a / sa, edgecolors="k", marker="X", label="T(x)") ax.legend(fontsize=15) @@ -221,7 +218,7 @@ def barycentric_projections( if utils.is_jax_array(arg): if matrix is None: - raise ValueError('The `matrix` argument cannot be None.') + raise ValueError("The `matrix` argument cannot be None.") a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b diff --git a/src/ott/tools/segment_sinkhorn.py b/src/ott/tools/segment_sinkhorn.py index d998850f9..852afd24d 100644 --- a/src/ott/tools/segment_sinkhorn.py +++ b/src/ott/tools/segment_sinkhorn.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Segmented sinkhorn utility.""" from types import MappingProxyType from typing import Any, Mapping, Optional, Tuple diff --git a/src/ott/tools/sinkhorn_divergence.py b/src/ott/tools/sinkhorn_divergence.py index efcb2fe01..54bfcb3ef 100644 --- a/src/ott/tools/sinkhorn_divergence.py +++ b/src/ott/tools/sinkhorn_divergence.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Implements the sinkhorn divergence.""" from types import MappingProxyType from typing import Any, List, Mapping, NamedTuple, Optional, Tuple, Type @@ -263,6 +262,7 @@ def segment_sinkhorn_divergence( :class:`~ott.geometry.pointcloud.PointCloud` geometry objects from the subsets of points and masses selected in `x` and `y`, this could be for instance entropy regularization float, scheduler or normalization. + Returns: An array of sinkhorn divergence values for each segment. """ diff --git a/src/ott/tools/soft_sort.py b/src/ott/tools/soft_sort.py index 6325634df..226d44bf8 100644 --- a/src/ott/tools/soft_sort.py +++ b/src/ott/tools/soft_sort.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Soft sort operators.""" - import functools from typing import Any, Callable, Optional @@ -56,7 +54,7 @@ def transport_for_sort( shape = inputs.shape if len(shape) > 2 or (len(shape) == 2 and shape[1] != 1): raise ValueError( - f'Shape ({shape}) not supported. The input should be one-dimensional.' + f"Shape ({shape}) not supported. The input should be one-dimensional." ) x = jnp.expand_dims(jnp.squeeze(inputs), axis=1) @@ -110,8 +108,7 @@ def apply_on_axis(op, inputs, axis, *args, **kwargs: Any) -> jnp.ndarray: rank = len(result.shape) - 1 axis = min(axis) permutation = permutation[:axis] + (rank,) + permutation[axis:-1] - result = jnp.transpose(result, permutation) - return result + return jnp.transpose(result, permutation) def _sort( @@ -244,6 +241,7 @@ def quantile( ``num_targets`` target values (squared Euclidean distance by default, see ``pointcloud.py`` for more details); ``epsilon`` values as well as other parameters to shape the ``sinkhorn`` algorithm. + Returns: A jnp.ndarray, which has the same shape as the input, except on the give axis on which the dimension is 1. @@ -300,6 +298,7 @@ def quantile_normalization( ``num_targets`` target values (squared Euclidean distance by default, see ``pointcloud.py`` for more details); ``epsilon`` values as well as other parameters to shape the ``sinkhorn`` algorithm. + Returns: A jnp.ndarray, which has the same shape as the input, except on the give axis on which the dimension is 1. @@ -310,8 +309,8 @@ def quantile_normalization( """ if weights is not None and weights.shape != targets.shape: raise ValueError( - 'The target weights and targets values should have the ' - f'same shape: {targets.shape} != {weights.shape}' + "The target weights and targets values should have the " + f"same shape: {targets.shape} != {weights.shape}" ) if weights is None: num_targets = targets.shape[0] diff --git a/src/ott/types.py b/src/ott/types.py index 10210690d..7a4c88716 100644 --- a/src/ott/types.py +++ b/src/ott/types.py @@ -15,6 +15,8 @@ import jax.numpy as jnp +__all__ = ["Transport"] + # TODO(michalk8): introduce additional types here diff --git a/src/ott/utils.py b/src/ott/utils.py index 27e4474c4..4d2a10766 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -127,7 +127,7 @@ def progress_fn(status, *args): with tqdm() as pbar: out_sink = jax.jit(solver)(prob) - """ + """ # noqa: D205 # Convert arguments. iteration, inner_iterations, total_iter, state = status iteration = int(iteration) @@ -141,4 +141,4 @@ def progress_fn(status, *args): error_idx = max((iteration + 1) // inner_iterations - 1, 0) error = errors[error_idx] - print(f"{iteration} / {total_iter} -- {error}") + print(f"{iteration} / {total_iter} -- {error}") # noqa: T201 diff --git a/tests/conftest.py b/tests/conftest.py index 23ed538af..c2974d6ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,24 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import collections.abc import itertools from typing import Any, Mapping, Optional, Sequence -import pytest -from _pytest.config.argparsing import Parser -from _pytest.python import Metafunc - import jax import jax.numpy as jnp +import pytest +from _pytest.python import Metafunc def pytest_generate_tests(metafunc: Metafunc) -> None: @@ -56,30 +67,16 @@ def pytest_generate_tests(metafunc: Metafunc) -> None: metafunc.parametrize(argnames, combinations, ids=ids) -def pytest_addoption(parser: Parser) -> None: - parser.addoption( - "--kernel-name", - default="python3", - help="Jupyter kernel name when executing notebook tests." - ) - parser.addoption( - "--notebook-cell-timeout", - type=int, - default=60, - help="Execution timeout in seconds for notebook cells." - ) - - @pytest.fixture(scope="session") def rng() -> jnp.ndarray: return jax.random.PRNGKey(0) @pytest.fixture() -def enable_x64(): +def enable_x64() -> bool: previous_value = jax.config.jax_enable_x64 jax.config.update("jax_enable_x64", True) try: - yield + yield True finally: jax.config.update("jax_enable_x64", previous_value) diff --git a/tests/geometry/costs_test.py b/tests/geometry/costs_test.py index 875b6f0c7..80525ed63 100644 --- a/tests/geometry/costs_test.py +++ b/tests/geometry/costs_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the cost/norm functions.""" from typing import Type -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, pointcloud from ott.solvers.linear import sinkhorn @@ -29,7 +26,7 @@ ts_metrics = None -@pytest.mark.fast +@pytest.mark.fast() class TestCostFn: def test_cosine(self, rng: jax.random.PRNGKeyArray): @@ -67,7 +64,7 @@ def test_cosine(self, rng: jax.random.PRNGKeyArray): atol=1e-5 ) - all_pairs = cosine_fn.all_pairs(x, y) + all_pairs = cosine_fn.all_pairs_pairwise(x, y) for i in range(n): for j in range(m): np.testing.assert_allclose( @@ -78,7 +75,7 @@ def test_cosine(self, rng: jax.random.PRNGKeyArray): ) -@pytest.mark.fast +@pytest.mark.fast() class TestBuresBarycenter: def test_bures(self, rng: jax.random.PRNGKeyArray): @@ -104,7 +101,7 @@ def test_bures(self, rng: jax.random.PRNGKeyArray): ) -@pytest.mark.fast +@pytest.mark.fast() class TestRegTICost: @pytest.mark.parametrize( @@ -192,7 +189,7 @@ def test_stronger_regularization_increases_sparsity( @pytest.mark.skipif(ts_metrics is None, reason="Not supported for Python 3.11") -@pytest.mark.fast +@pytest.mark.fast() class TestSoftDTW: @pytest.mark.parametrize("n", [11, 16]) @@ -210,7 +207,7 @@ def test_soft_dtw( np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) - @pytest.mark.parametrize("debiased,jit", [(False, True), (True, False)]) + @pytest.mark.parametrize(("debiased", "jit"), [(False, True), (True, False)]) def test_soft_dtw_debiased( self, rng: jax.random.PRNGKeyArray, @@ -237,7 +234,7 @@ def test_soft_dtw_debiased( np.testing.assert_allclose(cost_fn(t1, t1), 0.0, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(cost_fn(t2, t2), 0.0, rtol=1e-6, atol=1e-6) - @pytest.mark.parametrize("debiased,jit", [(False, False), (True, True)]) + @pytest.mark.parametrize(("debiased", "jit"), [(False, False), (True, True)]) @pytest.mark.parametrize("gamma", [1e-2, 1]) def test_soft_dtw_grad( self, rng: jax.random.PRNGKeyArray, debiased: bool, jit: bool, diff --git a/tests/geometry/graph_test.py b/tests/geometry/graph_test.py index 2ea0f653c..057120431 100644 --- a/tests/geometry/graph_test.py +++ b/tests/geometry/graph_test.py @@ -1,16 +1,27 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import time from typing import Any, Callable, Literal, Optional, Tuple, Union -import networkx as nx -import pytest -from networkx.algorithms import shortest_paths -from networkx.generators import balanced_tree, random_graphs - import jax import jax.experimental.sparse as jesp import jax.numpy as jnp +import networkx as nx import numpy as np - +import pytest +from networkx.algorithms import shortest_paths +from networkx.generators import balanced_tree, random_graphs from ott.geometry import geometry, graph from ott.math import decomposition from ott.problems.linear import linear_problem @@ -78,12 +89,13 @@ class TestGraph: @pytest.mark.parametrize("empty", [False, True]) def test_invalid_initialization(self, empty): - with pytest.raises(AssertionError, match="Please provide"): - if empty: + if empty: + with pytest.raises(AssertionError, match="Please provide"): _ = graph.Graph(graph=None, laplacian=None) - else: - G = random_graph(100) - L = random_graph(100, return_laplacian=True) + else: + G = random_graph(100) + L = random_graph(100, return_laplacian=True) + with pytest.raises(AssertionError, match="Please provide"): _ = graph.Graph(graph=G, laplacian=L) @pytest.mark.parametrize("fmt", [None, "coo"]) @@ -110,7 +122,7 @@ def test_init_laplacian(self, fmt: Optional[str]): assert geom.laplacian is L assert geom.graph is None - @pytest.mark.fast + @pytest.mark.fast() @pytest.mark.parametrize("as_laplacian", [False, True]) @pytest.mark.parametrize("fmt", [None, "coo"]) def test_pytree(self, fmt: Optional[str], as_laplacian: bool): @@ -274,7 +286,7 @@ def test_crank_nicolson_sparse_matches_dense(self, eps: float): atol=eps * 1e2, ) - @pytest.mark.parametrize("jit,normalize", [(False, True), (True, False)]) + @pytest.mark.parametrize(("jit", "normalize"), [(False, True), (True, False)]) def test_directed_graph(self, jit: bool, normalize: bool): def callback(geom: graph.Graph, @@ -318,7 +330,7 @@ def laplacian(geom: graph.Graph) -> jnp.ndarray: np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) - @pytest.mark.fast + @pytest.mark.fast() def test_factor_cache_works(self, rng: jax.random.PRNGKeyArray): def timeit(fn: Callable[[Any], Any]) -> Callable[[Any], float]: @@ -349,6 +361,7 @@ def callback(g: graph.Graph, x: jnp.ndarray) -> jnp.ndarray: assert time_cached < time_non_cached @pytest.mark.parametrize("jit", [False, True]) + @pytest.mark.skip(reason="Buggy") def test_factor_cache_unique(self, jit: bool): def callback(g: graph.Graph) -> decomposition.CholeskySolver: @@ -372,7 +385,7 @@ def callback(g: graph.Graph) -> decomposition.CholeskySolver: assert key2 in decomposition.SparseCholeskySolver._FACTOR_CACHE # Total memory allocated: 99.1MiB - @pytest.mark.fast + @pytest.mark.fast() @pytest.mark.limit_memory("200 MB") def test_sparse_graph_memory(self, rng: jax.random.PRNGKeyArray): # use a graph with some structure for Cholesky to be faster diff --git a/tests/geometry/low_rank_test.py b/tests/geometry/low_rank_test.py index c0c1e7f78..5ba0ea889 100644 --- a/tests/geometry/low_rank_test.py +++ b/tests/geometry/low_rank_test.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test Low-Rank Geometry.""" -from typing import Callable, Optional, Union, Tuple - -import pytest +from typing import Callable, Optional, Tuple, Union import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, geometry, grid, low_rank, pointcloud -@pytest.mark.fast +@pytest.mark.fast() class TestLRGeometry: def test_apply(self, rng: jax.random.PRNGKeyArray): @@ -45,7 +42,7 @@ def test_apply(self, rng: jax.random.PRNGKeyArray): rtol=1e-4 ) - @pytest.mark.parametrize("scale_cost", ['mean', 'max_cost', 'max_bound', 42.]) + @pytest.mark.parametrize("scale_cost", ["mean", "max_cost", "max_bound", 42.]) def test_conversion_pointcloud( self, rng: jax.random.PRNGKeyArray, scale_cost: Union[str, float] ): @@ -132,9 +129,8 @@ def test_add_lr_geoms( rtol=1e-4 ) - @pytest.mark.parametrize( - "scale,scale_cost,epsilon", [(0.1, "mean", None), (0.9, "max_cost", 1e-2)] - ) + @pytest.mark.parametrize(("scale", "scale_cost", "epsilon"), + [(0.1, "mean", None), (0.9, "max_cost", 1e-2)]) def test_add_lr_geoms_scale_factor( self, rng: jax.random.PRNGKeyArray, scale: float, scale_cost: str, epsilon: Optional[float] diff --git a/tests/geometry/pointcloud_test.py b/tests/geometry/pointcloud_test.py index 5cfa2cf7c..a1048c448 100644 --- a/tests/geometry/pointcloud_test.py +++ b/tests/geometry/pointcloud_test.py @@ -11,19 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for apply_cost and apply_kernel.""" from typing import Union -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, geometry, pointcloud -@pytest.mark.fast +@pytest.mark.fast() class TestPointCloudApply: def test_apply_cost_and_kernel(self, rng: jax.random.PRNGKeyArray): diff --git a/tests/geometry/scaling_cost_test.py b/tests/geometry/scaling_cost_test.py index d6c0d4d7e..2652dd905 100644 --- a/tests/geometry/scaling_cost_test.py +++ b/tests/geometry/scaling_cost_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the option to scale the cost matrix.""" from typing import Optional, Union -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr @@ -110,9 +107,9 @@ def test_online_matches_offline_pointcloud(self, scale: Union[str, float]): np.testing.assert_allclose( geom2.inv_scale_cost, geom1.inv_scale_cost, rtol=1e-4 ) - if scale == 'mean': + if scale == "mean": np.testing.assert_allclose(1.0, geom1.cost_matrix.mean(), rtol=1e-4) - elif scale == 'max_cost': + elif scale == "max_cost": np.testing.assert_allclose(1.0, geom1.cost_matrix.max(), rtol=1e-4) @pytest.mark.fast.with_args( @@ -184,9 +181,9 @@ def apply_sinkhorn(cost1, cost2, scale_cost): rtol=1e-4 ) - if scale == 'mean': + if scale == "mean": np.testing.assert_allclose(1.0, geom.cost_matrix.mean(), rtol=1e-4) - if scale == 'max_cost': + if scale == "max_cost": np.testing.assert_allclose(1.0, geom.cost_matrix.max(), rtol=1e-4) @pytest.mark.parametrize("batch_size", [5, 12]) @@ -194,7 +191,7 @@ def test_max_scale_cost_low_rank_with_batch(self, batch_size: int): """Test max_cost options for low rank with batch_size fixed.""" geom0 = low_rank.LRCGeometry( - self.cost1, self.cost2, scale_cost='max_cost', batch_size=batch_size + self.cost1, self.cost2, scale_cost="max_cost", batch_size=batch_size ) np.testing.assert_allclose( @@ -209,7 +206,7 @@ def test_max_scale_cost_low_rank_large_array(self): cost2 = jax.random.uniform(rngs[1], (11000, 2)) max_cost_lr = jnp.max(jnp.dot(cost1, cost2.T)) - geom0 = low_rank.LRCGeometry(cost1, cost2, scale_cost='max_cost') + geom0 = low_rank.LRCGeometry(cost1, cost2, scale_cost="max_cost") np.testing.assert_allclose( geom0.inv_scale_cost, 1.0 / max_cost_lr, rtol=1e-4 diff --git a/tests/geometry/subsetting_test.py b/tests/geometry/subsetting_test.py index 0b44f0f18..9edb04246 100644 --- a/tests/geometry/subsetting_test.py +++ b/tests/geometry/subsetting_test.py @@ -1,11 +1,22 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional, Sequence, Tuple, Type, Union -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, low_rank, pointcloud Geom_t = Union[pointcloud.PointCloud, geometry.Geometry, low_rank.LRCGeometry] @@ -45,7 +56,7 @@ def geom_masked(request, pc_masked) -> Tuple[Geom_t, pointcloud.PointCloud]: return geom, masked -@pytest.mark.fast +@pytest.mark.fast() class TestMaskPointCloud: @pytest.mark.parametrize("tgt_ixs", [7, jnp.arange(5)]) diff --git a/tests/initializers/linear/sinkhorn_init_test.py b/tests/initializers/linear/sinkhorn_init_test.py index 7d0abe78a..9c996e213 100644 --- a/tests/initializers/linear/sinkhorn_init_test.py +++ b/tests/initializers/linear/sinkhorn_init_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Sinkhorn initializers.""" -from typing import Any, Literal, Optional - -import pytest +from typing import Literal, Optional import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as linear_init from ott.initializers.nn import initializers as nn_init @@ -86,13 +83,11 @@ def run_sinkhorn( x: jnp.ndarray, y: jnp.ndarray, *, - initializer: linear_init.SinkhornInitializer = linear_init - .DefaultInitializer(), + initializer: linear_init.SinkhornInitializer, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, epsilon: float = 1e-2, lse_mode: bool = True, - **kwargs: Any ) -> sinkhorn.SinkhornOutput: """Runs Sinkhorn algorithm with given initializer.""" @@ -102,7 +97,7 @@ def run_sinkhorn( return solver(prob) -@pytest.mark.fast +@pytest.mark.fast() class TestSinkhornInitializers: @pytest.mark.parametrize( @@ -112,10 +107,9 @@ class TestSinkhornInitializers: ] ) def test_create_initializer(self, init: str): + kwargs_init = {} if init == "subsample": - kwargs_init = dict(subsample_n_x=10) - else: - kwargs_init = dict() + kwargs_init["subsample_n_x"] = 10 solver = sinkhorn.Sinkhorn(initializer=init, kwargs_init=kwargs_init) expected_types = { @@ -135,9 +129,9 @@ def test_create_initializer(self, init: str): expected_type = expected_types[init] assert isinstance(actual, expected_type) - @pytest.mark.parametrize( - "vector_min, lse_mode", [(True, True), (True, False), (False, True)] - ) + @pytest.mark.parametrize(("vector_min", "lse_mode"), [(True, True), + (True, False), + (False, True)]) def test_sorting_init(self, vector_min: bool, lse_mode: bool): """Tests sorting dual initializer.""" rng = jax.random.PRNGKey(42) @@ -149,6 +143,7 @@ def test_sorting_init(self, vector_min: bool, lse_mode: bool): sink_out_base = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer=linear_init.DefaultInitializer(), a=ot_problem.a, b=ot_problem.b, epsilon=epsilon @@ -227,7 +222,7 @@ def test_gauss_pointcloud_geom(self, rng: jax.random.PRNGKeyArray): with pytest.raises(AssertionError, match=r"pointcloud"): gaus_init.init_dual_a(ot_problem, lse_mode=True) - @pytest.mark.parametrize('lse_mode', [True, False]) + @pytest.mark.parametrize("lse_mode", [True, False]) @pytest.mark.parametrize("jit", [False, True]) @pytest.mark.parametrize("initializer", ["sorting", "gaussian", "subsample"]) def test_initializer_n_iter( @@ -255,15 +250,15 @@ def test_initializer_n_iter( rng, n, m, d, epsilon=epsilon, batch_size=3 ) + run_fn = run_sinkhorn if jit: - run_fn = jax.jit(run_sinkhorn, static_argnames=["lse_mode"]) - else: - run_fn = run_sinkhorn + run_fn = jax.jit(run_fn, static_argnames=["lse_mode"]) # run sinkhorn default_out = run_fn( x=ot_problem.geom.x, y=ot_problem.geom.y, + initializer=linear_init.DefaultInitializer(), a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, @@ -287,7 +282,7 @@ def test_initializer_n_iter( else: assert default_out.n_iters >= init_out.n_iters - @pytest.mark.parametrize('lse_mode', [True, False]) + @pytest.mark.parametrize("lse_mode", [True, False]) def test_meta_initializer(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): """Tests Meta initializer""" n, m, d = 200, 200, 2 @@ -302,7 +297,7 @@ def test_meta_initializer(self, rng: jax.random.PRNGKeyArray, lse_mode: bool): sink_out = run_sinkhorn( x=ot_problem.geom.x, y=ot_problem.geom.y, - initializer="default", + initializer=linear_init.DefaultInitializer(), a=ot_problem.a, b=ot_problem.b, epsilon=epsilon, diff --git a/tests/initializers/linear/sinkhorn_lr_init_test.py b/tests/initializers/linear/sinkhorn_lr_init_test.py index c85ff8e8f..63ce10fc3 100644 --- a/tests/initializers/linear/sinkhorn_lr_init_test.py +++ b/tests/initializers/linear/sinkhorn_lr_init_test.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Sinkhorn initializers.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.initializers.linear import initializers_lr from ott.problems.linear import linear_problem diff --git a/tests/initializers/quadratic/gw_init_test.py b/tests/initializers/quadratic/gw_init_test.py index 1b111f1ab..b9aef30d8 100644 --- a/tests/initializers/quadratic/gw_init_test.py +++ b/tests/initializers/quadratic/gw_init_test.py @@ -11,13 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gromov-Wasserstein initializers.""" - -import pytest - import jax import numpy as np - +import pytest from ott.geometry import geometry, pointcloud from ott.initializers.linear import initializers as lin_init from ott.initializers.linear import initializers_lr diff --git a/tests/math/lse_test.py b/tests/math/lse_test.py index 5f8665807..9a01d54cd 100644 --- a/tests/math/lse_test.py +++ b/tests/math/lse_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the jvp of a custom implementation of lse.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.math import utils as mu -@pytest.mark.fast +@pytest.mark.fast() class TestGeometryLse: def test_lse(self, rng: jax.random.PRNGKeyArray): diff --git a/tests/math/matrix_square_root_test.py b/tests/math/matrix_square_root_test.py index 71e46c1c8..fd4584c18 100644 --- a/tests/math/matrix_square_root_test.py +++ b/tests/math/matrix_square_root_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for matrix square roots.""" from typing import Any, Callable -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.math import matrix_square_root @@ -116,7 +113,7 @@ def test_sqrtm(self): atol=1e-2 ) - @pytest.mark.fast + @pytest.mark.fast() def test_sqrtm_batch(self): """Check sqrtm on larger of matrices.""" batch_dim0 = 2 @@ -152,7 +149,7 @@ def test_sqrtm_batch(self): ) # requires Schur decomposition, which jax does not implement on GPU - @pytest.mark.cpu + @pytest.mark.cpu() def test_solve_bartels_stewart(self): x = matrix_square_root.solve_sylvester_bartels_stewart( a=self.a[0], b=self.b[0], c=self.c[0] @@ -160,7 +157,7 @@ def test_solve_bartels_stewart(self): np.testing.assert_allclose(self.x[0], x, atol=1.e-5) # requires Schur decomposition, which jax does not implement on GPU - @pytest.mark.cpu + @pytest.mark.cpu() def test_solve_bartels_stewart_batch(self): x = matrix_square_root.solve_sylvester_bartels_stewart( a=self.a, b=self.b, c=self.c @@ -176,7 +173,7 @@ def test_solve_bartels_stewart_batch(self): np.testing.assert_allclose(self.x, x[0, 0], atol=1.e-5) # requires Schur decomposition, which jax does not implement on GPU - @pytest.mark.cpu + @pytest.mark.cpu() @pytest.mark.fast.with_args( "fn,n_tests,dim,epsilon,atol,rtol", [(lambda x: matrix_square_root.sqrtm(x)[0], 3, 3, 1e-6, 1e-6, 1e-6), @@ -190,9 +187,10 @@ def test_solve_bartels_stewart_batch(self): ], only_fast=-1, ) + @pytest.mark.usefixtures("enable_x64") def test_grad( - self, enable_x64, fn: Callable, n_tests: int, dim: int, epsilon: float, - atol: float, rtol: float + self, fn: Callable, n_tests: int, dim: int, epsilon: float, atol: float, + rtol: float ): rng = self.rng for _ in range(n_tests): diff --git a/tests/problems/linear/potentials_test.py b/tests/problems/linear/potentials_test.py index b09fdc502..a67876287 100644 --- a/tests/problems/linear/potentials_test.py +++ b/tests/problems/linear/potentials_test.py @@ -1,9 +1,20 @@ -import pytest - +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem, potentials from ott.solvers.linear import sinkhorn @@ -134,7 +145,8 @@ def test_entropic_potentials_sqpnorm( div_0 = sdiv(x, y).divergence mult = .1 if p > 1.0 else .25 - assert div < mult * div_0 # check we have moved points much closer to target. + # check we have moved points much closer to target + assert div < mult * div_0 @pytest.mark.fast.with_args( p=[1.45, 2.2, 1.0], forward=[False, True], only_fast=0 @@ -176,7 +188,8 @@ def test_entropic_potentials_pnorm( div = sdiv(x, z).divergence div_0 = sdiv(x, y).divergence - assert div < .1 * div_0 # check we have moved points much closer to target. + # check we have moved points much closer to target + assert div < .1 * div_0 @pytest.mark.parametrize("jit", [False, True]) def test_distance_differentiability( diff --git a/tests/solvers/linear/continuous_barycenter_test.py b/tests/solvers/linear/continuous_barycenter_test.py index e1a336f9a..aaf243cc8 100644 --- a/tests/solvers/linear/continuous_barycenter_test.py +++ b/tests/solvers/linear/continuous_barycenter_test.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for continuous barycenter.""" import functools from typing import Tuple -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, segment from ott.problems.linear import barycenter_problem from ott.solvers.linear import continuous_barycenter as cb diff --git a/tests/solvers/linear/discrete_barycenter_test.py b/tests/solvers/linear/discrete_barycenter_test.py index e4d5c6e36..0bb94bc6d 100644 --- a/tests/solvers/linear/discrete_barycenter_test.py +++ b/tests/solvers/linear/discrete_barycenter_test.py @@ -11,22 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import pytest - import jax.numpy as jnp - +import pytest from ott.geometry import grid, pointcloud -from ott.solvers.linear import discrete_barycenter as db from ott.problems.linear import barycenter_problem as bp +from ott.solvers.linear import discrete_barycenter as db class TestDiscreteBarycenter: - @pytest.mark.parametrize( - "lse_mode,debiased,epsilon", [(True, True, 1e-2), (False, False, 2e-2)], - ids=["lse-deb", 'scal-no-deb'] - ) + @pytest.mark.parametrize(("lse_mode", "debiased", "epsilon"), + [(True, True, 1e-2), (False, False, 2e-2)], + ids=["lse-deb", "scal-no-deb"]) def test_discrete_barycenter_grid( self, lse_mode: bool, debiased: bool, epsilon: float ): @@ -60,13 +56,13 @@ def test_discrete_barycenter_grid( bar, errors = out.histogram, out.errors assert bar[(jnp.prod(size) - 1) // 2] > 0.7 - assert 1 > bar[(jnp.prod(size) - 1) // 2] + assert bar[(jnp.prod(size) - 1) // 2] < 1 err = errors[jnp.isfinite(errors)][-1] assert threshold > err - @pytest.mark.parametrize( - "lse_mode,epsilon", [(True, 1e-3), (False, 1e-2)], ids=["lse", "scale"] - ) + @pytest.mark.parametrize(("lse_mode", "epsilon"), [(True, 1e-3), + (False, 1e-2)], + ids=["lse", "scale"]) def test_discrete_barycenter_pointcloud(self, lse_mode: bool, epsilon: float): """Tests the discrete barycenters on pointclouds. diff --git a/tests/solvers/linear/sinkhorn_diff_test.py b/tests/solvers/linear/sinkhorn_diff_test.py index 3abe2cd83..8512dcdae 100644 --- a/tests/solvers/linear/sinkhorn_diff_test.py +++ b/tests/solvers/linear/sinkhorn_diff_test.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the differentiability of reg_ot_cost w.r.t weights/locations.""" import functools from typing import Callable, List, Optional, Tuple -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, geometry, grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import implicit_differentiation as implicit_lib @@ -44,9 +41,8 @@ def initialize(self, rng: jax.random.PRNGKeyArray): self.a = a / jnp.sum(a) self.b = b / jnp.sum(b) - @pytest.mark.parametrize( - "lse_mode,threshold,pcg", [(False, 1e-6, False), (True, 1e-4, True)] - ) + @pytest.mark.parametrize(("lse_mode", "threshold", "pcg"), + [(False, 1e-6, False), (True, 1e-4, True)]) def test_implicit_differentiation_versus_autodiff( self, lse_mode: bool, threshold: float, pcg: bool ): @@ -182,9 +178,8 @@ def reg_ot(a: jnp.ndarray, b: jnp.ndarray) -> float: atol=1e-02 ) - @pytest.mark.parametrize( - "lse_mode,shape_data", [(True, (7, 9)), (False, (11, 5))] - ) + @pytest.mark.parametrize(("lse_mode", "shape_data"), [(True, (7, 9)), + (False, (11, 5))]) def test_gradient_sinkhorn_geometry( self, rng: jax.random.PRNGKeyArray, lse_mode: bool, shape_data: Tuple[int, int] @@ -337,7 +332,7 @@ def reg_ot_cost(c: jnp.ndarray) -> float: gradient = jax.grad(reg_ot_cost)(cost) np.testing.assert_array_equal(jnp.isnan(gradient), False) - @pytest.mark.fast + @pytest.mark.fast() def test_differentiability_with_jit(self, rng: jax.random.PRNGKeyArray): def reg_ot_cost(c: jnp.ndarray) -> float: @@ -541,7 +536,7 @@ def loss_from_potential(a: jnp.ndarray, x: jnp.ndarray, implicit: bool): np.testing.assert_allclose(g_imp, g_back, atol=5e-2, rtol=1e-2) -@pytest.mark.fast +@pytest.mark.fast() class TestSinkhornGradGrid: @pytest.mark.parametrize("lse_mode", [False, True]) @@ -803,7 +798,7 @@ def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True): if test_back: dif_norm = jnp.sum(jnp.abs(hess_imp - hess_back)) rel_dif_norm = dif_norm / jnp.sum(jnp.abs(hess_imp)) - assert 0.1 > rel_dif_norm + assert rel_dif_norm < 0.1 eps = 1e-3 for impl in [True, False] if test_back else [True]: diff --git a/tests/solvers/linear/sinkhorn_grid_test.py b/tests/solvers/linear/sinkhorn_grid_test.py index 764a6c3fd..abeeb9f88 100644 --- a/tests/solvers/linear/sinkhorn_grid_test.py +++ b/tests/solvers/linear/sinkhorn_grid_test.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Sinkhorn when applied on a grid.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -123,7 +119,7 @@ def test_apply_transport_grid( ) np.testing.assert_array_equal(jnp.isnan(mat_transport_t_vec_a), False) - @pytest.mark.fast + @pytest.mark.fast() def test_apply_cost(self, rng: jax.random.PRNGKeyArray): grid_size = (5, 6, 7) diff --git a/tests/solvers/linear/sinkhorn_lr_test.py b/tests/solvers/linear/sinkhorn_lr_test.py index f2ba405eb..7cda2bb62 100644 --- a/tests/solvers/linear/sinkhorn_lr_test.py +++ b/tests/solvers/linear/sinkhorn_lr_test.py @@ -11,13 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests Sinkhorn Low-Rank solver with various initializations.""" -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import low_rank, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn_lr diff --git a/tests/solvers/linear/sinkhorn_misc_test.py b/tests/solvers/linear/sinkhorn_misc_test.py index cb48c1c87..2787ca73f 100644 --- a/tests/solvers/linear/sinkhorn_misc_test.py +++ b/tests/solvers/linear/sinkhorn_misc_test.py @@ -11,21 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests Anderson acceleration for Sinkhorn.""" from typing import Optional, Tuple import chex -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, geometry, pointcloud from ott.problems.linear import linear_problem -from ott.solvers.linear import acceleration +from ott.solvers.linear import acceleration, sinkhorn from ott.solvers.linear import implicit_differentiation as implicit_lib -from ott.solvers.linear import sinkhorn class TestSinkhornAnderson: @@ -100,7 +96,7 @@ def test_anderson( assert iterations_anderson[0] > iterations_anderson[i] -@pytest.mark.fast +@pytest.mark.fast() class TestSinkhornBures: @pytest.fixture(autouse=True) @@ -132,7 +128,8 @@ def initialize(self): self.b = b / jnp.sum(b) @pytest.mark.parametrize("lse_mode", [False, True]) - @pytest.mark.parametrize("unbalanced,thresh", [(False, 1e-3), (True, 1e-4)]) + @pytest.mark.parametrize(("unbalanced", "thresh"), [(False, 1e-3), + (True, 1e-4)]) def test_bures_point_cloud( self, rng: jax.random.PRNGKeyArray, lse_mode: bool, unbalanced: bool, thresh: float @@ -235,7 +232,7 @@ def callback(epsilon: float, batch_size: int) -> sinkhorn.SinkhornOutput: assert threshold > err -@pytest.mark.fast +@pytest.mark.fast() class TestSinkhornUnbalanced: @pytest.fixture(autouse=True) @@ -343,7 +340,7 @@ def initialize(self, rng: jax.random.PRNGKeyArray): epsilon=self.epsilon ) - @pytest.mark.fast + @pytest.mark.fast() def test_jit_vs_non_jit_fwd(self): def assert_output_close( diff --git a/tests/solvers/linear/sinkhorn_test.py b/tests/solvers/linear/sinkhorn_test.py index 0a432121d..bf0ff23d7 100644 --- a/tests/solvers/linear/sinkhorn_test.py +++ b/tests/solvers/linear/sinkhorn_test.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Sinkhorn.""" from typing import Any, Optional, Tuple -import pytest - import jax import jax.numpy as jnp import numpy as np - -from ott.geometry import costs, geometry, grid, pointcloud, epsilon_scheduler +import pytest +from ott.geometry import costs, epsilon_scheduler, geometry, grid, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import acceleration, sinkhorn @@ -171,7 +168,7 @@ def test_autoepsilon_with_decay( f_1, f_2 = out_1.f, out_2.f np.testing.assert_allclose(f_1, f_2, rtol=1e-4, atol=1e-4) - @pytest.mark.fast + @pytest.mark.fast() def test_euclidean_point_cloud_min_iter(self): """Testing the min_iterations parameter.""" threshold = 1e-3 @@ -455,7 +452,7 @@ def test_restart(self, lse_mode: bool): # check only one iteration suffices when restarting with same data. assert num_iter_restarted == 1 - @pytest.mark.cpu + @pytest.mark.cpu() @pytest.mark.limit_memory("110 MB") @pytest.mark.fast.with_args("batch_size", [500, 1000], only_fast=0) def test_sinkhorn_online_memory_jit(self, batch_size: int): diff --git a/tests/solvers/nn/icnn_test.py b/tests/solvers/nn/icnn_test.py index afa517a86..131d17b74 100644 --- a/tests/solvers/nn/icnn_test.py +++ b/tests/solvers/nn/icnn_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ICNN network architecture.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.solvers.nn import models -@pytest.mark.fast +@pytest.mark.fast() class TestICNN: def test_icnn_convexity(self, rng: jax.random.PRNGKeyArray): @@ -35,18 +31,18 @@ def test_icnn_convexity(self, rng: jax.random.PRNGKeyArray): # initialize model rng1, rng2, rng3 = jax.random.split(rng, 3) - params = model.init(rng1, jnp.ones(n_features))['params'] + params = model.init(rng1, jnp.ones(n_features))["params"] # check convexity x = jax.random.normal(rng1, (n_samples, n_features)) * 0.1 y = jax.random.normal(rng2, (n_samples, n_features)) - out_x = model.apply({'params': params}, x) - out_y = model.apply({'params': params}, y) + out_x = model.apply({"params": params}, x) + out_y = model.apply({"params": params}, y) - out = list() + out = [] for t in jnp.linspace(0, 1): - out_xy = model.apply({'params': params}, t * x + (1 - t) * y) + out_xy = model.apply({"params": params}, t * x + (1 - t) * y) out.append((t * out_x + (1 - t) * out_y) - out_xy) np.testing.assert_array_equal(jnp.asarray(out) >= 0, True) @@ -61,13 +57,13 @@ def test_icnn_hessian(self, rng: jax.random.PRNGKeyArray): # initialize model rng1, rng2 = jax.random.split(rng) - params = model.init(rng1, jnp.ones(n_features))['params'] + params = model.init(rng1, jnp.ones(n_features))["params"] # check if Hessian is positive-semidefinite via eigenvalues data = jax.random.normal(rng2, (n_features,)) # compute Hessian - hessian = jax.hessian(model.apply, argnums=1)({'params': params}, data) + hessian = jax.hessian(model.apply, argnums=1)({"params": params}, data) # compute eigenvalues w = jnp.linalg.eigvalsh((hessian + hessian.T) / 2.0) diff --git a/tests/solvers/nn/neuraldual_test.py b/tests/solvers/nn/neuraldual_test.py index 8bdc2409c..bec9db919 100644 --- a/tests/solvers/nn/neuraldual_test.py +++ b/tests/solvers/nn/neuraldual_test.py @@ -1,4 +1,3 @@ -# # Copyright OTT-JAX # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,14 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for implementation of ICNN-based Kantorovich dual by Makkuva+(2020).""" from typing import Sequence, Tuple -import pytest - import jax import numpy as np - +import pytest from ott.problems.nn import dataset from ott.solvers.nn import models, neuraldual @@ -37,20 +33,19 @@ def datasets(request: Tuple[str, str]) -> DatasetPair_t: @pytest.fixture(params=["icnns", "mlps", "mlps-grad"]) def neural_models(request: str) -> ModelPair_t: - if request.param == 'icnns': + if request.param == "icnns": return ( models.ICNN(dim_data=2, dim_hidden=[128]), models.ICNN(dim_data=2, dim_hidden=[128]) ) - elif request.param == 'mlps': - return (models.MLP(dim_hidden=[128]), models.MLP(dim_hidden=[128])) - elif request.param == 'mlps-grad': + if request.param == "mlps": + return models.MLP(dim_hidden=[128]), models.MLP(dim_hidden=[128]), + if request.param == "mlps-grad": return ( models.MLP(dim_hidden=[128]), models.MLP(is_potential=False, dim_hidden=[128]) ) - else: - raise ValueError(f'Invalid request: {request.param}') + raise ValueError(f"Invalid request: {request.param}") class TestNeuralDual: @@ -85,8 +80,8 @@ def decreasing(losses: Sequence[float]) -> bool: neural_dual, logs = neural_dual_solver(*train_dataset, *valid_dataset) # check if training loss of f is increasing and g is decreasing - assert increasing(logs['train_logs']['loss_f']) - assert decreasing(logs['train_logs']['loss_g']) + assert increasing(logs["train_logs"]["loss_f"]) + assert decreasing(logs["train_logs"]["loss_g"]) def test_neural_dual_jit(self, datasets: DatasetPair_t): num_train_iters = 10 diff --git a/tests/solvers/quadratic/fgw_test.py b/tests/solvers/quadratic/fgw_test.py index 55ce18891..380f68fbf 100644 --- a/tests/solvers/quadratic/fgw_test.py +++ b/tests/solvers/quadratic/fgw_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the Fused Gromov Wasserstein.""" from typing import Literal, Tuple, Union -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import implicit_differentiation as implicit_lib @@ -97,10 +94,9 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, implicit: bool): np.testing.assert_allclose(g_a, gi_a, rtol=1e-02, atol=1e-02) np.testing.assert_allclose(g_b, gi_b, rtol=1e-02, atol=1e-02) - @pytest.mark.parametrize( - "lse_mode,is_cost", [(True, False), (False, True)], - ids=["lse-pc", "kernel-cost-mat"] - ) + @pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False), + (False, True)], + ids=["lse-pc", "kernel-cost-mat"]) def test_gradient_fgw_solver_geometry(self, lse_mode: bool, is_cost: bool): """Test gradient w.r.t. the geometries.""" diff --git a/tests/solvers/quadratic/gw_barycenter_test.py b/tests/solvers/quadratic/gw_barycenter_test.py index be5d71e3b..d16892c3c 100644 --- a/tests/solvers/quadratic/gw_barycenter_test.py +++ b/tests/solvers/quadratic/gw_barycenter_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Gromov-Wasserstein barycenter.""" from typing import Any, Optional, Sequence, Tuple -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import pointcloud from ott.problems.quadratic import gw_barycenter as gwb from ott.solvers.quadratic import gw_barycenter as gwb_solver @@ -63,7 +60,7 @@ def pad_cost_matrices( # TODO(cuturi) add back KL test when KL cost GW is fixed. @pytest.mark.parametrize( - "gw_loss,bar_size,epsilon", + ("gw_loss", "bar_size", "epsilon"), [("sqeucl", 17, None)] # , ("kl", 22, 1e-2)] ) def test_gw_barycenter( diff --git a/tests/solvers/quadratic/gw_test.py b/tests/solvers/quadratic/gw_test.py index a9db95902..b7d97a32c 100644 --- a/tests/solvers/quadratic/gw_test.py +++ b/tests/solvers/quadratic/gw_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the Gromov Wasserstein.""" from typing import Tuple, Union -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import geometry, low_rank, pointcloud from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import implicit_differentiation as implicit_lib @@ -27,7 +24,7 @@ from ott.solvers.quadratic import gromov_wasserstein -@pytest.mark.fast +@pytest.mark.fast() class TestQuadraticProblem: @pytest.mark.parametrize("as_pc", [False, True]) @@ -198,10 +195,9 @@ def reg_gw(a: jnp.ndarray, b: jnp.ndarray, grad_matrices[0][1], grad_matrices[1][1], rtol=1e-02, atol=1e-02 ) - @pytest.mark.fast - @pytest.mark.parametrize( - "balanced,rank", [(True, -1), (False, -1), (True, 3)] - ) + @pytest.mark.fast() + @pytest.mark.parametrize(("balanced", "rank"), [(True, -1), (False, -1), + (True, 3)]) def test_gw_pointcloud(self, balanced: bool, rank: int): """Test basic computations pointclouds.""" geom_x = pointcloud.PointCloud(self.x) @@ -228,15 +224,12 @@ def test_gw_pointcloud(self, balanced: bool, rank: int): assert not jnp.isnan(out.reg_gw_cost) - @pytest.mark.parametrize( - "unbalanced,unbalanced_correction", [(False, False), (True, False), - (True, True)], - ids=["bal", "unbal-nocorr", "unbal-corr"] - ) - @pytest.mark.parametrize( - "lse_mode,is_cost", [(True, False), (False, True)], - ids=["lse-pc", "kernel-cost-mat"] - ) + @pytest.mark.parametrize(("unbalanced", "unbalanced_correction"), + [(False, False), (True, False), (True, True)], + ids=["bal", "unbal-nocorr", "unbal-corr"]) + @pytest.mark.parametrize(("lse_mode", "is_cost"), [(True, False), + (False, True)], + ids=["lse-pc", "kernel-cost-mat"]) def test_gradient_gw_geometry( self, lse_mode: bool, is_cost: bool, unbalanced: bool, unbalanced_correction: bool @@ -308,7 +301,7 @@ def loss_thre(threshold: float) -> float: assert loss_thre(1e-1) >= loss_thre(1e-4) assert loss_thre(1e-3) >= loss_thre(1e-5) - @pytest.mark.fast + @pytest.mark.fast() def test_gw_lr(self, rng: jax.random.PRNGKeyArray): """Checking LR and Entropic have similar outputs on same problem.""" rngs = jax.random.split(rng, 4) @@ -355,8 +348,8 @@ def test_gw_lr_matches_fused(self, rng: jax.random.PRNGKeyArray): ot_gw = solver(prob) # Test solutions look alike - assert 0.11 > jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix) - assert 0.15 > jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) + assert jnp.linalg.norm(ot_gwlr.matrix - ot_gw.matrix) < 0.11 + assert jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) < 0.15 # Test at least some difference when adding bigger entropic regularization assert jnp.linalg.norm(ot_gwlr.matrix - ot_gwlreps.matrix) > 1e-3 diff --git a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py index 63c3d01a9..22b1d430e 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_pair_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_pair_test.py @@ -11,13 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for fit_gmm_pair.""" - -import pytest - import jax import jax.numpy as jnp - +import pytest from ott.tools.gaussian_mixture import ( fit_gmm, fit_gmm_pair, @@ -66,16 +62,13 @@ def initialize(self, rng: jax.random.PRNGKeyArray): self.samples_gmm1 = gmm_generator1.sample(rng=subrng1, size=2000) # requires Schur decomposition, which jax does not implement on GPU - @pytest.mark.cpu + @pytest.mark.cpu() @pytest.mark.fast.with_args( balanced=[False, True], weighted=[False, True], only_fast=0 ) def test_fit_gmm(self, balanced, weighted): # dumb integration test that makes sure nothing crashes - if balanced: - tau = 1. - else: - tau = self.tau + tau = 1.0 if balanced else self.tau if weighted: weights0 = jnp.ones(self.samples_gmm0.shape[0]) diff --git a/tests/tools/gaussian_mixture/fit_gmm_test.py b/tests/tools/gaussian_mixture/fit_gmm_test.py index 647e4f7ff..4ef662d5d 100644 --- a/tests/tools/gaussian_mixture/fit_gmm_test.py +++ b/tests/tools/gaussian_mixture/fit_gmm_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for fit_gmm_pair.""" - -import pytest - import jax import jax.numpy as jnp import jax.test_util - +import pytest from ott.tools.gaussian_mixture import fit_gmm, gaussian_mixture -@pytest.mark.fast +@pytest.mark.fast() class TestFitGmm: @pytest.fixture(autouse=True) diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py index ba81cdcba..3865cbd86 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_pair_test.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for gaussian_mixture_pair.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.tools.gaussian_mixture import gaussian_mixture, gaussian_mixture_pair @@ -71,7 +67,7 @@ def test_get_sinkhorn_to_same_gmm_is_almost_zero(self): np.testing.assert_almost_equal(cost, 0.00, decimal=2) - @pytest.mark.fast + @pytest.mark.fast() def test_get_sinkhorn_to_shifted_is_almost_shift(self): loc_shift = jnp.stack([ 2. * jnp.ones(self.n_components), @@ -92,7 +88,7 @@ def test_get_sinkhorn_to_shifted_is_almost_shift(self): np.testing.assert_approx_equal(cost, 4.0, significant=2) - @pytest.mark.fast + @pytest.mark.fast() def test_get_coupling_between_same_gmm(self): gmm = self.gmm0 pair = gaussian_mixture_pair.GaussianMixturePair( diff --git a/tests/tools/gaussian_mixture/gaussian_mixture_test.py b/tests/tools/gaussian_mixture/gaussian_mixture_test.py index 1fdfbd8db..4a23b4b1c 100644 --- a/tests/tools/gaussian_mixture/gaussian_mixture_test.py +++ b/tests/tools/gaussian_mixture/gaussian_mixture_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for gaussian_mixture.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.tools.gaussian_mixture import gaussian_mixture, linalg -@pytest.mark.fast +@pytest.mark.fast() class TestGaussianMixture: def test_get_summary_stats_from_points_and_assignment_probs( diff --git a/tests/tools/gaussian_mixture/gaussian_test.py b/tests/tools/gaussian_mixture/gaussian_test.py index b09c9c8fe..2c412cc32 100644 --- a/tests/tools/gaussian_mixture/gaussian_test.py +++ b/tests/tools/gaussian_mixture/gaussian_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for gaussian.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.tools.gaussian_mixture import gaussian, scale_tril -@pytest.mark.fast +@pytest.mark.fast() class TestGaussian: def test_from_random(self, rng: jax.random.PRNGKeyArray): @@ -100,8 +96,8 @@ def test_w2_dist(self, rng: jax.random.PRNGKeyArray): np.testing.assert_almost_equal(w2, 0., decimal=5) # When covariances commute (e.g. if covariance is diagonal), have - # distance between covariances = frobenius norm^2 of (delta cholesky) - # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # pylint: disable=line-too-long + # distance between covariances = frobenius norm^2 of (delta cholesky), see + # https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # noqa: E501 size = 4 rng, subrng0, subrng1 = jax.random.split(rng, num=3) loc0 = jax.random.normal(key=subrng0, shape=(size,)) diff --git a/tests/tools/gaussian_mixture/linalg_test.py b/tests/tools/gaussian_mixture/linalg_test.py index 927d8f849..cf9af762e 100644 --- a/tests/tools/gaussian_mixture/linalg_test.py +++ b/tests/tools/gaussian_mixture/linalg_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for linalg.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.tools.gaussian_mixture import linalg -@pytest.mark.fast +@pytest.mark.fast() class TestLinalg: def test_get_mean_and_var(self, rng: jax.random.PRNGKeyArray): diff --git a/tests/tools/gaussian_mixture/probabilities_test.py b/tests/tools/gaussian_mixture/probabilities_test.py index 65dd727c1..d6db75008 100644 --- a/tests/tools/gaussian_mixture/probabilities_test.py +++ b/tests/tools/gaussian_mixture/probabilities_test.py @@ -11,18 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for probabilities.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.tools.gaussian_mixture import probabilities -@pytest.mark.fast +@pytest.mark.fast() class TestProbabilities: def test_probs(self): diff --git a/tests/tools/gaussian_mixture/scale_tril_test.py b/tests/tools/gaussian_mixture/scale_tril_test.py index 6aa28b4ae..eb42cdbcb 100644 --- a/tests/tools/gaussian_mixture/scale_tril_test.py +++ b/tests/tools/gaussian_mixture/scale_tril_test.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for ScaleTriL.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.math import matrix_square_root from ott.tools.gaussian_mixture import scale_tril @@ -29,7 +25,7 @@ def chol() -> scale_tril.ScaleTriL: return scale_tril.ScaleTriL(params=params, size=2) -@pytest.mark.fast +@pytest.mark.fast() class TestScaleTriL: def test_cholesky(self, chol: scale_tril.ScaleTriL): @@ -76,8 +72,8 @@ def test_w2_dist(self, rng: jax.random.PRNGKeyArray): np.testing.assert_allclose(expected, w2, atol=1e-4, rtol=1e-4) # When covariances commute (e.g. if covariance is diagonal), have - # distance between covariances = Frobenius norm^2 of (delta sqrt(cov)) - # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # pylint: disable=line-too-long + # distance between covariances = Frobenius norm^2 of (delta sqrt(cov)), see + # see https://djalil.chafai.net/blog/2010/04/30/wasserstein-distance-between-two-gaussians/ # noqa: E501 size = 4 rng, subrng0, subrng1 = jax.random.split(rng, num=3) diag0 = jnp.exp(jax.random.normal(key=subrng0, shape=(size,))) diff --git a/tests/tools/k_means_test.py b/tests/tools/k_means_test.py index 82fd31a30..91c56839f 100644 --- a/tests/tools/k_means_test.py +++ b/tests/tools/k_means_test.py @@ -1,32 +1,43 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import os import sys from typing import Any, Literal, Optional, Tuple, Union -import pytest - import jax import jax.numpy as jnp import numpy as np +import pytest +from ott.geometry import costs, pointcloud +from ott.tools import k_means from sklearn import datasets from sklearn.cluster import KMeans, kmeans_plusplus from sklearn.cluster._k_means_common import _is_same_clustering -from ott.geometry import costs, pointcloud -from ott.tools import k_means - def make_blobs( *args: Any, - cost_fn: Optional[Literal['sqeucl', 'cosine']] = None, + cost_fn: Optional[Literal["sqeucl", "cosine"]] = None, **kwargs: Any ) -> Tuple[Union[jnp.ndarray, pointcloud.PointCloud], jnp.ndarray, jnp.ndarray]: X, y, c = datasets.make_blobs(*args, return_centers=True, **kwargs) X, y, c = jnp.asarray(X), jnp.asarray(y), jnp.asarray(c) if cost_fn is None: pass - elif cost_fn == 'sqeucl': + elif cost_fn == "sqeucl": X = pointcloud.PointCloud(X, cost_fn=costs.SqEuclidean()) - elif cost_fn == 'cosine': + elif cost_fn == "cosine": X = pointcloud.PointCloud(X, cost_fn=costs.Cosine()) else: raise NotImplementedError(cost_fn) @@ -55,7 +66,7 @@ def test_n_local_trials(self, rng: jax.random.PRNGKeyArray, n_local_trials): n, k = 150, 4 rng1, rng2 = jax.random.split(rng) geom, _, c = make_blobs( - n_samples=n, centers=k, cost_fn='sqeucl', random_state=0 + n_samples=n, centers=k, cost_fn="sqeucl", random_state=0 ) centers1 = k_means._k_means_plus_plus(geom, k, rng1, n_local_trials) centers2 = k_means._k_means_plus_plus(geom, k, rng2, 20) @@ -72,7 +83,7 @@ def test_matches_sklearn(self, rng: jax.random.PRNGKeyArray, k: int): n_samples=200, centers=k, n_features=ndim, - cost_fn='sqeucl', + cost_fn="sqeucl", random_state=0 ) gt_centers, _ = kmeans_plusplus(np.asarray(geom.x), k, random_state=1) @@ -109,7 +120,7 @@ def callback(x: jnp.ndarray) -> float: class TestKmeans: - @pytest.mark.fast + @pytest.mark.fast() @pytest.mark.parametrize("k", [1, 6]) def test_k_means_output(self, rng: jax.random.PRNGKeyArray, k: int): max_iter, ndim = 10, 4 @@ -129,7 +140,7 @@ def test_k_means_output(self, rng: jax.random.PRNGKeyArray, k: int): assert res.inner_errors is None assert _is_same_clustering(pred_assignment, gt_assignment, k) - @pytest.mark.fast + @pytest.mark.fast() def test_k_means_simple_example(self): expected_labels = np.asarray([1, 1, 0, 0], dtype=np.int32) expected_centers = np.asarray([[0.75, 1], [0.25, 0]]) @@ -281,19 +292,20 @@ def test_weight_scaling_effects_only_inertia( res.error, res_scaled.error * jnp.sum(weights), rtol=1e-3, atol=1e-3 ) - @pytest.mark.fast + @pytest.mark.fast() def test_empty_weights(self, rng: jax.random.PRNGKeyArray): n, ndim, k, d = 20, 2, 3, 5. - x = np.random.normal(size=(n, ndim)) + gen = np.random.RandomState(0) + x = gen.normal(size=(n, ndim)) x[:, 0] += d x[:, 1] += d - y = np.random.normal(size=(n, ndim)) + y = gen.normal(size=(n, ndim)) y[:, 0] -= d y[:, 1] -= d - z = np.random.normal(size=(n, ndim)) + z = gen.normal(size=(n, ndim)) z[:, 0] += d z[:, 1] -= d - w = np.random.normal(size=(n, ndim)) + w = gen.normal(size=(n, ndim)) w[:, 0] -= d w[:, 1] += d x = jnp.concatenate((x, y, z, w)) @@ -354,13 +366,12 @@ def callback(x: jnp.ndarray) -> k_means.KMeansOutput: assert res.converged == res_jit.converged @pytest.mark.skipif( - sys.platform == 'darwin' and os.environ.get("CI", "false") == "true", - reason='Fails on macOS CI.' - ) - @pytest.mark.parametrize( - "jit,force_scan", [(True, False), (False, True)], - ids=["jit-while-loop", "nojit-for-loop"] + sys.platform == "darwin" and os.environ.get("CI", "false") == "true", + reason="Fails on macOS CI." ) + @pytest.mark.parametrize(("jit", "force_scan"), [(True, False), + (False, True)], + ids=["jit-while-loop", "nojit-for-loop"]) def test_k_means_differentiability( self, rng: jax.random.PRNGKeyArray, jit: bool, force_scan: bool ): @@ -399,7 +410,7 @@ def inertia(x: jnp.ndarray, w: jnp.ndarray) -> float: np.testing.assert_allclose(actual, expected, rtol=tol, atol=tol) @pytest.mark.parametrize("tol", [1e-3, 0.]) - @pytest.mark.parametrize("n,k", [(37, 4), (128, 6)]) + @pytest.mark.parametrize(("n", "k"), [(37, 4), (128, 6)]) def test_clustering_matches_sklearn( self, rng: jax.random.PRNGKeyArray, n: int, k: int, tol: float ): diff --git a/tests/tools/segment_sinkhorn_test.py b/tests/tools/segment_sinkhorn_test.py index 8a5b9e923..111af8aae 100644 --- a/tests/tools/segment_sinkhorn_test.py +++ b/tests/tools/segment_sinkhorn_test.py @@ -11,14 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for Segmented Sinkhorn.""" - -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, pointcloud from ott.problems.linear import linear_problem from ott.solvers.linear import sinkhorn @@ -102,7 +98,7 @@ def test_segment_sinkhorn_different_segment_sizes(self): sink = jax.jit( segment_sinkhorn.segment_sinkhorn, - static_argnames=['num_segments', 'max_measure_size'], + static_argnames=["num_segments", "max_measure_size"], ) segmented_regotcost = sink( jnp.concatenate((x1, x2)), @@ -167,7 +163,7 @@ def g(rng, n): cost_fn=b_cost, num_per_segment_x=num_per_segment_x, num_per_segment_y=num_per_segment_y, - sinkhorn_kwargs={'lse_mode': True}, + sinkhorn_kwargs={"lse_mode": True}, epsilon=0.1, ) np.testing.assert_allclose(segmented_reg_ot_cost, true_reg_ot_cost) diff --git a/tests/tools/sinkhorn_divergence_test.py b/tests/tools/sinkhorn_divergence_test.py index 71e5c1846..0ad7004c8 100644 --- a/tests/tools/sinkhorn_divergence_test.py +++ b/tests/tools/sinkhorn_divergence_test.py @@ -11,15 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the Sinkhorn divergence.""" from typing import Any, Dict, Optional -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.geometry import costs, geometry, pointcloud from ott.solvers.linear import acceleration, sinkhorn from ott.tools import sinkhorn_divergence @@ -82,14 +79,14 @@ def test_euclidean_point_cloud(self, cost_fn: costs.CostFn, epsilon: float): x, cost_fn=cost_fn, epsilon=1e-1, - sinkhorn_kwargs={'inner_iterations': 1}, + sinkhorn_kwargs={"inner_iterations": 1}, ) np.testing.assert_allclose(div.divergence, 0.0, rtol=1e-5, atol=1e-5) iters_xx = jnp.sum(div.errors[0] > 0) iters_xx_sym = jnp.sum(div.errors[1] > 0) assert iters_xx >= iters_xx_sym - @pytest.mark.fast + @pytest.mark.fast() def test_euclidean_autoepsilon(self): rngs = jax.random.split(self.rng, 2) cloud_a = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) @@ -140,7 +137,7 @@ def test_euclidean_point_cloud_wrapper(self, use_weights: bool): assert len(div.potentials) == 3 assert len(div.geoms) == 3 - @pytest.mark.fast + @pytest.mark.fast() def test_euclidean_point_cloud_unbalanced_wrapper(self): rngs = jax.random.split(self.rng, 2) cloud_a = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) @@ -218,8 +215,8 @@ def test_segment_sinkhorn_result(self, shuffle: bool): rngs = jax.random.split(self.rng, 4) x = jax.random.uniform(rngs[0], (self._num_points[0], self._dim)) y = jax.random.uniform(rngs[1], (self._num_points[1], self._dim)) - geom_kwargs = dict(epsilon=0.01) - sinkhorn_kwargs = dict(threshold=1e-2) + geom_kwargs = {"epsilon": 0.01} + sinkhorn_kwargs = {"threshold": 1e-2} true_divergence = sinkhorn_divergence.sinkhorn_divergence( pointcloud.PointCloud, x, @@ -281,7 +278,7 @@ def test_segment_sinkhorn_different_segment_sizes(self): sink_div = jax.jit( sinkhorn_divergence.segment_sinkhorn_divergence, - static_argnames=['num_per_segment_x', 'num_per_segment_y'], + static_argnames=["num_per_segment_x", "num_per_segment_y"], ) segmented_divergences = sink_div( @@ -336,7 +333,7 @@ def g(rng, n): x, y, sinkhorn_kwargs={ - 'lse_mode': True + "lse_mode": True }, epsilon=0.1, cost_fn=b_cost @@ -353,7 +350,7 @@ def g(rng, n): max_measure_size=5, num_per_segment_x=num_per_segment_x, num_per_segment_y=num_per_segment_y, - sinkhorn_kwargs={'lse_mode': True}, + sinkhorn_kwargs={"lse_mode": True}, epsilon=0.1, cost_fn=b_cost ) @@ -427,7 +424,7 @@ def loss_fn(cloud_a: jnp.ndarray, cloud_b: jnp.ndarray) -> float: epsilon=1.0, a=self._a, b=self._b, - sinkhorn_kwargs=dict(threshold=0.05) + sinkhorn_kwargs={"threshold": 0.05}, ) return div.divergence diff --git a/tests/tools/soft_sort_test.py b/tests/tools/soft_sort_test.py index 034198ae0..706014c40 100644 --- a/tests/tools/soft_sort_test.py +++ b/tests/tools/soft_sort_test.py @@ -11,16 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the soft sort tools.""" import functools from typing import Tuple -import pytest - import jax import jax.numpy as jnp import numpy as np - +import pytest from ott.solvers.linear import acceleration from ott.solvers.linear import implicit_differentiation as implicit_lib from ott.tools import soft_sort @@ -63,7 +60,7 @@ def test_sort_array_squashing_momentum(self, rng: jax.random.PRNGKeyArray): np.testing.assert_array_equal(jnp.diff(xs_lin, axis=0) >= -1e-8, True) np.testing.assert_array_equal(jnp.diff(xs_sig, axis=0) >= -1e-8, True) - @pytest.mark.fast + @pytest.mark.fast() @pytest.mark.parametrize("k", [-1, 4, 100]) def test_topk_one_array(self, rng: jax.random.PRNGKeyArray, k: int): n = 20 @@ -98,7 +95,7 @@ def test_rank_one_array(self, rng: jax.random.PRNGKeyArray): np.testing.assert_array_equal(x.shape, ranks.shape) np.testing.assert_allclose(ranks, expected_ranks, atol=0.9, rtol=0.1) - @pytest.mark.fast + @pytest.mark.fast() @pytest.mark.parametrize("level", [0.2, 0.9]) def test_quantile(self, level: float): x = jnp.linspace(0.0, 1.0, 100) @@ -149,7 +146,7 @@ def test_sort_with(self, rng: jax.random.PRNGKeyArray): np.testing.assert_array_equal(output.shape, [k, d]) np.testing.assert_allclose(output, inputs[-k:], atol=0.05) - @pytest.mark.fast + @pytest.mark.fast() def test_quantize(self): n = 100 inputs = jnp.linspace(0.0, 1.0, n)[..., None] diff --git a/tests/utils_test.py b/tests/utils_test.py index 7e79a93ea..768a498b5 100644 --- a/tests/utils_test.py +++ b/tests/utils_test.py @@ -1,13 +1,24 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional import pytest - from ott import utils -@pytest.mark.parametrize( - "version,msg", [(None, "foo, bar, baz"), ("quux", None)] -) +@pytest.mark.parametrize(("version", "msg"), [(None, "foo, bar, baz"), + ("quux", None)]) def test_deprecation_warning(version: Optional[str], msg: Optional[str]): @utils.deprecate(version=version, alt=msg)