Skip to content

Commit

Permalink
add ability to add titles on animation plots (#400)
Browse files Browse the repository at this point in the history
* add ability to add titles on plots

* include title / titles in test

* cost_threshold -> threshold
  • Loading branch information
marcocuturi committed Jul 20, 2023
1 parent 5ab76b0 commit 34f2465
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 7 deletions.
36 changes: 31 additions & 5 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,31 @@ class Plot:
This step requires a conversion to a numpy array, in order to compute leading
singular values. This tool is therefore not designed having performance in
mind.
Args:
fig: Specify figure object. Created by default
ax: Specify axes objects. Created by default
threshold: value below which links in transportation matrix won't be
plotted. This value should be negative when using animations.
scale: scale used for marker plots.
show_lines: whether to show OT lines, as described in ``ot.matrix`` argument
cmap: cmap used to plot line colors.
scale_alpha_by_coupling: use or not the coupling's value as proxy for alpha
alpha: default alpha value for lines.
title: title of the plot.
"""

def __init__(
self,
fig: Optional["plt.Figure"] = None,
ax: Optional["plt.Axes"] = None,
cost_threshold: float = -1.0, # should be negative for animations.
threshold: float = -1.0,
scale: int = 200,
show_lines: bool = True,
cmap: str = "cool",
scale_alpha_by_coupling: bool = False,
alpha: float = 0.7,
title: Optional[str] = None
):
if plt is None:
raise RuntimeError("Please install `matplotlib` first.")
Expand All @@ -89,11 +102,12 @@ def __init__(
self._lines = []
self._points_x = None
self._points_y = None
self._threshold = cost_threshold
self._threshold = threshold
self._scale = scale
self._cmap = cmap
self._scale_alpha_by_coupling = scale_alpha_by_coupling
self._alpha = alpha
self._title = title

def _scatter(self, ot: Transport):
"""Compute the position and scales of the points on a 2D plot."""
Expand Down Expand Up @@ -165,9 +179,13 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]:
alpha=alpha
)
self._lines.append(line)
if self._title is not None:
self.ax.set_title(self._title)
return [self._points_x, self._points_y] + self._lines

def update(self, ot: Transport) -> List["plt.Artist"]:
def update(self,
ot: Transport,
title: Optional[str] = None) -> List["plt.Artist"]:
"""Update a plot with a transport instance."""
x, y, _, _ = self._scatter(ot)
self._points_x.set_offsets(x)
Expand Down Expand Up @@ -202,20 +220,28 @@ def update(self, ot: Transport) -> List["plt.Artist"]:
self._lines.append(line)

self._lines = self._lines[:num_to_plot] # Maybe remove some
if title is not None:
self.ax.set_title(title)
return [self._points_x, self._points_y] + self._lines

def animate(
self,
transports: Sequence[Transport],
titles: Optional[Sequence[str]] = None,
frame_rate: float = 10.0
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
_ = self(transports[0])
if titles is None:
titles = [None for _ in np.range(0, len(transports))]
assert len(titles) == len(transports), (
"titles and transports have different lengths"
)
return animation.FuncAnimation(
self.fig,
lambda i: self.update(transports[i]),
lambda i: self.update(transports[i], titles[i]),
np.arange(0, len(transports)),
init_func=lambda: self.update(transports[0]),
init_func=lambda: self.update(transports[0], titles[0]),
interval=1000 / frame_rate,
blit=True
)
4 changes: 2 additions & 2 deletions tests/tools/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def test_plot(self, monkeypatch):
plott = plot.Plot()
_ = plott(ots[0])
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
plott.animate(ots, frame_rate=2)
plott = ott.tools.plot.Plot(fig=fig, title="test")
plott.animate(ots, frame_rate=2, titles=["test1", "test2"])

0 comments on commit 34f2465

Please sign in to comment.