Skip to content

Commit

Permalink
ENH: correlation function accepts method being a callable (#22684)
Browse files Browse the repository at this point in the history
  • Loading branch information
shadiakiki1986 authored and jreback committed Sep 26, 2018
1 parent 4a459b8 commit a393675
Show file tree
Hide file tree
Showing 6 changed files with 85 additions and 4 deletions.
15 changes: 15 additions & 0 deletions doc/source/computation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,21 @@ Like ``cov``, ``corr`` also supports the optional ``min_periods`` keyword:
frame.corr(min_periods=12)
.. versionadded:: 0.24.0

The ``method`` argument can also be a callable for a generic correlation
calculation. In this case, it should be a single function
that produces a single value from two ndarray inputs. Suppose we wanted to
compute the correlation based on histogram intersection:

.. ipython:: python
# histogram intersection
histogram_intersection = lambda a, b: np.minimum(
np.true_divide(a, a.sum()), np.true_divide(b, b.sum())
).sum()
frame.corr(method=histogram_intersection)
A related method :meth:`~DataFrame.corrwith` is implemented on DataFrame to
compute the correlation between like-labeled Series contained in different
DataFrame objects.
Expand Down
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ New features
- :func:`DataFrame.to_parquet` now accepts ``index`` as an argument, allowing
the user to override the engine's default behavior to include or omit the
dataframe's indexes from the resulting Parquet file. (:issue:`20768`)
- :meth:`DataFrame.corr` and :meth:`Series.corr` now accept a callable for generic calculation methods of correlation, e.g. histogram intersection (:issue:`22684`)


.. _whatsnew_0240.enhancements.extension_array_operators:

Expand Down
20 changes: 18 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6672,10 +6672,14 @@ def corr(self, method='pearson', min_periods=1):
Parameters
----------
method : {'pearson', 'kendall', 'spearman'}
method : {'pearson', 'kendall', 'spearman'} or callable
* pearson : standard correlation coefficient
* kendall : Kendall Tau correlation coefficient
* spearman : Spearman rank correlation
* callable: callable with input two 1d ndarrays
and returning a float
.. versionadded:: 0.24.0
min_periods : int, optional
Minimum number of observations required per pair of columns
to have a valid result. Currently only available for pearson
Expand All @@ -6684,6 +6688,18 @@ def corr(self, method='pearson', min_periods=1):
Returns
-------
y : DataFrame
Examples
--------
>>> import numpy as np
>>> histogram_intersection = lambda a, b: np.minimum(a, b
... ).sum().round(decimals=1)
>>> df = pd.DataFrame([(.2, .3), (.0, .6), (.6, .0), (.2, .1)],
... columns=['dogs', 'cats'])
>>> df.corr(method=histogram_intersection)
dogs cats
dogs 1.0 0.3
cats 0.3 1.0
"""
numeric_df = self._get_numeric_data()
cols = numeric_df.columns
Expand All @@ -6695,7 +6711,7 @@ def corr(self, method='pearson', min_periods=1):
elif method == 'spearman':
correl = libalgos.nancorr_spearman(ensure_float64(mat),
minp=min_periods)
elif method == 'kendall':
elif method == 'kendall' or callable(method):
if min_periods is None:
min_periods = 1
mat = ensure_float64(mat).T
Expand Down
2 changes: 2 additions & 0 deletions pandas/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,6 +766,8 @@ def nancorr(a, b, method='pearson', min_periods=None):
def get_corr_func(method):
if method in ['kendall', 'spearman']:
from scipy.stats import kendalltau, spearmanr
elif callable(method):
return method

def _pearson(a, b):
return np.corrcoef(a, b)[0, 1]
Expand Down
18 changes: 16 additions & 2 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1910,23 +1910,37 @@ def corr(self, other, method='pearson', min_periods=None):
Parameters
----------
other : Series
method : {'pearson', 'kendall', 'spearman'}
method : {'pearson', 'kendall', 'spearman'} or callable
* pearson : standard correlation coefficient
* kendall : Kendall Tau correlation coefficient
* spearman : Spearman rank correlation
* callable: callable with input two 1d ndarray
and returning a float
.. versionadded:: 0.24.0
min_periods : int, optional
Minimum number of observations needed to have a valid result
Returns
-------
correlation : float
Examples
--------
>>> import numpy as np
>>> histogram_intersection = lambda a, b: np.minimum(a, b
... ).sum().round(decimals=1)
>>> s1 = pd.Series([.2, .0, .6, .2])
>>> s2 = pd.Series([.3, .6, .0, .1])
>>> s1.corr(s2, method=histogram_intersection)
0.3
"""
this, other = self.align(other, join='inner', copy=False)
if len(this) == 0:
return np.nan

if method in ['pearson', 'spearman', 'kendall']:
if method in ['pearson', 'spearman', 'kendall'] or callable(method):
return nanops.nancorr(this.values, other.values, method=method,
min_periods=min_periods)

Expand Down
32 changes: 32 additions & 0 deletions pandas/tests/series/test_analytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,38 @@ def test_corr_invalid_method(self):
with tm.assert_raises_regex(ValueError, msg):
s1.corr(s2, method="____")

def test_corr_callable_method(self):
# simple correlation example
# returns 1 if exact equality, 0 otherwise
my_corr = lambda a, b: 1. if (a == b).all() else 0.

# simple example
s1 = Series([1, 2, 3, 4, 5])
s2 = Series([5, 4, 3, 2, 1])
expected = 0
tm.assert_almost_equal(
s1.corr(s2, method=my_corr),
expected)

# full overlap
tm.assert_almost_equal(
self.ts.corr(self.ts, method=my_corr), 1.)

# partial overlap
tm.assert_almost_equal(
self.ts[:15].corr(self.ts[5:], method=my_corr), 1.)

# No overlap
assert np.isnan(
self.ts[::2].corr(self.ts[1::2], method=my_corr))

# dataframe example
df = pd.DataFrame([s1, s2])
expected = pd.DataFrame([
{0: 1., 1: 0}, {0: 0, 1: 1.}])
tm.assert_almost_equal(
df.transpose().corr(method=my_corr), expected)

def test_cov(self):
# full overlap
tm.assert_almost_equal(self.ts.cov(self.ts), self.ts.std() ** 2)
Expand Down

0 comments on commit a393675

Please sign in to comment.