From 68f9d08d29d04953acedf7eef29cc3465cde11d1 Mon Sep 17 00:00:00 2001 From: Marco Cuturi Date: Wed, 12 Jul 2023 14:48:39 -0700 Subject: [PATCH] test `Plot` class in `tools` module to increase coverage (#394) * test plot tool to increase coverage * drop deprecated jnp.DeviceArray * remove unused barycenters functions * minor fixes --- src/ott/tools/plot.py | 59 +++++++--------------------------------- src/ott/utils.py | 5 ++-- tests/tools/plot_test.py | 46 +++++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+), 52 deletions(-) create mode 100644 tests/tools/plot_test.py diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index 529406de5..ccd3b1ec3 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -17,7 +17,6 @@ import numpy as np import scipy -from ott import utils from ott.geometry import pointcloud from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein @@ -39,7 +38,9 @@ def bidimensional(x: jnp.ndarray, if x.shape[1] < 3: return x, y - u, s, _ = scipy.sparse.linalg.svds(jnp.concatenate([x, y], axis=0), k=2) + u, s, _ = scipy.sparse.linalg.svds( + np.array(jnp.concatenate([x, y], axis=0)), k=2 + ) proj = u * s k = x.shape[0] return proj[:k], proj[k:] @@ -48,14 +49,18 @@ def bidimensional(x: jnp.ndarray, class Plot: """Plot an optimal transport map between two point clouds. - It enables to either plot or update a plot in a single object, offering the - possibilities to create animations as a + This object can either plot or update a plot, to create animations as a :class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to disk at will. There are two design principles here: #. we do not rely on saving to/loading from disk to create animations #. we try as much as possible to disentangle the transport problem from its visualization. + + We use 2D scatter plots by default, relying on PCA visualization for d>3 data. + This step requires a conversion to a numpy array, in order to compute leading + singular values. This tool is therefore not designed having performance in + mind. """ def __init__( @@ -135,7 +140,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray): return result def __call__(self, ot: Transport) -> List["plt.Artist"]: - """Plot 2-D couplings. Projects via PCA if data is higher dimensional.""" + """Plot couplings in 2-D, using PCA if data is higher dimensional.""" x, y, sx, sy = self._scatter(ot) self._points_x = self.ax.scatter( *x.T, s=sx, edgecolors="k", marker="o", label="x" @@ -214,47 +219,3 @@ def animate( interval=1000 / frame_rate, blit=True ) - - -def _barycenters( - ax: "plt.Axes", - y: jnp.ndarray, - a: jnp.ndarray, - b: jnp.ndarray, - matrix: jnp.ndarray, - scale: int = 200 -) -> None: - """Plot 2-D sinkhorn barycenters.""" - sa, sb = jnp.min(a) / scale, jnp.min(b) / scale - ax.scatter(*y.T, s=b / sb, edgecolors="k", marker="X", label="y") - tx = 1 / a[:, None] * jnp.matmul(matrix, y) - ax.scatter(*tx.T, s=a / sa, edgecolors="k", marker="X", label="T(x)") - ax.legend(fontsize=15) - - -def barycentric_projections( - arg: Union[Transport, jnp.ndarray], - a: jnp.ndarray = None, - b: jnp.ndarray = None, - matrix: jnp.ndarray = None, - ax: Optional["plt.Axes"] = None, - **kwargs -): - """Plot the barycenters, from the Transport object or from arguments.""" - if ax is None: - _, ax = plt.subplots(1, 1, figsize=(8, 5)) - - if utils.is_jax_array(arg): - if matrix is None: - raise ValueError("The `matrix` argument cannot be None.") - - a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a - b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b - return _barycenters(ax, arg, a, b, matrix, **kwargs) - - if isinstance(arg, gromov_wasserstein.GWOutput): - geom = arg.linear_state.geom - else: - geom = arg.geom - - return _barycenters(ax, geom.y, arg.a, arg.b, arg.matrix, **kwargs) diff --git a/src/ott/utils.py b/src/ott/utils.py index b74f55af2..20e673695 100644 --- a/src/ott/utils.py +++ b/src/ott/utils.py @@ -60,7 +60,7 @@ def is_jax_array(obj: Any) -> bool: """Check if an object is a Jax array.""" if hasattr(jax, "Array"): # https://jax.readthedocs.io/en/latest/jax_array_migration.html - return isinstance(obj, (jax.Array, jnp.DeviceArray)) + return isinstance(obj, jax.Array) return isinstance(obj, jnp.DeviceArray) @@ -70,8 +70,7 @@ def default_prng_key( """Get the default PRNG key. Args: - rng: - PRNG key. + rng: PRNG key. Returns: If ``rng = None``, returns the default PRNG key. diff --git a/tests/tools/plot_test.py b/tests/tools/plot_test.py new file mode 100644 index 000000000..78a6edd2d --- /dev/null +++ b/tests/tools/plot_test.py @@ -0,0 +1,46 @@ +# Copyright OTT-JAX +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import matplotlib.pyplot as plt +import ott +from ott.geometry import pointcloud +from ott.problems.linear import linear_problem +from ott.solvers.linear import sinkhorn +from ott.tools import plot + + +class TestSoftSort: + + def test_plot(self, monkeypatch): + monkeypatch.setattr(plt, "show", lambda: None) + n, m, d = 12, 7, 3 + rngs = jax.random.split(jax.random.PRNGKey(0), 3) + xs = [ + jax.random.normal(rngs[0], (n, d)) + 1, + jax.random.normal(rngs[1], (n, d)) + 1 + ] + y = jax.random.uniform(rngs[2], (m, d)) + + solver = sinkhorn.Sinkhorn() + ots = [ + solver(linear_problem.LinearProblem(pointcloud.PointCloud(x, y))) + for x in xs + ] + + plott = plot.Plot() + _ = plott(ots[0]) + fig = plt.figure(figsize=(8, 5)) + plott = ott.tools.plot.Plot(fig=fig) + plott.animate(ots, frame_rate=2)