Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marcocuturi committed Jul 12, 2023
1 parent 8fc3dc9 commit 83ebc8f
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,18 @@ def bidimensional(x: jnp.ndarray,
class Plot:
"""Plot an optimal transport map between two point clouds.
It enables to either plot or update a plot in a single object, offering the
possibilities to create animations as a
This object can either plot or update a plot, to create animations as a
:class:`~matplotlib.animation.FuncAnimation`, which can in turned be saved to
disk at will. There are two design principles here:
#. we do not rely on saving to/loading from disk to create animations
#. we try as much as possible to disentangle the transport problem from
its visualization.
#. we rely on PCA visualization for d>3 data. this requires a conversion to
a numpy array, which can be slow for large samples.
We use 2D scatter plots by default, relying on PCA visualization for d>3 data.
This step requires a conversion to a numpy array, in order to compute leading
singular values. This tool is therefore not designed having performance in
mind.
"""

def __init__(
Expand Down Expand Up @@ -138,7 +140,7 @@ def _mapping(self, x: jnp.ndarray, y: jnp.ndarray, matrix: jnp.ndarray):
return result

def __call__(self, ot: Transport) -> List["plt.Artist"]:
"""Plot 2-D couplings. Projects via PCA if data is higher dimensional."""
"""Plot couplings in 2-D, using PCA if data is higher dimensional."""
x, y, sx, sy = self._scatter(ot)
self._points_x = self.ax.scatter(
*x.T, s=sx, edgecolors="k", marker="o", label="x"
Expand Down

0 comments on commit 83ebc8f

Please sign in to comment.