From 78a0ca900ba31023e8cebb9fa208ea1cba1031e1 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Sat, 22 Sep 2012 22:53:16 +0900 Subject: [PATCH 1/2] ENH: add arguments for scatter_plot --- pandas/tools/plotting.py | 47 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 60ed0c70d516b..56caebf8e92b8 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1431,30 +1431,56 @@ def format_date_labels(ax, rot): pass -def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False): +def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, + sharex=True, sharey=True, xlim=None, ylim=None, + **kwds): """ + Draw scatter plot of the DataFrame's series using matplotlib / pylab. - Returns - ------- - fig : matplotlib.Figure + Parameters + ---------- + data : Dataframe + x : column name of Dataframe for x axis + y : column name of Dataframe for y axis + by : column in the DataFrame to group by + ax : matplotlib axes object, default None + figsize : + grid : boolean, default True + Whether to show axis grid lines + xlabelsize : int, default None + If specified changes the x-axis label size + ylabelsize : int, default None + If specified changes the y-axis label size + sharex : bool, if True, the X axis will be shared amongst all subplots. + sharey : bool, if True, the Y axis will be shared amongst all subplots. + xlim : 2-tuple/list + ylim : 2-tuple/list + kwds : other plotting keyword arguments + To be passed to scatter function """ import matplotlib.pyplot as plt - def plot_group(group, ax): + def plot_group(group, ax, xlim=None, ylim=None,**kwds): xvals = group[x].values yvals = group[y].values - ax.scatter(xvals, yvals) + ax.scatter(xvals, yvals, **kwds) + if ylim is not None: + ax.set_ylim(ylim) + if xlim is not None: + ax.set_xlim(xlim) ax.grid(grid) if by is not None: - fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax) + fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, + sharex=sharex, sharey=sharey, xlim=xlim, ylim=ylim, + ax=ax, **kwds) else: if ax is None: fig = plt.figure() ax = fig.add_subplot(111) else: fig = ax.get_figure() - plot_group(data, ax) + plot_group(data, ax, xlim=xlim, ylim=ylim, **kwds) ax.set_ylabel(com._stringify(y)) ax.set_xlabel(com._stringify(x)) @@ -1638,7 +1664,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, figsize=None, sharex=True, sharey=True, layout=None, - rot=0, ax=None): + rot=0, grid=False, ax=None, **kwds): from pandas.core.frame import DataFrame import matplotlib.pyplot as plt @@ -1675,8 +1701,9 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True, ax = ravel_axes[i] if numeric_only and isinstance(group, DataFrame): group = group._get_numeric_data() - plotf(group, ax) + plotf(group, ax, **kwds) ax.set_title(com._stringify(key)) + ax.grid(grid) return fig, axes From 9210b6669d258b8ba95c56caee53d66c38fc1037 Mon Sep 17 00:00:00 2001 From: sinhrks Date: Mon, 8 Oct 2012 23:44:37 +0900 Subject: [PATCH 2/2] ENH: Add arguments for scatter_plot and test cases --- pandas/tests/test_graphics.py | 16 ++++++-- pandas/tools/plotting.py | 69 +++++++++++++++++++++++++++-------- 2 files changed, 66 insertions(+), 19 deletions(-) diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index f2fd98169b585..eb2596c5cf9b3 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -380,13 +380,21 @@ def scat(**kwds): _check_plot_works(scat, diagonal='kde') _check_plot_works(scat, diagonal='density') _check_plot_works(scat, diagonal='hist') - - def scat2(x, y, by=None, ax=None, figsize=None): - return plt.scatter_plot(df, x, y, by, ax, figsize=None) - + def scat2(x, y, by=None, ax=None, figsize=None, **kwds): + return plt.scatter_plot(df, x, y, by, ax, figsize=None, **kwds) _check_plot_works(scat2, 0, 1) grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index) _check_plot_works(scat2, 0, 1, by=grouper) + _check_plot_works(scat2, 0, 1, color='red', xlim=(1,5), ylim=(1,5)) + _check_plot_works(scat2, 0, 1, by=grouper, sharex=True, sharey=True) + + xf, yf = 20, 30 + fig = scat2(0, 1, xlabelsize=xf, ylabelsize=yf) + for ax in fig.axes: + ytick = ax.get_yticklabels()[0] + xtick = ax.get_xticklabels()[0] + self.assertAlmostEqual(ytick.get_fontsize(), yf) + self.assertAlmostEqual(xtick.get_fontsize(), xf) @slow def test_andrews_curves(self): diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 56caebf8e92b8..03524ed40b119 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1432,8 +1432,8 @@ def format_date_labels(ax, rot): def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, - sharex=True, sharey=True, xlim=None, ylim=None, - **kwds): + sharex=True, sharey=True, xlabelsize=None, ylabelsize=None, + xlim=None, ylim=None, **kwds): """ Draw scatter plot of the DataFrame's series using matplotlib / pylab. @@ -1460,31 +1460,26 @@ def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, """ import matplotlib.pyplot as plt - def plot_group(group, ax, xlim=None, ylim=None,**kwds): + def plot_group(group, ax, **kwds): xvals = group[x].values yvals = group[y].values ax.scatter(xvals, yvals, **kwds) - if ylim is not None: - ax.set_ylim(ylim) - if xlim is not None: - ax.set_xlim(xlim) - ax.grid(grid) + _decorate_axes(ax, grid=grid, xlabelsize=xlabelsize, + ylabelsize=ylabelsize, xlim=xlim, ylim=ylim) if by is not None: fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, - sharex=sharex, sharey=sharey, xlim=xlim, ylim=ylim, - ax=ax, **kwds) + sharex=sharex, sharey=sharey, ax=ax, **kwds) else: if ax is None: fig = plt.figure() ax = fig.add_subplot(111) else: fig = ax.get_figure() - plot_group(data, ax, xlim=xlim, ylim=ylim, **kwds) - ax.set_ylabel(com._stringify(y)) - ax.set_xlabel(com._stringify(x)) - - ax.grid(grid) + plot_group(data, ax, **kwds) + _decorate_axes(ax, grid=grid, xlabelsize=xlabelsize, + ylabelsize=ylabelsize, xlim=xlim, ylim=ylim, + xlabel=com._stringify(x), ylabel=com._stringify(y)) return fig @@ -1762,6 +1757,50 @@ def _get_layout(nplots): else: return k, k +def _decorate_axes(axes, title=None, legend=None, + xlim=None, ylim=None, grid=None, + xticks=None, yticks=None, xticklabels=None, yticklabels=None, + xlabelsize=None, ylabelsize=None, xrot=None, yrot=None, + xlabel=None, ylabel=None): + + import matplotlib.pyplot as plt + import matplotlib.axes + assert isinstance(axes, matplotlib.axes.SubplotBase) + + if title is not None: + axes.set_title(title) + + if legend == True: + axes.legend() + elif isinstance(legend, dict): + axes.legend(**legend) + + if xticks is not None: + axes.xaxis.set_ticks(xticks) + + if xticklabels is not None: + axes.xaxis.set_ticklabels(xtickslabels) + if xlabelsize is not None: + plt.setp(axes.get_xticklabels(), fontsize=xlabelsize) + if xrot is not None: + plt.setp(axes.get_xticklabels(), rotation=xrot) + + if xlabel is not None: + axes.set_xlabel(xlabel) + + if yticks is not None: + axes.yaxis.set_ticks(yticks) + + if yticklabels is not None: + axes.yaxis.set_ticklabels(yticklabels) + if ylabelsize is not None: + plt.setp(axes.get_yticklabels(), fontsize=ylabelsize) + if yrot is not None: + plt.setp(axes.get_yticklabels(), rotation=yrot) + + if ylabel is not None: + axes.set_ylabel(ylabel) + # copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0 def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,