Skip to content

Commit ead0c76

Browse files
authored
Merge branch 'main' into plt-import
2 parents d9ba65c + aa4c088 commit ead0c76

File tree

3 files changed

+137
-81
lines changed

3 files changed

+137
-81
lines changed

plotly/matplotlylib/renderer.py

Lines changed: 46 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ def __init__(self):
6060
self.mpl_x_bounds = (0, 1)
6161
self.mpl_y_bounds = (0, 1)
6262
self.msg = "Initialized PlotlyRenderer\n"
63+
self._processing_legend = False
64+
self._legend_visible = False
6365

6466
def open_figure(self, fig, props):
6567
"""Creates a new figure by beginning to fill out layout dict.
@@ -108,7 +110,6 @@ def close_figure(self, fig):
108110
fig -- a matplotlib.figure.Figure object.
109111
110112
"""
111-
self.plotly_fig["layout"]["showlegend"] = False
112113
self.msg += "Closing figure\n"
113114

114115
def open_axes(self, ax, props):
@@ -198,6 +199,37 @@ def close_axes(self, ax):
198199
self.msg += " Closing axes\n"
199200
self.x_is_mpl_date = False
200201

202+
def open_legend(self, legend, props):
203+
"""Enable Plotly's native legend when matplotlib legend is detected.
204+
205+
This method is called when a matplotlib legend is found. It enables
206+
Plotly's showlegend only if the matplotlib legend is visible.
207+
208+
Positional arguments:
209+
legend -- matplotlib.legend.Legend object
210+
props -- legend properties dictionary
211+
"""
212+
self.msg += " Opening legend\n"
213+
self._processing_legend = True
214+
self._legend_visible = props.get("visible", True)
215+
if self._legend_visible:
216+
self.msg += (
217+
" Enabling native plotly legend (matplotlib legend is visible)\n"
218+
)
219+
self.plotly_fig["layout"]["showlegend"] = True
220+
else:
221+
self.msg += " Not enabling legend (matplotlib legend is not visible)\n"
222+
223+
def close_legend(self, legend):
224+
"""Finalize legend processing.
225+
226+
Positional arguments:
227+
legend -- matplotlib.legend.Legend object
228+
"""
229+
self.msg += " Closing legend\n"
230+
self._processing_legend = False
231+
self._legend_visible = False
232+
201233
def draw_bars(self, bars):
202234
# sort bars according to bar containers
203235
mpl_traces = []
@@ -299,7 +331,7 @@ def draw_bar(self, coll):
299331
) # TODO ditto
300332
if len(bar["x"]) > 1:
301333
self.msg += " Heck yeah, I drew that bar chart\n"
302-
(self.plotly_fig.add_trace(bar),)
334+
self.plotly_fig.add_trace(bar)
303335
if bar_gap is not None:
304336
self.plotly_fig["layout"]["bargap"] = bar_gap
305337
else:
@@ -309,83 +341,6 @@ def draw_bar(self, coll):
309341
"assuming data redundancy, not plotting."
310342
)
311343

