Skip to content

Commit

Permalink
TST: Clean tests/plotting/test_frame.py misc (#53914)
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Jun 28, 2023
1 parent c96544e commit 0aa3994
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 149 deletions.
48 changes: 31 additions & 17 deletions pandas/tests/plotting/frame/test_frame_legend.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class TestFrameLegend:
)
def test_mixed_yerr(self):
# https://github.com/pandas-dev/pandas/issues/39522
import matplotlib as mpl
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D

Expand All @@ -46,8 +45,6 @@ def test_mixed_yerr(self):

def test_legend_false(self):
# https://github.com/pandas-dev/pandas/issues/40044
import matplotlib as mpl

df = DataFrame({"a": [1, 1], "b": [2, 3]})
df2 = DataFrame({"d": [2.5, 2.5]})

Expand All @@ -63,27 +60,31 @@ def test_legend_false(self):
assert result == expected

@td.skip_if_no_scipy
def test_df_legend_labels(self):
kinds = ["line", "bar", "barh", "kde", "area", "hist"]
@pytest.mark.parametrize("kind", ["line", "bar", "barh", "kde", "area", "hist"])
def test_df_legend_labels(self, kind):
df = DataFrame(np.random.rand(3, 3), columns=["a", "b", "c"])
df2 = DataFrame(np.random.rand(3, 3), columns=["d", "e", "f"])
df3 = DataFrame(np.random.rand(3, 3), columns=["g", "h", "i"])
df4 = DataFrame(np.random.rand(3, 3), columns=["j", "k", "l"])

for kind in kinds:
ax = df.plot(kind=kind, legend=True)
_check_legend_labels(ax, labels=df.columns)
ax = df.plot(kind=kind, legend=True)
_check_legend_labels(ax, labels=df.columns)

ax = df2.plot(kind=kind, legend=False, ax=ax)
_check_legend_labels(ax, labels=df.columns)
ax = df2.plot(kind=kind, legend=False, ax=ax)
_check_legend_labels(ax, labels=df.columns)

ax = df3.plot(kind=kind, legend=True, ax=ax)
_check_legend_labels(ax, labels=df.columns.union(df3.columns))
ax = df3.plot(kind=kind, legend=True, ax=ax)
_check_legend_labels(ax, labels=df.columns.union(df3.columns))

ax = df4.plot(kind=kind, legend="reverse", ax=ax)
expected = list(df.columns.union(df3.columns)) + list(reversed(df4.columns))
_check_legend_labels(ax, labels=expected)
ax = df4.plot(kind=kind, legend="reverse", ax=ax)
expected = list(df.columns.union(df3.columns)) + list(reversed(df4.columns))
_check_legend_labels(ax, labels=expected)

@td.skip_if_no_scipy
def test_df_legend_labels_secondary_y(self):
df = DataFrame(np.random.rand(3, 3), columns=["a", "b", "c"])
df2 = DataFrame(np.random.rand(3, 3), columns=["d", "e", "f"])
df3 = DataFrame(np.random.rand(3, 3), columns=["g", "h", "i"])
# Secondary Y
ax = df.plot(legend=True, secondary_y="b")
_check_legend_labels(ax, labels=["a", "b (right)", "c"])
Expand All @@ -92,6 +93,8 @@ def test_df_legend_labels(self):
ax = df3.plot(kind="bar", legend=True, secondary_y="h", ax=ax)
_check_legend_labels(ax, labels=["a", "b (right)", "c", "g", "h (right)", "i"])

@td.skip_if_no_scipy
def test_df_legend_labels_time_series(self):
# Time Series
ind = date_range("1/1/2014", periods=3)
df = DataFrame(np.random.randn(3, 3), columns=["a", "b", "c"], index=ind)
Expand All @@ -104,6 +107,13 @@ def test_df_legend_labels(self):
ax = df3.plot(legend=True, ax=ax)
_check_legend_labels(ax, labels=["a", "b (right)", "c", "g", "h", "i"])

@td.skip_if_no_scipy
def test_df_legend_labels_time_series_scatter(self):
# Time Series
ind = date_range("1/1/2014", periods=3)
df = DataFrame(np.random.randn(3, 3), columns=["a", "b", "c"], index=ind)
df2 = DataFrame(np.random.randn(3, 3), columns=["d", "e", "f"], index=ind)
df3 = DataFrame(np.random.randn(3, 3), columns=["g", "h", "i"], index=ind)
# scatter
ax = df.plot.scatter(x="a", y="b", label="data1")
_check_legend_labels(ax, labels=["data1"])
Expand All @@ -112,6 +122,10 @@ def test_df_legend_labels(self):
ax = df3.plot.scatter(x="g", y="h", label="data3", ax=ax)
_check_legend_labels(ax, labels=["data1", "data3"])

@td.skip_if_no_scipy
def test_df_legend_labels_time_series_no_mutate(self):
ind = date_range("1/1/2014", periods=3)
df = DataFrame(np.random.randn(3, 3), columns=["a", "b", "c"], index=ind)
# ensure label args pass through and
# index name does not mutate
# column names don't mutate
Expand All @@ -128,7 +142,7 @@ def test_df_legend_labels(self):
def test_missing_marker_multi_plots_on_same_ax(self):
# GH 18222
df = DataFrame(data=[[1, 1, 1, 1], [2, 2, 4, 8]], columns=["x", "r", "g", "b"])
fig, ax = mpl.pyplot.subplots(nrows=1, ncols=3)
_, ax = mpl.pyplot.subplots(nrows=1, ncols=3)
# Left plot
df.plot(x="x", y="r", linewidth=0, marker="o", color="r", ax=ax[0])
df.plot(x="x", y="g", linewidth=1, marker="x", color="g", ax=ax[0])
Expand Down Expand Up @@ -210,7 +224,7 @@ def test_missing_markers_legend_using_style(self):
}
)

fig, ax = mpl.pyplot.subplots()
_, ax = mpl.pyplot.subplots()
for kind in "ABC":
df.plot("X", kind, label=kind, ax=ax, style=".")

Expand Down

0 comments on commit 0aa3994

Please sign in to comment.