Skip to content

Commit

Permalink
arg_where expression (#3757)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 22, 2022
1 parent 4f55ebc commit c20d943
Show file tree
Hide file tree
Showing 21 changed files with 153 additions and 77 deletions.
2 changes: 2 additions & 0 deletions polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ list_to_struct = ["polars-ops/list_to_struct", "polars-lazy/list_to_struct"]
describe = ["polars-core/describe"]
timezones = ["polars-core/timezones"]
string_justify = ["polars-lazy/string_justify", "polars-ops/string_justify"]
arg_where = ["polars-lazy/arg_where"]

test = [
"lazy",
Expand Down Expand Up @@ -235,6 +236,7 @@ docs-selection = [
"list_eval",
"cumulative_eval",
"timezones",
"arg_where",
]

bench = [
Expand Down
9 changes: 0 additions & 9 deletions polars/polars-core/src/chunked_array/boolean.rs

This file was deleted.

1 change: 0 additions & 1 deletion polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use std::sync::Arc;
pub mod ops;
#[macro_use]
pub mod arithmetic;
pub mod boolean;
pub mod builder;
pub mod cast;
pub mod comparison;
Expand Down
15 changes: 0 additions & 15 deletions polars/polars-core/src/chunked_array/ops/take/take_single.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,3 @@ impl TakeRandom for ListChunked {
})
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
#[should_panic]
fn test_oob() {
let data: Series = [1.0, 2.0, 3.0].iter().collect();
let data = data.f64().unwrap();
let matches = data.equal(5.0);
let matches_indexes = matches.arg_true();
matches_indexes.get(0);
}
}
4 changes: 0 additions & 4 deletions polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,6 @@ impl SeriesTrait for SeriesWrap<BooleanChunked> {
ArgAgg::arg_max(&self.0)
}

fn arg_true(&self) -> Result<IdxCa> {
Ok(self.0.arg_true())
}

fn is_null(&self) -> BooleanChunked {
self.0.is_null()
}
Expand Down
7 changes: 0 additions & 7 deletions polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,13 +457,6 @@ pub trait SeriesTrait:
None
}

/// Get indexes that evaluate true
fn arg_true(&self) -> Result<IdxCa> {
Err(PolarsError::InvalidOperation(
"arg_true can only be called for boolean dtype".into(),
))
}

/// Get a mask of the null values.
fn is_null(&self) -> BooleanChunked {
invalid_operation_panic!(self)
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ list_to_struct = ["polars-ops/list_to_struct"]
python = ["pyo3"]
row_hash = ["polars-core/row_hash"]
string_justify = ["polars-ops/string_justify"]
arg_where = []

# no guarantees whatsoever
private = ["polars-time/private"]
Expand Down
33 changes: 33 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/arg_where.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use super::*;
use polars_arrow::trusted_len::PushUnchecked;

pub(super) fn arg_where(s: &mut [Series]) -> Result<Series> {
let predicate = s[0].bool()?;

if predicate.is_empty() {
Ok(Series::full_null(predicate.name(), 0, &IDX_DTYPE))
} else {
let capacity = predicate.sum().unwrap();
let mut out = Vec::with_capacity(capacity as usize);
let mut cnt = 0 as IdxSize;

predicate.downcast_iter().for_each(|arr| {
let values = match arr.validity() {
Some(validity) => validity & arr.values(),
None => arr.values().clone(),
};

// todo! could use chunkiter from arrow here
for bit in values.iter() {
if bit {
// safety:
// we allocated enough slots
unsafe { out.push_unchecked(cnt) }
}
cnt += 1;
}
});
let arr = Box::new(IdxArr::from_vec(out)) as ArrayRef;
Ok(IdxCa::from_chunks(predicate.name(), vec![arr]).into_series())
}
}
10 changes: 10 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "arg_where")]
mod arg_where;
#[cfg(feature = "is_in")]
mod is_in;
mod pow;
Expand All @@ -16,6 +18,8 @@ pub enum FunctionExpr {
Hash(usize),
#[cfg(feature = "is_in")]
IsIn,
#[cfg(feature = "arg_where")]
ArgWhere,
}

