From 83ebc8ff69758168b3035fff6491f91b3ca898e6 Mon Sep 17 00:00:00 2001 From: marcocuturi Date: Wed, 12 Jul 2023 10:31:09 -0700 Subject: [PATCH] minor fixes --- src/ott/tools/plot.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index 6f045dbf0..ccd3b1ec3 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -49,16 +49,18 @@ def bidimensional(x: jnp.ndarray, class Plot: """Plot an optimal transport map between two point clouds. - It enables to either plot or update a plot in a single object, offering the - possibilities to create animations as a + This object can either plot or update a plot, to create animations as a :class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to disk at will. There are two design principles here: #. we do not rely on saving to/loading from disk to create animations #. we try as much as possible to disentangle the transport problem from its visualization. - #. we rely on PCA visualization for d>3 data. this requires a conversion to - a numpy array, which can be slow for large samples. + + We use 2D scatter plots by default, relying on PCA visualization for d>3 data. + This step requires a conversion to a numpy array, in order to compute leading + singular values. This tool is therefore not designed having performance in + mind. """ def __init__( @@ -138,7 +140,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray): return result def __call__(self, ot: Transport) -> List["plt.Artist"]: - """Plot 2-D couplings. Projects via PCA if data is higher dimensional.""" + """Plot couplings in 2-D, using PCA if data is higher dimensional.""" x, y, sx, sy = self._scatter(ot) self._points_x = self.ax.scatter( *x.T, s=sx, edgecolors="k", marker="o", label="x"