Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ability to add titles on animation plots #400

Merged
merged 3 commits into from
Jul 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions src/ott/tools/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,31 @@ class Plot:
This step requires a conversion to a numpy array, in order to compute leading
singular values. This tool is therefore not designed having performance in
mind.

Args:
fig: Specify figure object. Created by default
ax: Specify axes objects. Created by default
threshold: value below which links in transportation matrix won't be
plotted. This value should be negative when using animations.
scale: scale used for marker plots.
show_lines: whether to show OT lines, as described in ``ot.matrix`` argument
cmap: cmap used to plot line colors.
scale_alpha_by_coupling: use or not the coupling's value as proxy for alpha
alpha: default alpha value for lines.
title: title of the plot.
"""

def __init__(
self,
fig: Optional["plt.Figure"] = None,
ax: Optional["plt.Axes"] = None,
cost_threshold: float = -1.0, # should be negative for animations.
threshold: float = -1.0,
scale: int = 200,
show_lines: bool = True,
cmap: str = "cool",
scale_alpha_by_coupling: bool = False,
alpha: float = 0.7,
title: Optional[str] = None
):
if plt is None:
raise RuntimeError("Please install `matplotlib` first.")
Expand All @@ -89,11 +102,12 @@ def __init__(
self._lines = []
self._points_x = None
self._points_y = None
self._threshold = cost_threshold
self._threshold = threshold
self._scale = scale
self._cmap = cmap
self._scale_alpha_by_coupling = scale_alpha_by_coupling
self._alpha = alpha
self._title = title

def _scatter(self, ot: Transport):
"""Compute the position and scales of the points on a 2D plot."""
Expand Down Expand Up @@ -165,9 +179,13 @@ def __call__(self, ot: Transport) -> List["plt.Artist"]:
alpha=alpha
)
self._lines.append(line)
if self._title is not None:
self.ax.set_title(self._title)
return [self._points_x, self._points_y] + self._lines

def update(self, ot: Transport) -> List["plt.Artist"]:
def update(self,
ot: Transport,
title: Optional[str] = None) -> List["plt.Artist"]:
"""Update a plot with a transport instance."""
x, y, _, _ = self._scatter(ot)
self._points_x.set_offsets(x)
Expand Down Expand Up @@ -202,20 +220,28 @@ def update(self, ot: Transport) -> List["plt.Artist"]:
self._lines.append(line)

self._lines = self._lines[:num_to_plot] # Maybe remove some
if title is not None:
self.ax.set_title(title)
return [self._points_x, self._points_y] + self._lines

def animate(
self,
transports: Sequence[Transport],
titles: Optional[Sequence[str]] = None,
frame_rate: float = 10.0
) -> "animation.FuncAnimation":
"""Make an animation from several transports."""
_ = self(transports[0])
if titles is None:
titles = [None for _ in np.range(0, len(transports))]
assert len(titles) == len(transports), (
"titles and transports have different lengths"
)
return animation.FuncAnimation(
self.fig,
lambda i: self.update(transports[i]),
lambda i: self.update(transports[i], titles[i]),
np.arange(0, len(transports)),
init_func=lambda: self.update(transports[0]),
init_func=lambda: self.update(transports[0], titles[0]),
interval=1000 / frame_rate,
blit=True
)
4 changes: 2 additions & 2 deletions tests/tools/plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def test_plot(self, monkeypatch):
plott = plot.Plot()
_ = plott(ots[0])
fig = plt.figure(figsize=(8, 5))
plott = ott.tools.plot.Plot(fig=fig)
plott.animate(ots, frame_rate=2)
plott = ott.tools.plot.Plot(fig=fig, title="test")
plott.animate(ots, frame_rate=2, titles=["test1", "test2"])
Loading