Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
ENH: integrate grouped histogram #2186
  • Loading branch information
changhiskhan committed Nov 16, 2012
1 parent 8be92dc commit 3fd8f36
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 32 deletions.
2 changes: 2 additions & 0 deletions RELEASE.rst
Expand Up @@ -33,6 +33,8 @@ pandas 0.9.2

**Improvements to existing features**

- Grouped histogram via `by` keyword in Series/DataFrame.hist (#2186)

**Bug fixes**

- Fixes bug when negative period passed to Series/DataFrame.diff (#2266)
Expand Down
17 changes: 17 additions & 0 deletions doc/source/visualization.rst
Expand Up @@ -271,6 +271,7 @@ Histograms
@savefig hist_plot_ex.png width=4.5in
df['A'].diff().hist()
For a DataFrame, ``hist`` plots the histograms of the columns on multiple
subplots:

Expand All @@ -281,6 +282,22 @@ subplots:
@savefig frame_hist_ex.png width=4.5in
df.diff().hist(color='k', alpha=0.5, bins=50)
New since 0.9.2, the ``by`` keyword can be specified to plot grouped histograms:

.. ipython:: python
:suppress:
plt.figure();
.. ipython:: python
data = Series(np.random.randn(1000))
@savefig grouped_hist.png width=4.5in
data.hist(by=np.random.randint(0, 4, 1000))
.. _visualization.box:

Box-Plotting
Expand Down
15 changes: 13 additions & 2 deletions pandas/tests/test_graphics.py
Expand Up @@ -134,6 +134,8 @@ def test_hist(self):
_check_plot_works(self.ts.hist)
_check_plot_works(self.ts.hist, grid=False)

_check_plot_works(self.ts.hist, by=self.ts.index.month)

@slow
def test_kde(self):
_check_plot_works(self.ts.plot, kind='kde')
Expand Down Expand Up @@ -609,11 +611,20 @@ def test_time_series_plot_color_kwargs(self):

@slow
def test_grouped_hist(self):
df = DataFrame(np.random.randn(50, 2), columns=['A', 'B'])
df['C'] = np.random.randint(0, 3, 50)
import matplotlib.pyplot as plt
df = DataFrame(np.random.randn(500, 2), columns=['A', 'B'])
df['C'] = np.random.randint(0, 4, 500)
axes = plotting.grouped_hist(df.A, by=df.C)
self.assert_(len(axes.ravel()) == 4)

plt.close('all')
axes = df.hist(by=df.C)
self.assert_(axes.ndim == 2)
self.assert_(len(axes.ravel()) == 4)

for ax in axes.ravel():
self.assert_(len(ax.patches) > 0)

PNG_PATH = 'tmp.png'

def _check_plot_works(f, *args, **kwargs):
Expand Down
100 changes: 70 additions & 30 deletions pandas/tools/plotting.py
Expand Up @@ -536,24 +536,36 @@ def r(h):

def grouped_hist(data, column=None, by=None, ax=None, bins=50, log=False,
figsize=None, layout=None, sharex=False, sharey=False,
rot=90):
rot=90, **kwargs):
"""
Grouped histogram
Parameters
----------
data: Series/DataFrame
column: object, optional
by: object, optional
ax: axes, optional
bins: int, default 50
log: boolean, default False
figsize: tuple, optional
layout: optional
sharex: boolean, default False
sharey: boolean, default False
rot: int, default 90
Returns
-------
fig : matplotlib.Figure
axes: collection of Matplotlib Axes
"""
# if isinstance(data, DataFrame):
# data = data[column]

def plot_group(group, ax):
ax.hist(group.dropna(), bins=bins)
ax.hist(group.dropna().values, bins=bins)

fig, axes = _grouped_plot(plot_group, data, column=column,
by=by, sharex=sharex, sharey=sharey,
figsize=figsize, layout=layout, rot=rot)
fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9,
hspace=0.3, wspace=0.2)
hspace=0.5, wspace=0.3)
return axes

