Skip to content

Commit

Permalink
feat[rust, python]: coalesce function (#4931)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 22, 2022
1 parent 770a089 commit 23a2309
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 34 deletions.
2 changes: 2 additions & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ top_k = ["polars-lazy/top_k"]
algo = ["polars-algo"]
cse = ["polars-lazy/cse"]
propagate_nans = ["polars-lazy/propagate_nans"]
coalesce = ["polars-lazy/coalesce"]

test = [
"lazy",
Expand Down Expand Up @@ -259,6 +260,7 @@ docs-selection = [
"timezones",
"arg_where",
"propagate_nans",
"coalesce",
]

bench = [
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ top_k = ["polars-ops/top_k"]
semi_anti_join = ["polars-core/semi_anti_join"]
cse = []
propagate_nans = ["polars-ops/propagate_nans"]
coalesce = []

# no guarantees whatsoever
private = ["polars-time/private"]
Expand Down
19 changes: 19 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/fill_null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,22 @@ pub(super) fn fill_null(s: &[Series], super_type: &DataType) -> PolarsResult<Ser
array.zip_with_same_type(&mask, &fill_value)
}
}

pub(super) fn coalesce(s: &mut [Series]) -> PolarsResult<Series> {
if s.is_empty() {
Err(PolarsError::ComputeError(
"cannot coalesce empty list".into(),
))
} else {
let mut out = s[0].clone();
for s in s {
if !out.null_count() == 0 {
return Ok(out);
} else {
let mask = out.is_not_null();
out = out.zip_with_same_type(&mask, s)?;
}
}
Ok(out)
}
}
74 changes: 43 additions & 31 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,61 +109,64 @@ pub enum FunctionExpr {
Not,
IsUnique,
IsDuplicated,
Coalesce,
}

impl Display for FunctionExpr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use FunctionExpr::*;

match self {
NullCount => write!(f, "null_count"),
Pow => write!(f, "pow"),
let s = match self {
NullCount => "null_count",
Pow => "pow",
#[cfg(feature = "row_hash")]
Hash(_, _, _, _) => write!(f, "hash"),
Hash(_, _, _, _) => "hash",
#[cfg(feature = "is_in")]
IsIn => write!(f, "is_in"),
IsIn => "is_in",
#[cfg(feature = "arg_where")]
ArgWhere => write!(f, "arg_where"),
ArgWhere => "arg_where",
#[cfg(feature = "search_sorted")]
SearchSorted => write!(f, "search_sorted"),
SearchSorted => "search_sorted",
#[cfg(feature = "strings")]
StringExpr(s) => write!(f, "{}", s),
StringExpr(s) => return write!(f, "{}", s),
#[cfg(feature = "temporal")]
TemporalExpr(fun) => write!(f, "{}", fun),
TemporalExpr(fun) => return write!(f, "{}", fun),
#[cfg(feature = "date_offset")]
DateOffset(_) => write!(f, "dt.offset_by"),
DateOffset(_) => "dt.offset_by",
#[cfg(feature = "trigonometry")]
Trigonometry(func) => write!(f, "{}", func),
Trigonometry(func) => return write!(f, "{}", func),
#[cfg(feature = "sign")]
Sign => write!(f, "sign"),
FillNull { .. } => write!(f, "fill_null"),
Sign => "sign",
FillNull { .. } => "fill_null",
#[cfg(feature = "is_in")]
ListContains => write!(f, "arr.contains"),
ListContains => "arr.contains",
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { .. } => write!(f, "rolling_skew"),
ShiftAndFill { .. } => write!(f, "shift_and_fill"),
Nan(_) => write!(f, "nan"),
RollingSkew { .. } => "rolling_skew",
ShiftAndFill { .. } => "shift_and_fill",
Nan(_) => "nan",
#[cfg(feature = "round_series")]
Clip { min, max } => match (min, max) {
(Some(_), Some(_)) => write!(f, "clip"),
(None, Some(_)) => write!(f, "clip_max"),
(Some(_), None) => write!(f, "clip_min"),
(Some(_), Some(_)) => "clip",
(None, Some(_)) => "clip_max",
(Some(_), None) => "clip_min",
_ => unreachable!(),
},
#[cfg(feature = "list")]
ListExpr(func) => write!(f, "{}", func),
ListExpr(func) => return write!(f, "{}", func),
#[cfg(feature = "dtype-struct")]
StructExpr(func) => write!(f, "{}", func),
StructExpr(func) => return write!(f, "{}", func),
#[cfg(feature = "top_k")]
TopK { .. } => write!(f, "top_k"),
Shift(_) => write!(f, "shift"),
Reverse => write!(f, "reverse"),
Not => write!(f, "is_not"),
IsNull => write!(f, "is_null"),
IsNotNull => write!(f, "is_not_null"),
IsUnique => write!(f, "is_unique"),
IsDuplicated => write!(f, "is_duplicated"),
}
TopK { .. } => "top_k",
Shift(_) => "shift",
Reverse => "reverse",
Not => "is_not",
IsNull => "is_null",
IsNotNull => "is_not_null",
IsUnique => "is_unique",
IsDuplicated => "is_duplicated",
Coalesce => "coalesce",
};
write!(f, "{}", s)
}
}

