Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test Plot class in tools module to increase coverage #394

Merged
merged 4 commits into from
Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 10 additions & 49 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import numpy as np
import scipy

from ott import utils
from ott.geometry import pointcloud
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
Expand All @@ -39,7 +38,9 @@ def bidimensional(x: jnp.ndarray,
if x.shape[1] < 3:
return x, y

u, s, _ = scipy.sparse.linalg.svds(jnp.concatenate([x, y], axis=0), k=2)
u, s, _ = scipy.sparse.linalg.svds(
np.array(jnp.concatenate([x, y], axis=0)), k=2
)
proj = u * s
k = x.shape[0]
return proj[:k], proj[k:]
Expand All @@ -48,14 +49,18 @@ def bidimensional(x: jnp.ndarray,
class Plot:
"""Plot an optimal transport map between two point clouds.
It enables to either plot or update a plot in a single object, offering the
possibilities to create animations as a
This object can either plot or update a plot, to create animations as a
:class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to
disk at will. There are two design principles here:
#. we do not rely on saving to/loading from disk to create animations
#. we try as much as possible to disentangle the transport problem from
its visualization.
We use 2D scatter plots by default, relying on PCA visualization for d>3 data.
This step requires a conversion to a numpy array, in order to compute leading
singular values. This tool is therefore not designed having performance in
mind.
"""

def __init__(
Expand Down Expand Up @@ -135,7 +140,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray):
return result

def __call__(self, ot: Transport) -> List["plt.Artist"]:
"""Plot 2-D couplings. Projects via PCA if data is higher dimensional."""
"""Plot couplings in 2-D, using PCA if data is higher dimensional."""
x, y, sx, sy = self._scatter(ot)
self._points_x = self.ax.scatter(
*x.T, s=sx, edgecolors="k", marker="o", label="x"
Expand Down Expand Up @@ -214,47 +219,3 @@ def animate(
interval=1000 / frame_rate,
blit=True
)


def _barycenters(
ax: "plt.Axes",
y: jnp.ndarray,
a: jnp.ndarray,
b: jnp.ndarray,
matrix: jnp.ndarray,
scale: int = 200
) -> None:
"""Plot 2-D sinkhorn barycenters."""
sa, sb = jnp.min(a) / scale, jnp.min(b) / scale
ax.scatter(*y.T, s=b / sb, edgecolors="k", marker="X", label="y")
tx = 1 / a[:, None] * jnp.matmul(matrix, y)
ax.scatter(*tx.T, s=a / sa, edgecolors="k", marker="X", label="T(x)")
ax.legend(fontsize=15)


def barycentric_projections(
arg: Union[Transport, jnp.ndarray],
a: jnp.ndarray = None,
b: jnp.ndarray = None,
matrix: jnp.ndarray = None,
ax: Optional["plt.Axes"] = None,
**kwargs
):
"""Plot the barycenters, from the Transport object or from arguments."""
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(8, 5))

if utils.is_jax_array(arg):
if matrix is None:
raise ValueError("The `matrix` argument cannot be None.")

a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a
b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b
return _barycenters(ax, arg, a, b, matrix, **kwargs)

if isinstance(arg, gromov_wasserstein.GWOutput):
geom = arg.linear_state.geom
else:
geom = arg.geom

return _barycenters(ax, geom.y, arg.a, arg.b, arg.matrix, **kwargs)
5 changes: 2 additions & 3 deletions src/ott/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def is_jax_array(obj: Any) -> bool:
"""Check if an object is a Jax array."""
if hasattr(jax, "Array"):
# https://jax.readthedocs.io/en/latest/jax_array_migration.html
return isinstance(obj, (jax.Array, jnp.DeviceArray))
return isinstance(obj, jax.Array)
return isinstance(obj, jnp.DeviceArray)


Expand All @@ -70,8 +70,7 @@ def default_prng_key(
"""Get the default PRNG key.
Args:
rng:
PRNG key.
rng: PRNG key.
Returns:
If ``rng = None``, returns the default PRNG key.
Expand Down
46 changes: 46 additions & 0 deletions tests/tools/plot_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import matplotlib.pyplot as plt
import ott
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from ott.tools import plot


class TestSoftSort:

def test_plot(self, monkeypatch):
monkeypatch.setattr(plt, "show", lambda: None)
n, m, d = 12, 7, 3
rngs = jax.random.split(jax.random.PRNGKey(0), 3)
xs = [
jax.random.normal(rngs[0], (n, d)) + 1,
jax.random.normal(rngs[1], (n, d)) + 1
]
y = jax.random.uniform(rngs[2], (m, d))

solver = sinkhorn.Sinkhorn()
ots = [
solver(linear_problem.LinearProblem(pointcloud.PointCloud(x, y)))
for x in xs
]

plott = plot.Plot()
_ = plott(ots[0])
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
plott.animate(ots, frame_rate=2)