class MPLPlot(object):
Expand Down Expand Up @@ -1573,7 +1585,7 @@ def plot_group(group, ax):
return fig


def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, xrot=None,
ylabelsize=None, yrot=None, ax=None,
sharex=False, sharey=False, **kwds):
"""
Expand All @@ -1597,6 +1609,27 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,
kwds : other plotting keyword arguments
To be passed to hist function
"""
if column is not None:
if not isinstance(column, (list, np.ndarray)):
column = [column]
data = data.ix[:, column]

if by is not None:

axes = grouped_hist(data, by=by, ax=ax, grid=grid, **kwds)

for ax in axes.ravel():
if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)

return axes

import matplotlib.pyplot as plt
n = len(data.columns)
rows, cols = 1, 1
Expand Down Expand Up @@ -1633,14 +1666,15 @@ def hist_frame(data, grid=True, xlabelsize=None, xrot=None,

return axes


def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
ylabelsize=None, yrot=None, **kwds):
def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None,
xrot=None, ylabelsize=None, yrot=None, **kwds):
"""
Draw histogram of the input series using matplotlib
Parameters
----------
by : object, optional
If passed, then used to form histograms for separate groups
ax : matplotlib axis object
If not passed, uses gca()
grid : boolean, default True
Expand All @@ -1663,24 +1697,30 @@ def hist_series(self, ax=None, grid=True, xlabelsize=None, xrot=None,
"""
import matplotlib.pyplot as plt

if ax is None:
ax = plt.gca()

values = self.dropna().values
if by is None:
if ax is None:
ax = plt.gca()
values = self.dropna().values

ax.hist(values, **kwds)
ax.grid(grid)
ax.hist(values, **kwds)
ax.grid(grid)
axes = np.array([ax])
else:
axes = grouped_hist(self, by=by, ax=ax, grid=grid, **kwds)

if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)
for ax in axes.ravel():
if xlabelsize is not None:
plt.setp(ax.get_xticklabels(), fontsize=xlabelsize)
if xrot is not None:
plt.setp(ax.get_xticklabels(), rotation=xrot)
if ylabelsize is not None:
plt.setp(ax.get_yticklabels(), fontsize=ylabelsize)
if yrot is not None:
plt.setp(ax.get_yticklabels(), rotation=yrot)

return ax
if axes.ndim == 1 and len(axes) == 1:
return axes[0]
return axes


def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,
Expand Down Expand Up @@ -1751,7 +1791,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None,

def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
figsize=None, sharex=True, sharey=True, layout=None,
rot=0, ax=None):
rot=0, ax=None, **kwargs):
from pandas.core.frame import DataFrame
import matplotlib.pyplot as plt

Expand Down Expand Up @@ -1788,15 +1828,15 @@ def _grouped_plot(plotf, data, column=None, by=None, numeric_only=True,
ax = ravel_axes[i]
if numeric_only and isinstance(group, DataFrame):
group = group._get_numeric_data()
plotf(group, ax)
plotf(group, ax, **kwargs)
ax.set_title(com.pprint_thing(key))

return fig, axes


def _grouped_plot_by_column(plotf, data, columns=None, by=None,
numeric_only=True, grid=False,
figsize=None, ax=None):
figsize=None, ax=None, **kwargs):
import matplotlib.pyplot as plt

grouped = data.groupby(by)
Expand All @@ -1822,7 +1862,7 @@ def _grouped_plot_by_column(plotf, data, columns=None, by=None,
for i, col in enumerate(columns):
ax = ravel_axes[i]
gp_col = grouped[col]
plotf(gp_col, ax)
plotf(gp_col, ax, **kwargs)
ax.set_title(col)
ax.set_xlabel(com.pprint_thing(by))
ax.grid(grid)
Expand Down

0 comments on commit 3fd8f36

Please sign in to comment.