312-
def draw_legend_shapes(self, mode, shape, **props):
313-
"""Create a shape that matches lines or markers in legends.
314-
315-
Main issue is that path for circles do not render, so we have to use 'circle'
316-
instead of 'path'.
317-
"""
318-
for single_mode in mode.split("+"):
319-
x = props["data"][0][0]
320-
y = props["data"][0][1]
321-
if single_mode == "markers" and props.get("markerstyle"):
322-
size = shape.pop("size", 6)
323-
symbol = shape.pop("symbol")
324-
# aligning to "center"
325-
x0 = 0
326-
y0 = 0
327-
x1 = size
328-
y1 = size
329-
markerpath = props["markerstyle"].get("markerpath")
330-
if markerpath is None and symbol != "circle":
331-
self.msg += (
332-
"not sure how to handle this marker without a valid path\n"
333-
)
334-
return
335-
# marker path to SVG path conversion
336-
path = " ".join(
337-
[f"{a} {t[0]},{t[1]}" for a, t in zip(markerpath[1], markerpath[0])]
338-
)
339-
340-
if symbol == "circle":
341-
# symbols like . and o in matplotlib, use circle
342-
# plotly also maps many other markers to circle, such as 1,8 and p
343-
path = None
344-
shape_type = "circle"
345-
x0 = -size / 2
346-
y0 = size / 2
347-
x1 = size / 2
348-
y1 = size + size / 2
349-
else:
350-
# triangles, star etc
351-
shape_type = "path"
352-
legend_shape = go.layout.Shape(
353-
type=shape_type,
354-
xref="paper",
355-
yref="paper",
356-
x0=x0,
357-
y0=y0,
358-
x1=x1,
359-
y1=y1,
360-
xsizemode="pixel",
361-
ysizemode="pixel",
362-
xanchor=x,
363-
yanchor=y,
364-
path=path,
365-
**shape,
366-
)
367-
368-
elif single_mode == "lines":
369-
mode = "line"
370-
x1 = props["data"][1][0]
371-
y1 = props["data"][1][1]
372-
373-
legend_shape = go.layout.Shape(
374-
type=mode,
375-
xref="paper",
376-
yref="paper",
377-
x0=x,
378-
y0=y + 0.02,
379-
x1=x1,
380-
y1=y1 + 0.02,
381-
**shape,
382-
)
383-
else:
384-
self.msg += "not sure how to handle this element\n"
385-
return
386-
self.plotly_fig.add_shape(legend_shape)
387-
self.msg += " Heck yeah, I drew that shape\n"
388-
389344
def draw_marked_line(self, **props):
390345
"""Create a data dict for a line obj.
391346
@@ -497,11 +452,11 @@ def draw_marked_line(self, **props):
497452
marked_line["x"] = mpltools.mpl_dates_to_datestrings(
498453
marked_line["x"], formatter
499454
)
500-
(self.plotly_fig.add_trace(marked_line),)
455+
self.plotly_fig.add_trace(marked_line)
501456
self.msg += " Heck yeah, I drew that line\n"
502457
elif props["coordinates"] == "axes":
503458
# dealing with legend graphical elements
504-
self.draw_legend_shapes(mode=mode, shape=shape, **props)
459+
self.msg += " Using native legend\n"
505460
else:
506461
self.msg += " Line didn't have 'data' coordinates, not drawing\n"
507462
warnings.warn(
@@ -667,6 +622,16 @@ def draw_text(self, **props):
667622
self.draw_title(**props)
668623
else: # just a regular text annotation...
669624
self.msg += " Text object is a normal annotation\n"
625+
# Skip creating annotations for legend text when using native legend
626+
if (
627+
self._processing_legend
628+
and self._legend_visible
629+
and props["coordinates"] == "axes"
630+
):
631+
self.msg += (
632+
" Skipping legend text annotation (using native legend)\n"
633+
)
634+
return
670635
if props["coordinates"] != "data":
671636
self.msg += " Text object isn't linked to 'data' coordinates\n"
672637
x_px, y_px = (
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import matplotlib
2+
3+
matplotlib.use("Agg")
4+
import matplotlib.pyplot as plt
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import plotly.tools as tls
2+
3+
from . import plt
4+
5+
6+
def test_native_legend_enabled_when_matplotlib_legend_present():
7+
"""Test that when matplotlib legend is present, Plotly uses native legend."""
8+
fig, ax = plt.subplots()
9+
ax.plot([0, 1], [0, 1], label="Line 1")
10+
ax.plot([0, 1], [1, 0], label="Line 2")
11+
ax.legend()
12+
13+
plotly_fig = tls.mpl_to_plotly(fig)
14+
15+
# Should enable native legend
16+
assert plotly_fig.layout.showlegend == True
17+
# Should have 2 traces with names
18+
assert len(plotly_fig.data) == 2
19+
assert plotly_fig.data[0].name == "Line 1"
20+
assert plotly_fig.data[1].name == "Line 2"
21+
22+
23+
def test_no_fake_legend_shapes_with_native_legend():
24+
"""Test that fake legend shapes are not created when using native legend."""
25+
fig, ax = plt.subplots()
26+
ax.plot([0, 1], [0, 1], "o-", label="Data with markers")
27+
ax.legend()
28+
29+
plotly_fig = tls.mpl_to_plotly(fig)
30+
31+
# Should use native legend
32+
assert plotly_fig.layout.showlegend == True
33+
# Should not create fake legend elements
34+
assert len(plotly_fig.layout.shapes) == 0
35+
assert len(plotly_fig.layout.annotations) == 0
36+
37+
38+
def test_legend_disabled_when_no_matplotlib_legend():
39+
"""Test that legend is not enabled when no matplotlib legend is present."""
40+
fig, ax = plt.subplots()
41+
ax.plot([0, 1], [0, 1], label="Line 1") # Has label but no legend() call
42+
43+
plotly_fig = tls.mpl_to_plotly(fig)
44+
45+
# Should not have showlegend explicitly set to True
46+
# (Plotly's default behavior when no legend elements exist)
47+
assert (
48+
not hasattr(plotly_fig.layout, "showlegend")
49+
or plotly_fig.layout.showlegend != True
50+
)
51+
52+
53+
def test_legend_disabled_when_matplotlib_legend_not_visible():
54+
"""Test that legend is not enabled when no matplotlib legend is not visible."""
55+
fig, ax = plt.subplots()
56+
ax.plot([0, 1], [0, 1], label="Line 1")
57+
legend = ax.legend()
58+
legend.set_visible(False) # Hide the legend
59+
60+
plotly_fig = tls.mpl_to_plotly(fig)
61+
62+
# Should not enable legend when matplotlib legend is hidden
63+
assert (
64+
not hasattr(plotly_fig.layout, "showlegend")
65+
or plotly_fig.layout.showlegend != True
66+
)
67+
68+
69+
def test_multiple_traces_native_legend():
70+
"""Test native legend works with multiple traces of different types."""
71+
fig, ax = plt.subplots()
72+
ax.plot([0, 1, 2], [0, 1, 0], "-", label="Line")
73+
ax.plot([0, 1, 2], [1, 0, 1], "o", label="Markers")
74+
ax.plot([0, 1, 2], [0.5, 0.5, 0.5], "s-", label="Line+Markers")
75+
ax.legend()
76+
77+
plotly_fig = tls.mpl_to_plotly(fig)
78+
79+
assert plotly_fig.layout.showlegend == True
80+
assert len(plotly_fig.data) == 3
81+
assert plotly_fig.data[0].name == "Line"
82+
assert plotly_fig.data[1].name == "Markers"
83+
assert plotly_fig.data[2].name == "Line+Markers"
84+
# Verify modes are correct
85+
assert plotly_fig.data[0].mode == "lines"
86+
assert plotly_fig.data[1].mode == "markers"
87+
assert plotly_fig.data[2].mode == "lines+markers"

0 commit comments

Comments
 (0)