Expand All @@ -177,6 +180,14 @@ macro_rules! wrap {
// all expression arguments are in the slice.
// the first element is the root expression.
macro_rules! map_as_slice {
($func:path) => {{
let f = move |s: &mut [Series]| {
$func(s)
};

SpecialEq::new(Arc::new(f))
}};

($func:path, $($args:expr),*) => {{
let f = move |s: &mut [Series]| {
$func(s, $($args),*)
Expand Down Expand Up @@ -324,6 +335,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Not => map!(dispatch::is_not),
IsUnique => map!(dispatch::is_unique),
IsDuplicated => map!(dispatch::is_duplicated),
Coalesce => map_as_slice!(fill_null::coalesce),
}
}
}
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ impl FunctionExpr {
match self {
NullCount => with_dtype(IDX_DTYPE),
Pow => super_type(),
Coalesce => super_type(),
#[cfg(feature = "row_hash")]
Hash(..) => with_dtype(DataType::UInt64),
#[cfg(feature = "is_in")]
Expand Down
15 changes: 15 additions & 0 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,7 @@ pub fn repeat<L: Literal>(value: L, n_times: Expr) -> Expr {
}

#[cfg(feature = "arg_where")]
#[cfg_attr(docsrs, doc(cfg(feature = "arg_where")))]
/// Get the indices where `condition` evaluates `true`.
pub fn arg_where<E: Into<Expr>>(condition: E) -> Expr {
let condition = condition.into();
Expand All @@ -925,3 +926,17 @@ pub fn arg_where<E: Into<Expr>>(condition: E) -> Expr {
},
}
}

/// Folds the expressions from left to right keeping the first no null values.
pub fn coalesce(exprs: &[Expr]) -> Expr {
let input = exprs.to_vec();
Expr::Function {
input,
function: FunctionExpr::Coalesce,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
cast_to_supertypes: true,
..Default::default()
},
}
}
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ These functions can be used as expression and sometimes also in eager contexts.
concat_list
concat_str
count
coalesce
cov
date
datetime
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def version() -> str:
arg_where,
argsort_by,
avg,
coalesce,
col,
collect_all,
concat_list,
Expand Down Expand Up @@ -258,6 +259,7 @@ def version() -> str:
"var",
"struct",
"duration",
"coalesce",
# polars.convert
"from_dict",
"from_dicts",
Expand Down
7 changes: 4 additions & 3 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import math
import random
from datetime import date, datetime, time
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING, Any, Callable, Sequence
from warnings import warn

Expand Down Expand Up @@ -52,7 +52,7 @@ def selection_to_pyexpr_list(
exprs: str
| Expr
| pli.Series
| Sequence[str | Expr | pli.Series | date | datetime | int | float],
| Sequence[str | Expr | pli.Series | timedelta | date | datetime | int | float],
) -> list[PyExpr]:
if isinstance(exprs, (str, Expr, pli.Series)):
exprs = [exprs]
Expand All @@ -72,6 +72,7 @@ def expr_to_lit_or_expr(
| date
| datetime
| time
| timedelta
| Sequence[(int | float | str | None)]
),
str_to_lit: bool = True,
Expand All @@ -95,7 +96,7 @@ def expr_to_lit_or_expr(
if isinstance(expr, str) and not str_to_lit:
return pli.col(expr)
elif (
isinstance(expr, (int, float, str, pli.Series, datetime, date, time))
isinstance(expr, (int, float, str, pli.Series, datetime, date, time, timedelta))
or expr is None
):
return pli.lit(expr)
Expand Down
19 changes: 19 additions & 0 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from polars.polars import arg_where as py_arg_where
from polars.polars import argsort_by as pyargsort_by
from polars.polars import as_struct as _as_struct
from polars.polars import coalesce_exprs as _coalesce_exprs
from polars.polars import col as pycol
from polars.polars import collect_all as _collect_all
from polars.polars import cols as pycols
Expand Down Expand Up @@ -1847,3 +1848,21 @@ def arg_where(
else:
condition = pli.expr_to_lit_or_expr(condition, str_to_lit=True)
return pli.wrap_expr(py_arg_where(condition._pyexpr))


def coalesce(
exprs: Sequence[
pli.Expr | str | date | datetime | timedelta | int | float | bool | pli.Series
],
) -> pli.Expr:
"""
Folds the expressions from left to right keeping the first no null values.
Parameters
----------
exprs
Expression to coalesce.
"""
exprs = pli.selection_to_pyexpr_list(exprs)
return pli.wrap_expr(_coalesce_exprs(exprs))
7 changes: 7 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,12 @@ fn max_exprs(exprs: Vec<PyExpr>) -> PyExpr {
polars::lazy::dsl::max_exprs(exprs).into()
}

#[pyfunction]
fn coalesce_exprs(exprs: Vec<PyExpr>) -> PyExpr {
let exprs = exprs.to_exprs();
polars::lazy::dsl::coalesce(&exprs).into()
}

#[pyfunction]
fn sum_exprs(exprs: Vec<PyExpr>) -> PyExpr {
let exprs = exprs.to_exprs();
Expand Down Expand Up @@ -551,5 +557,6 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(pool_size)).unwrap();
m.add_wrapped(wrap_pyfunction!(arg_where)).unwrap();
m.add_wrapped(wrap_pyfunction!(get_idx_type)).unwrap();
m.add_wrapped(wrap_pyfunction!(coalesce_exprs)).unwrap();
Ok(())
}
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ def test_nan_aggregations() -> None:
str(df.groupby("b").agg(aggs).to_dict(False))
== "{'b': [1], 'max': [3.0], 'min': [2.0], 'nan_max': [nan], 'nan_min': [nan]}"
)


def test_coalesce() -> None:
df = pl.DataFrame(
{
"a": [None, None, None, None],
"b": [1, 2, None, None],
"c": [1, None, 3, None],
}
)
assert df.select(pl.coalesce(["a", "b", "c", 10])).to_dict(False) == {
"a": [1.0, 2.0, 3.0, 10.0]
}

0 comments on commit 23a2309

Please sign in to comment.