Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kde plotting #1059

Closed
wants to merge 9 commits into from
97 changes: 89 additions & 8 deletions pandas/tools/plotting.py
Expand Up @@ -3,6 +3,7 @@
from itertools import izip

import numpy as np
from scipy import stats

from pandas.util.decorators import cache_readonly
import pandas.core.common as com
Expand All @@ -11,8 +12,8 @@
from pandas.tseries.period import PeriodIndex
from pandas.tseries.offsets import DateOffset

def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
**kwds):

def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, **kwds):
"""
Draw a matrix of scatter plots.

Expand All @@ -36,6 +37,51 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,

for i, a in zip(range(n), df.columns):
for j, b in zip(range(n), df.columns):
if i == j:
# Deal with the diagonal by drawing a histogram there.
if diagonal == 'hist':
axes[i, j].hist(df[a])
elif diagonal == 'kde':
y = df[a]
gkde = stats.gaussian_kde(y)
ind = np.linspace(min(y), max(y), 1000)
axes[i, j].plot(ind, gkde.evaluate(ind), **kwds)
axes[i, j].yaxis.set_visible(False)
axes[i, j].xaxis.set_visible(False)
if i == 0 and j == 0:
axes[i, j].yaxis.set_ticks_position('left')
axes[i, j].yaxis.set_label_position('left')
axes[i, j].yaxis.set_visible(True)
if i == n - 1 and j == n - 1:
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')
axes[i, j].yaxis.set_visible(True)
else:
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)
axes[i, j].yaxis.set_visible(False)
axes[i, j].xaxis.set_visible(False)

# setup labels
if i == 0 and j % 2 == 1:
axes[i, j].set_xlabel(b, visible=True)
axes[i, j].xaxis.set_visible(True)
axes[i, j].xaxis.set_ticks_position('top')
axes[i, j].xaxis.set_label_position('top')
if i == n - 1 and j % 2 == 0:
axes[i, j].set_xlabel(b, visible=True)
axes[i, j].xaxis.set_visible(True)
axes[i, j].xaxis.set_ticks_position('bottom')
axes[i, j].xaxis.set_label_position('bottom')
if j == 0 and i % 2 == 0:
axes[i, j].set_ylabel(a, visible=True)
axes[i, j].yaxis.set_visible(True)
axes[i, j].yaxis.set_ticks_position('left')
axes[i, j].yaxis.set_label_position('left')
if j == n - 1 and i % 2 == 1:
axes[i, j].set_ylabel(a, visible=True)
axes[i, j].yaxis.set_visible(True)
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')
axes[i, j].scatter(df[b], df[a], alpha=alpha, **kwds)
axes[i, j].set_xlabel('')
axes[i, j].set_ylabel('')
Expand Down Expand Up @@ -84,15 +130,14 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False,
axes[i, j].set_yticklabels(ticks)
axes[i, j].yaxis.set_ticks_position('right')
axes[i, j].yaxis.set_label_position('right')

axes[i, j].grid(b=grid)

# ensure {x,y}lim off diagonal are the same as diagonal
for i in range(n):
for j in range(n):
if i != j:
axes[i, j].set_xlim(axes[j, j].get_xlim())
axes[i, j].set_ylim(axes[i, i].get_ylim())
#for i in range(n):
# for j in range(n):
# if i != j:
# axes[i, j].set_xlim(axes[j, j].get_xlim())
# axes[i, j].set_ylim(axes[i, i].get_ylim())

return axes

Expand Down Expand Up @@ -326,6 +371,38 @@ def _get_xticks(self):

return x

class KdePlot(MPLPlot):
def __init__(self, data, **kwargs):
MPLPlot.__init__(self, data, **kwargs)

def _get_plot_function(self):
return self.plt.Axes.plot

def _make_plot(self):
plotf = self._get_plot_function()
for i, (label, y) in enumerate(self._iter_data()):
if self.subplots:
ax = self.axes[i]
style = 'k'
else:
style = '' # empty string ignored
ax = self.ax
if self.style:
style = self.style
gkde = stats.gaussian_kde(y)
sample_range = max(y) - min(y)
ind = np.linspace(min(y) - 0.5 * sample_range,
max(y) + 0.5 * sample_range, 1000)
ax.set_ylabel("Density")
plotf(ax, ind, gkde.evaluate(ind), style, label=label, **self.kwds)
ax.grid(self.grid)

def _post_plot_logic(self):
df = self.data

if self.subplots and self.legend:
self.axes[0].legend(loc='best')

class LinePlot(MPLPlot):

def __init__(self, data, **kwargs):
Expand Down Expand Up @@ -608,6 +685,8 @@ def plot_frame(frame=None, subplots=False, sharex=True, sharey=False,
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot
else:
raise ValueError('Invalid chart type given %s' % kind)

Expand Down Expand Up @@ -670,6 +749,8 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None,
klass = LinePlot
elif kind in ('bar', 'barh'):
klass = BarPlot
elif kind == 'kde':
klass = KdePlot

if ax is None:
ax = _gca()
Expand Down