diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index b0d2d1d08..6f045dbf0 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -17,7 +17,6 @@ import numpy as np import scipy -from ott import utils from ott.geometry import pointcloud from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein @@ -218,47 +217,3 @@ def animate( interval=1000 / frame_rate, blit=True ) - - -def _barycenters( - ax: "plt.Axes", - y: jnp.ndarray, - a: jnp.ndarray, - b: jnp.ndarray, - matrix: jnp.ndarray, - scale: int = 200 -) -> None: - """Plot 2-D sinkhorn barycenters.""" - sa, sb = jnp.min(a) / scale, jnp.min(b) / scale - ax.scatter(*y.T, s=b / sb, edgecolors="k", marker="X", label="y") - tx = 1 / a[:, None] * jnp.matmul(matrix, y) - ax.scatter(*tx.T, s=a / sa, edgecolors="k", marker="X", label="T(x)") - ax.legend(fontsize=15) - - -def barycentric_projections( - arg: Union[Transport, jnp.ndarray], - a: jnp.ndarray = None, - b: jnp.ndarray = None, - matrix: jnp.ndarray = None, - ax: Optional["plt.Axes"] = None, - **kwargs -): - """Plot the barycenters, from the Transport object or from arguments.""" - if ax is None: - _, ax = plt.subplots(1, 1, figsize=(8, 5)) - - if utils.is_jax_array(arg): - if matrix is None: - raise ValueError("The `matrix` argument cannot be None.") - - a = jnp.ones(matrix.shape[0]) / matrix.shape[0] if a is None else a - b = jnp.ones(matrix.shape[1]) / matrix.shape[1] if b is None else b - return _barycenters(ax, arg, a, b, matrix, **kwargs) - - if isinstance(arg, gromov_wasserstein.GWOutput): - geom = arg.linear_state.geom - else: - geom = arg.geom - - return _barycenters(ax, geom.y, arg.a, arg.b, arg.matrix, **kwargs)