Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

Already on GitHub? Sign in to your account

BUG: inconsistent subplot ax handling #7391

Merged
merged 1 commit into from Jul 6, 2014
Jump to file or symbol
Failed to load files and symbols.
+57 −39
Split
View
@@ -176,6 +176,8 @@ Bug Fixes
- Bug in ``to_timedelta`` that accepted invalid units and misinterpreted 'm/h' (:issue:`7611`, :issue: `6423`)
- Bug in grouped ``hist`` and ``scatter`` plots use old ``figsize`` default (:issue:`7394`)
+- Bug in plotting subplots with ``DataFrame.plot``, ``hist`` clears passed ``ax`` even if the number of subplots is one (:issue:`7391`).
+- Bug in plotting subplots with ``DataFrame.boxplot`` with ``by`` kw raises ``ValueError`` if the number of subplots exceeds 1 (:issue:`7391`).
- Bug in ``Panel.apply`` with a multi-index as an axis (:issue:`7469`)
@@ -859,6 +859,13 @@ def test_plot(self):
axes = _check_plot_works(df.plot, kind='bar', subplots=True)
self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
+ # When ax is supplied and required number of axes is 1,
+ # passed ax should be used:
+ fig, ax = self.plt.subplots()
+ axes = df.plot(kind='bar', subplots=True, ax=ax)
+ self.assertEqual(len(axes), 1)
+ self.assertIs(ax.get_axes(), axes[0])
+
def test_nonnumeric_exclude(self):
df = DataFrame({'A': ["x", "y", "z"], 'B': [1, 2, 3]})
ax = df.plot()
@@ -1419,17 +1426,23 @@ def test_boxplot(self):
df = DataFrame(np.random.rand(10, 2), columns=['Col1', 'Col2'])
df['X'] = Series(['A', 'A', 'A', 'A', 'A', 'B', 'B', 'B', 'B', 'B'])
+ df['Y'] = Series(['A'] * 10)
_check_plot_works(df.boxplot, by='X')
- # When ax is supplied, existing axes should be used:
+ # When ax is supplied and required number of axes is 1,
+ # passed ax should be used:
fig, ax = self.plt.subplots()
axes = df.boxplot('Col1', by='X', ax=ax)
self.assertIs(ax.get_axes(), axes)
- # Multiple columns with an ax argument is not supported
fig, ax = self.plt.subplots()
- with tm.assertRaisesRegexp(ValueError, 'existing axis'):
- df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax)
+ axes = df.groupby('Y').boxplot(ax=ax, return_type='axes')
+ self.assertIs(ax.get_axes(), axes['A'])
+
+ # Multiple columns with an ax argument should use same figure
+ fig, ax = self.plt.subplots()
+ axes = df.boxplot(column=['Col1', 'Col2'], by='X', ax=ax, return_type='axes')
+ self.assertIs(axes['Col1'].get_figure(), fig)
# When by is None, check that all relevant lines are present in the dict
fig, ax = self.plt.subplots()
@@ -2180,32 +2193,32 @@ class TestDataFrameGroupByPlots(TestPlotBase):
@slow
def test_boxplot(self):
grouped = self.hist_df.groupby(by='gender')
- box = _check_plot_works(grouped.boxplot, return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
+ axes = _check_plot_works(grouped.boxplot, return_type='axes')
+ self._check_axes_shape(axes.values(), axes_num=2, layout=(1, 2))
- box = _check_plot_works(grouped.boxplot, subplots=False,
- return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=2, layout=(1, 2))
+ axes = _check_plot_works(grouped.boxplot, subplots=False,
+ return_type='axes')
+ self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
tuples = lzip(string.ascii_letters[:10], range(10))
df = DataFrame(np.random.rand(10, 3),
index=MultiIndex.from_tuples(tuples))
grouped = df.groupby(level=1)
- box = _check_plot_works(grouped.boxplot, return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
+ axes = _check_plot_works(grouped.boxplot, return_type='axes')
+ self._check_axes_shape(axes.values(), axes_num=10, layout=(4, 3))
- box = _check_plot_works(grouped.boxplot, subplots=False,
- return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=10, layout=(4, 3))
+ axes = _check_plot_works(grouped.boxplot, subplots=False,
+ return_type='axes')
+ self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
grouped = df.unstack(level=1).groupby(level=0, axis=1)
- box = _check_plot_works(grouped.boxplot, return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
+ axes = _check_plot_works(grouped.boxplot, return_type='axes')
+ self._check_axes_shape(axes.values(), axes_num=3, layout=(2, 2))
- box = _check_plot_works(grouped.boxplot, subplots=False,
- return_type='dict')
- self._check_axes_shape(self.plt.gcf().axes, axes_num=3, layout=(2, 2))
+ axes = _check_plot_works(grouped.boxplot, subplots=False,
+ return_type='axes')
+ self._check_axes_shape(axes, axes_num=1, layout=(1, 1))
def test_series_plot_color_kwargs(self):
# GH1890
View
@@ -2665,7 +2665,8 @@ def plot_group(group, ax):
def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
- rot=0, grid=True, figsize=None, layout=None, **kwds):
+ rot=0, grid=True, ax=None, figsize=None,
+ layout=None, **kwds):
"""
Make box plots from DataFrameGroupBy data.
@@ -2712,7 +2713,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
naxes = len(grouped)
nrows, ncols = _get_layout(naxes, layout=layout)
fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes, squeeze=False,
- sharex=False, sharey=True)
+ ax=ax, sharex=False, sharey=True, figsize=figsize)
axes = _flatten(axes)
ret = compat.OrderedDict()
@@ -2733,7 +2734,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
else:
df = frames[0]
ret = df.boxplot(column=column, fontsize=fontsize, rot=rot,
- grid=grid, figsize=figsize, layout=layout, **kwds)
+ grid=grid, ax=ax, figsize=figsize, layout=layout, **kwds)
return ret
@@ -2779,17 +2780,10 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
by = [by]
columns = data._get_numeric_data().columns - by
naxes = len(columns)
-
- if ax is None:
- nrows, ncols = _get_layout(naxes, layout=layout)
- fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
- sharex=True, sharey=True,
- figsize=figsize, ax=ax)
- else:
- if naxes > 1:
- raise ValueError("Using an existing axis is not supported when plotting multiple columns.")
- fig = ax.get_figure()
- axes = ax.get_axes()
+ nrows, ncols = _get_layout(naxes, layout=layout)
+ fig, axes = _subplots(nrows=nrows, ncols=ncols, naxes=naxes,
+ sharex=True, sharey=True,
+ figsize=figsize, ax=ax)
ravel_axes = _flatten(axes)
@@ -2974,12 +2968,6 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
if subplot_kw is None:
subplot_kw = {}
- if ax is None:
- fig = plt.figure(**fig_kw)
- else:
- fig = ax.get_figure()
- fig.clear()
-
# Create empty object array to hold all axes. It's easiest to make it 1-d
# so we can just append subplots upon creation, and then
nplots = nrows * ncols
@@ -2989,6 +2977,21 @@ def _subplots(nrows=1, ncols=1, naxes=None, sharex=False, sharey=False, squeeze=
elif nplots < naxes:
raise ValueError("naxes {0} is larger than layour size defined by nrows * ncols".format(naxes))
+ if ax is None:
+ fig = plt.figure(**fig_kw)
+ else:
+ fig = ax.get_figure()
+ # if ax is passed and a number of subplots is 1, return ax as it is
+ if naxes == 1:
+ if squeeze:
+ return fig, ax
+ else:
+ return fig, _flatten(ax)
+ else:
+ warnings.warn("To output multiple subplots, the figure containing the passed axes "
+ "is being cleared", UserWarning)
+ fig.clear()
+
axarr = np.empty(nplots, dtype=object)
def on_right(i):