Skip to content

Commit

Permalink
spearman's rank correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 15, 2021
1 parent b30de44 commit f8ebd59
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 0 deletions.
7 changes: 7 additions & 0 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ pub fn pearson_corr(a: Expr, b: Expr) -> Expr {
map_binary(a, b, function, Some(Field::new(name, DataType::Float32))).alias(name)
}

/// Compute the spearman rank correlation between two columns.
#[cfg(feature = "rank")]
#[cfg_attr(docsrs, doc(cfg(feature = "rank")))]
pub fn spearman_rank_corr(a: Expr, b: Expr) -> Expr {
pearson_corr(a.rank(RankMethod::Min), b.rank(RankMethod::Min)).alias("spearman_rank_corr")
}

/// Find the indexes that would sort these series in order of appearance.
/// That means that the first `Series` will be used to determine the ordering
/// until duplicates are found. Once duplicates are found, the next `Series` will
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ These functions can be used as expression and sometimes also in eager contexts.
tail
lit
pearson_corr
spearman_rank_corr
cov
map_binary
fold
Expand Down
23 changes: 23 additions & 0 deletions py-polars/polars/lazy/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from polars.polars import fold as pyfold
from polars.polars import lit as pylit
from polars.polars import pearson_corr as pypearson_corr
from polars.polars import spearman_rank_corr as pyspearman_rank_corr

_DOCUMENTING = False
except ImportError:
Expand All @@ -45,6 +46,7 @@
"tail",
"lit",
"pearson_corr",
"spearman_rank_corr",
"cov",
"map_binary",
"fold",
Expand Down Expand Up @@ -397,6 +399,27 @@ def lit(
return pl.lazy.expr.wrap_expr(pylit(value))


def spearman_rank_corr(
a: Union[str, "pl.Expr"],
b: Union[str, "pl.Expr"],
) -> "pl.Expr":
"""
Compute the spearman rank correlation between two columns.
Parameters
----------
a
Column name or Expression.
b
Column name or Expression.
"""
if isinstance(a, str):
a = col(a)
if isinstance(b, str):
b = col(b)
return pl.lazy.expr.wrap_expr(pyspearman_rank_corr(a._pyexpr, b._pyexpr))


def pearson_corr(
a: Union[str, "pl.Expr"],
b: Union[str, "pl.Expr"],
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ fn pearson_corr(a: dsl::PyExpr, b: dsl::PyExpr) -> dsl::PyExpr {
polars::lazy::functions::pearson_corr(a.inner, b.inner).into()
}

#[pyfunction]
fn spearman_rank_corr(a: dsl::PyExpr, b: dsl::PyExpr) -> dsl::PyExpr {
polars::lazy::functions::spearman_rank_corr(a.inner, b.inner).into()
}

#[pyfunction]
fn cov(a: dsl::PyExpr, b: dsl::PyExpr) -> dsl::PyExpr {
polars::lazy::functions::cov(a.inner, b.inner).into()
Expand Down Expand Up @@ -223,5 +228,6 @@ fn polars(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(concat_series)).unwrap();
m.add_wrapped(wrap_pyfunction!(ipc_schema)).unwrap();
m.add_wrapped(wrap_pyfunction!(collect_all)).unwrap();
m.add_wrapped(wrap_pyfunction!(spearman_rank_corr)).unwrap();
Ok(())
}
18 changes: 18 additions & 0 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,3 +629,21 @@ def test_collect_all(df):
out = pl.collect_all([lf1, lf2])
out[0][0, 0] == 6
out[1][0, 0] == 12.0


def test_spearman_corr():
df = pl.DataFrame(
{
"era": [1, 1, 1, 2, 2, 2],
"prediction": [2, 4, 5, 190, 1, 4],
"target": [1, 3, 2, 1, 43, 3],
}
)

out = (
df.groupby("era", maintain_order=True).agg(
pl.spearman_rank_corr(pl.col("prediction"), pl.col("target")).alias("c"),
)
)["c"]
assert np.isclose(out[0], 0.5)
assert np.isclose(out[1], -1.0)

0 comments on commit f8ebd59

Please sign in to comment.