Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,13 +380,21 @@ def scat(**kwds):
_check_plot_works(scat, diagonal='kde')
_check_plot_works(scat, diagonal='density')
_check_plot_works(scat, diagonal='hist')

def scat2(x, y, by=None, ax=None, figsize=None):
return plt.scatter_plot(df, x, y, by, ax, figsize=None)

def scat2(x, y, by=None, ax=None, figsize=None, **kwds):
return plt.scatter_plot(df, x, y, by, ax, figsize=None, **kwds)
_check_plot_works(scat2, 0, 1)
grouper = Series(np.repeat([1, 2, 3, 4, 5], 20), df.index)
_check_plot_works(scat2, 0, 1, by=grouper)
_check_plot_works(scat2, 0, 1, color='red', xlim=(1,5), ylim=(1,5))
_check_plot_works(scat2, 0, 1, by=grouper, sharex=True, sharey=True)

xf, yf = 20, 30
fig = scat2(0, 1, xlabelsize=xf, ylabelsize=yf)
for ax in fig.axes:
ytick = ax.get_yticklabels()[0]
xtick = ax.get_xticklabels()[0]
self.assertAlmostEqual(ytick.get_fontsize(), yf)
self.assertAlmostEqual(xtick.get_fontsize(), xf)

@slow
def test_andrews_curves(self):
Expand Down
96 changes: 81 additions & 15 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,34 +1431,55 @@ def format_date_labels(ax, rot):
pass


def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False):
def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False,
sharex=True, sharey=True, xlabelsize=None, ylabelsize=None,
xlim=None, ylim=None, **kwds):
"""
Draw scatter plot of the DataFrame's series using matplotlib / pylab.

Returns
-------
fig : matplotlib.Figure
Parameters
----------
data : Dataframe
x : column name of Dataframe for x axis
y : column name of Dataframe for y axis
by : column in the DataFrame to group by
ax : matplotlib axes object, default None
figsize :
grid : boolean, default True
Whether to show axis grid lines
xlabelsize : int, default None
If specified changes the x-axis label size
ylabelsize : int, default None
If specified changes the y-axis label size
sharex : bool, if True, the X axis will be shared amongst all subplots.
sharey : bool, if True, the Y axis will be shared amongst all subplots.
xlim : 2-tuple/list
ylim : 2-tuple/list
kwds : other plotting keyword arguments
To be passed to scatter function
"""
import matplotlib.pyplot as plt

def plot_group(group, ax):
def plot_group(group, ax, **kwds):
xvals = group[x].values
yvals = group[y].values
ax.scatter(xvals, yvals)
ax.grid(grid)
ax.scatter(xvals, yvals, **kwds)
_decorate_axes(ax, grid=grid, xlabelsize=xlabelsize,
ylabelsize=ylabelsize, xlim=xlim, ylim=ylim)

if by is not None:
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize, ax=ax)
fig = _grouped_plot(plot_group, data, by=by, figsize=figsize,
sharex=sharex, sharey=sharey, ax=ax, **kwds)
else:
if ax is None:
fig = plt.figure()
ax = fig.add_subplot(111)
else:
fig = ax.get_figure()
plot_group(data, ax)
ax.set_ylabel(com._stringify(y))
ax.set_xlabel(com._stringify(x))

ax.grid(grid)
plot_group(data, ax, **kwds)
_decorate_axes(ax, grid=grid, xlabelsize=xlabelsize,
ylabelsize=ylabelsize, xlim=xlim, ylim=ylim,
xlabel=com._stringify(x), ylabel=com._stringify(y))

return fig

Expand Down Expand Up @@ -1638,7 +1659,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,

def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None):
rot=0, grid=False, ax=None, **kwds):
from pandas.core.frame import DataFrame
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -1675,8 +1696,9 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
ax = ravel_axes[i]
if numeric_only and isinstance(group, DataFrame):
group = group._get_numeric_data()
plotf(group, ax)
plotf(group, ax, **kwds)
ax.set_title(com._stringify(key))
ax.grid(grid)

return fig, axes

Expand Down Expand Up @@ -1735,6 +1757,50 @@ def _get_layout(nplots):
else:
return k, k

def _decorate_axes(axes, title=None, legend=None,
xlim=None, ylim=None, grid=None,
xticks=None, yticks=None, xticklabels=None, yticklabels=None,
xlabelsize=None, ylabelsize=None, xrot=None, yrot=None,
xlabel=None, ylabel=None):

import matplotlib.pyplot as plt
import matplotlib.axes
assert isinstance(axes, matplotlib.axes.SubplotBase)

if title is not None:
axes.set_title(title)

if legend == True:
axes.legend()
elif isinstance(legend, dict):
axes.legend(**legend)

if xticks is not None:
axes.xaxis.set_ticks(xticks)

if xticklabels is not None:
axes.xaxis.set_ticklabels(xtickslabels)
if xlabelsize is not None:
plt.setp(axes.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(axes.get_xticklabels(), rotation=xrot)

if xlabel is not None:
axes.set_xlabel(xlabel)

if yticks is not None:
axes.yaxis.set_ticks(yticks)

if yticklabels is not None:
axes.yaxis.set_ticklabels(yticklabels)
if ylabelsize is not None:
plt.setp(axes.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(axes.get_yticklabels(), rotation=yrot)

if ylabel is not None:
axes.set_ylabel(ylabel)

# copied from matplotlib/pyplot.py for compatibility with matplotlib < 1.0

def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
Expand Down