diff --git a/src/ott/tools/plot.py b/src/ott/tools/plot.py index ccd3b1ec3..1a1e340d1 100644 --- a/src/ott/tools/plot.py +++ b/src/ott/tools/plot.py @@ -61,18 +61,31 @@ class Plot: This step requires a conversion to a numpy array, in order to compute leading singular values. This tool is therefore not designed having performance in mind. + + Args: + fig: Specify figure object. Created by default + ax: Specify axes objects. Created by default + threshold: value below which links in transportation matrix won't be + plotted. This value should be negative when using animations. + scale: scale used for marker plots. + show_lines: whether to show OT lines, as described in ``ot.matrix`` argument + cmap: cmap used to plot line colors. + scale_alpha_by_coupling: use or not the coupling's value as proxy for alpha + alpha: default alpha value for lines. + title: title of the plot. """ def __init__( self, fig: Optional["plt.Figure"] = None, ax: Optional["plt.Axes"] = None, - cost_threshold: float = -1.0, # should be negative for animations. + threshold: float = -1.0, scale: int = 200, show_lines: bool = True, cmap: str = "cool", scale_alpha_by_coupling: bool = False, alpha: float = 0.7, + title: Optional[str] = None ): if plt is None: raise RuntimeError("Please install `matplotlib` first.") @@ -89,11 +102,12 @@ def __init__( self._lines = [] self._points_x = None self._points_y = None - self._threshold = cost_threshold + self._threshold = threshold self._scale = scale self._cmap = cmap self._scale_alpha_by_coupling = scale_alpha_by_coupling self._alpha = alpha + self._title = title def _scatter(self, ot: Transport): """Compute the position and scales of the points on a 2D plot.""" @@ -165,9 +179,13 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]: alpha=alpha ) self._lines.append(line) + if self._title is not None: + self.ax.set_title(self._title) return [self._points_x, self._points_y] + self._lines - def update(self, ot: Transport) -> List["plt.Artist"]: + def update(self, + ot: Transport, + title: Optional[str] = None) -> List["plt.Artist"]: """Update a plot with a transport instance.""" x, y, _, _ = self._scatter(ot) self._points_x.set_offsets(x) @@ -202,20 +220,28 @@ def update(self, ot: Transport) -> List["plt.Artist"]: self._lines.append(line) self._lines = self._lines[:num_to_plot] # Maybe remove some + if title is not None: + self.ax.set_title(title) return [self._points_x, self._points_y] + self._lines def animate( self, transports: Sequence[Transport], + titles: Optional[Sequence[str]] = None, frame_rate: float = 10.0 ) -> "animation.FuncAnimation": """Make an animation from several transports.""" _ = self(transports[0]) + if titles is None: + titles = [None for _ in np.range(0, len(transports))] + assert len(titles) == len(transports), ( + "titles and transports have different lengths" + ) return animation.FuncAnimation( self.fig, - lambda i: self.update(transports[i]), + lambda i: self.update(transports[i], titles[i]), np.arange(0, len(transports)), - init_func=lambda: self.update(transports[0]), + init_func=lambda: self.update(transports[0], titles[0]), interval=1000 / frame_rate, blit=True ) diff --git a/tests/tools/plot_test.py b/tests/tools/plot_test.py index 78a6edd2d..80e374bb6 100644 --- a/tests/tools/plot_test.py +++ b/tests/tools/plot_test.py @@ -42,5 +42,5 @@ def test_plot(self, monkeypatch): plott = plot.Plot() _ = plott(ots[0]) fig = plt.figure(figsize=(8, 5)) - plott = ott.tools.plot.Plot(fig=fig) - plott.animate(ots, frame_rate=2) + plott = ott.tools.plot.Plot(fig=fig, title="test") + plott.animate(ots, frame_rate=2, titles=["test1", "test2"])