impl FunctionExpr {
Expand Down Expand Up @@ -46,6 +50,8 @@ impl FunctionExpr {
Hash(_) => with_dtype(DataType::UInt64),
#[cfg(feature = "is_in")]
IsIn => with_dtype(DataType::Boolean),
#[cfg(feature = "arg_where")]
ArgWhere => with_dtype(IDX_DTYPE),
}
}
}
Expand Down Expand Up @@ -82,6 +88,10 @@ impl From<FunctionExpr> for NoEq<Arc<dyn SeriesUdf>> {
IsIn => {
wrap!(is_in::is_in)
}
#[cfg(feature = "arg_where")]
ArgWhere => {
wrap!(arg_where::arg_where)
}
}
}
}
19 changes: 19 additions & 0 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
//!
//! Functions on expressions that might be useful.
//!
#[cfg(feature = "arg_where")]
use crate::dsl::function_expr::FunctionExpr;
use crate::prelude::*;
use crate::utils::has_wildcard;
use polars_core::export::arrow::temporal_conversions::NANOSECONDS;
Expand Down Expand Up @@ -861,6 +863,7 @@ pub fn as_struct(exprs: &[Expr]) -> Expr {
})
}

/// Repeat a literal `value` `n` times.
pub fn repeat<L: Literal>(value: L, n_times: Expr) -> Expr {
let function = |s: Series, n: Series| {
let n = n.get(0).extract::<usize>().ok_or_else(|| {
Expand All @@ -870,3 +873,19 @@ pub fn repeat<L: Literal>(value: L, n_times: Expr) -> Expr {
};
apply_binary(lit(value), n_times, function, GetOutput::same_type())
}

#[cfg(feature = "arg_where")]
/// Get the indices where `condition` evaluates `true`.
pub fn arg_where<E: Into<Expr>>(condition: E) -> Expr {
let condition = condition.into();
Expr::Function {
input: vec![condition],
function: FunctionExpr::ArgWhere,
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: false,
fmt_str: "arg_where",
},
}
}
3 changes: 2 additions & 1 deletion polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@
//! - `horizontal_concat` - Concat horizontally and extend with null values if lengths don't match
//! - `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series)
//! - `partition_by` - Split into multiple DataFrames partitioned by groups.
//! * `Series` operations:
//! * `Series`/`Expression` operations:
//! - `is_in` - [Check for membership in `Series`](crate::chunked_array::ops::IsIn)
//! - `zip_with` - [Zip two Series/ ChunkedArrays](crate::chunked_array::ops::ChunkZip)
//! - `round_series` - round underlying float types of `Series`.
Expand Down Expand Up @@ -231,6 +231,7 @@
//! - `list_to_struct` - Convert `List` to `Struct` dtypes.
//! - `list_eval` - Apply expressions over list elements.
//! - `cumulative_eval` - Apply expressions over cumulatively increasing windows.
//! - `argwhere` Get indices where condition holds.
//! * `DataFrame` pretty printing
//! - `fmt` - Activate DataFrame formatting
//!
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ features = [
"list_to_struct",
"to_dummies",
"string_justify",
"arg_where",
]

# [patch.crates-io]
Expand Down
4 changes: 2 additions & 2 deletions py-polars/docs/source/reference/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ Conversion
from_arrow
from_pandas

Eager functions
~~~~~~~~~~~~~~~
Eager/Lazy functions
~~~~~~~~~~~~~~~~~~~~
.. autosummary::
:toctree: api/

Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def version() -> str:
DataFrame,
wrap_df,
)
from polars.internals.functions import arg_where, concat, date_range, get_dummies
from polars.internals.functions import concat, date_range, get_dummies
from polars.internals.io import read_ipc_schema, read_parquet_schema
from polars.internals.lazy_frame import LazyFrame
from polars.internals.lazy_functions import _date as date
Expand All @@ -68,6 +68,7 @@ def version() -> str:
any,
apply,
arange,
arg_where,
argsort_by,
avg,
col,
Expand Down
11 changes: 10 additions & 1 deletion py-polars/polars/internals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@
from .frame import DataFrame, LazyFrame, wrap_df, wrap_ldf
from .functions import concat, date_range # DataFrame.describe() & DataFrame.upsample()
from .io import _is_local_file, _prepare_file_arg, read_ipc_schema, read_parquet_schema
from .lazy_functions import all, argsort_by, col, concat_list, element, lit, select
from .lazy_functions import (
all,
arg_where,
argsort_by,
col,
concat_list,
element,
lit,
select,
)
from .series import Series, wrap_s
from .whenthen import when # used in expr.clip()
29 changes: 0 additions & 29 deletions py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,35 +146,6 @@ def concat(
return out


def arg_where(mask: "pli.Series") -> "pli.Series":
"""
Get index values where Boolean mask evaluate True.
Parameters
----------
mask
Boolean Series.
Returns
-------
UInt32 Series
Examples
--------
>>> df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
>>> pl.arg_where(df.select(pl.col("a") % 2 == 0).to_series())
shape: (2,)
Series: '' [u32]
[
1
3
]
"""
return mask.arg_true()


def date_range(
low: datetime,
high: datetime,
Expand Down
63 changes: 63 additions & 0 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

try:
from polars.polars import arange as pyarange
from polars.polars import arg_where as py_arg_where
from polars.polars import argsort_by as pyargsort_by
from polars.polars import as_struct as _as_struct
from polars.polars import binary_function as pybinary_function
Expand Down Expand Up @@ -1653,3 +1654,65 @@ def repeat(
if isinstance(n, int):
n = lit(n)
return pli.wrap_expr(_repeat(value, n._pyexpr))


@overload
def arg_where(
condition: Union["pli.Expr", "pli.Series"],
eager: Literal[False] = ...,
) -> "pli.Expr":
...


@overload
def arg_where(
condition: Union["pli.Expr", "pli.Series"], eager: Literal[True]
) -> "pli.Series":
...


@overload
def arg_where(
condition: Union["pli.Expr", "pli.Series"], eager: bool
) -> Union["pli.Expr", "pli.Series"]:
...


def arg_where(
condition: Union["pli.Expr", "pli.Series"], eager: bool = False
) -> Union["pli.Expr", "pli.Series"]:
"""
Return indices where `condition` evaluates `True`.
Parameters
----------
condition
Boolean expression to evaluate
Examples
--------
>>> df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
>>> df.select(
... [
... pl.col("a") % 2 == 0,
... ]
... ).to_series()
shape: (2,)
Series: '' [u32]
[
1
3
]
"""
if eager:
if not isinstance(condition, pli.Series):
raise ValueError(
f"expected 'Series' in 'arg_where' if 'eager=True', got {type(condition)}"
)
return (
condition.to_frame().select(arg_where(pli.col(condition.name))).to_series()
)
else:
condition = pli.expr_to_lit_or_expr(condition, str_to_lit=True)
return pli.wrap_expr(py_arg_where(condition._pyexpr))
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1768,7 +1768,7 @@ def arg_true(self) -> "Series":
-------
UInt32 Series
"""
return wrap_s(self._s.arg_true())
return pli.arg_where(self, eager=True)

def is_unique(self) -> "Series":
"""
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,11 @@ fn pool_size() -> usize {
POOL.current_num_threads()
}

#[pyfunction]
pub fn arg_where(condition: PyExpr) -> PyExpr {
polars::lazy::dsl::arg_where(condition.inner).into()
}

#[pymodule]
fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add("NotFoundError", py.get_type::<NotFoundError>())
Expand Down Expand Up @@ -530,5 +535,6 @@ fn polars(py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(as_struct)).unwrap();
m.add_wrapped(wrap_pyfunction!(repeat)).unwrap();
m.add_wrapped(wrap_pyfunction!(pool_size)).unwrap();
m.add_wrapped(wrap_pyfunction!(arg_where)).unwrap();
Ok(())
}

0 comments on commit c20d943

Please sign in to comment.