Skip to content

Commit

Permalink
count expression that does not require column name (#2558)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 6, 2022
1 parent 11ff23e commit 194e402
Show file tree
Hide file tree
Showing 12 changed files with 126 additions and 4 deletions.
7 changes: 7 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,8 @@ pub enum Expr {
value: String,
expr: Box<Expr>,
},
/// Special case without that does not need columns
Count,
}

impl Default for Expr {
Expand Down Expand Up @@ -2190,6 +2192,11 @@ where
}
}

/// Count expression
pub fn count() -> Expr {
Expr::Count
}

#[cfg(test)]
mod test {
use super::*;
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ pub enum AExpr {
offset: i64,
length: usize,
},
Count,
}

impl Default for AExpr {
Expand Down Expand Up @@ -126,6 +127,7 @@ impl AExpr {
) -> Result<Field> {
use AExpr::*;
match self {
Count => Ok(Field::new("count", DataType::UInt32)),
Window { function, .. } => {
let e = arena.get(*function);

Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
length,
},
Expr::Wildcard => AExpr::Wildcard,
Expr::Count => AExpr::Count,
Expr::KeepName(_) => panic!("no keep_name expected at this point"),
Expr::Exclude(_, _) => panic!("no exclude expected at this point"),
Expr::SufPreFix { .. } => panic!("no `suffix/prefix` expected at this point"),
Expand Down Expand Up @@ -389,6 +390,7 @@ pub(crate) fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
let expr = expr_arena.get(node).clone();

match expr {
AExpr::Count => Expr::Count,
AExpr::Duplicated(node) => Expr::Duplicated(Box::new(node_to_expr(node, expr_arena))),
AExpr::IsUnique(node) => Expr::IsUnique(Box::new(node_to_expr(node, expr_arena))),
AExpr::Reverse(node) => Expr::Reverse(Box::new(node_to_expr(node, expr_arena))),
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ impl fmt::Debug for Expr {
partition_by,
..
} => write!(f, "{:?}.over({:?})", function, partition_by),
Count => write!(f, "count()"),
IsUnique(expr) => write!(f, "{:?}.unique()", expr),
Explode(expr) => write!(f, "{:?}.explode()", expr),
Duplicated(expr) => write!(f, "{:?}.is_duplicate()", expr),
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ macro_rules! push_expr {
($current_expr:expr, $push:ident, $iter:ident) => {{
use Expr::*;
match $current_expr {
Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) => {}
Column(_) | Literal(_) | Wildcard | Columns(_) | DtypeColumn(_) | Count => {}
Alias(e, _) => $push(e),
Not(e) => $push(e),
BinaryExpr { left, op: _, right } => {
Expand Down Expand Up @@ -155,7 +155,7 @@ impl AExpr {
use AExpr::*;

match self {
Column(_) | Literal(_) | Wildcard => {}
Column(_) | Literal(_) | Wildcard | Count => {}
Alias(e, _) => push(e),
Not(e) => push(e),
BinaryExpr { left, op: _, right } => {
Expand Down
68 changes: 68 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use polars_arrow::utils::CustomIterTools;
use polars_core::prelude::*;
use polars_core::utils::NoNull;
use std::borrow::Cow;

pub struct CountExpr {
expr: Expr,
}

impl CountExpr {
pub(crate) fn new() -> Self {
Self { expr: Expr::Count }
}
}

impl PhysicalExpr for CountExpr {
fn as_expression(&self) -> &Expr {
&self.expr
}

fn evaluate(&self, df: &DataFrame, _state: &ExecutionState) -> Result<Series> {
Ok(Series::new("count", [df.height() as u32]))
}

fn evaluate_on_groups<'a>(
&self,
_df: &DataFrame,
groups: &'a GroupsProxy,
_state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ca = match groups {
GroupsProxy::Idx(groups) => {
let ca: NoNull<UInt32Chunked> = groups
.all()
.iter()
.map(|g| g.len() as u32)
.collect_trusted();
ca.into_inner()
}
GroupsProxy::Slice(groups) => {
let ca: NoNull<UInt32Chunked> = groups.iter().map(|g| g[1]).collect_trusted();
ca.into_inner()
}
};
ca.rename("count");
let s = ca.into_series();

Ok(AggregationContext::new(s, Cow::Borrowed(groups), true))
}
fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
Ok(Field::new("count", DataType::UInt32))
}
}

impl PhysicalAggregation for CountExpr {
fn aggregate(
&self,
df: &DataFrame,
groups: &GroupsProxy,
state: &ExecutionState,
) -> Result<Option<Series>> {
let mut ac = self.evaluate_on_groups(df, groups, state)?;
let s = ac.aggregated();
Ok(Some(s))
}
}
1 change: 1 addition & 0 deletions polars/polars-lazy/src/physical_plan/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub(crate) mod apply;
pub(crate) mod binary;
pub(crate) mod cast;
pub(crate) mod column;
pub(crate) mod count;
pub(crate) mod filter;
pub(crate) mod is_not_null;
pub(crate) mod is_null;
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-lazy/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::physical_plan::executors::groupby_rolling::GroupByRollingExec;
#[cfg(feature = "ipc")]
use crate::physical_plan::executors::scan::IpcExec;
use crate::physical_plan::executors::union::UnionExec;
use crate::prelude::count::CountExpr;
use crate::prelude::shift::ShiftExpr;
use crate::prelude::*;
use crate::utils::{expr_to_root_column_name, has_window_aexpr};
Expand Down Expand Up @@ -486,6 +487,7 @@ impl DefaultPlanner {
use AExpr::*;

match expr_arena.get(expression).clone() {
Count => Ok(Arc::new(CountExpr::new())),
Window {
mut function,
partition_by,
Expand Down
21 changes: 19 additions & 2 deletions py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from polars.polars import cols as pycols
from polars.polars import concat_lst as _concat_lst
from polars.polars import concat_str as _concat_str
from polars.polars import count as _count
from polars.polars import cov as pycov
from polars.polars import dtype_cols as _dtype_cols
from polars.polars import fold as pyfold
Expand Down Expand Up @@ -160,10 +161,26 @@ def count(column: "pli.Series") -> int:
...


def count(column: Union[str, "pli.Series"] = "") -> Union["pli.Expr", int]:
@overload
def count(column: None = None) -> "pli.Expr":
...


def count(column: Optional[Union[str, "pli.Series"]] = None) -> Union["pli.Expr", int]:
"""
Count the number of values in this column.
Count the number of values in this column/context.
Parameters
----------
column
If dtype is:
pl.Series -> count the values in the series
str -> count the values in this column
None -> count the number of values in this context
"""
if column is None:
return pli.wrap_expr(_count())

if isinstance(column, pli.Series):
return column.len()
return col(column).count()
Expand Down
4 changes: 4 additions & 0 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1285,6 +1285,10 @@ pub fn col(name: &str) -> PyExpr {
dsl::col(name).into()
}

pub fn count() -> PyExpr {
dsl::count().into()
}

pub fn cols(names: Vec<String>) -> PyExpr {
dsl::cols(names).into()
}
Expand Down
6 changes: 6 additions & 0 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ fn col(name: &str) -> dsl::PyExpr {
dsl::col(name)
}

#[pyfunction]
fn count() -> dsl::PyExpr {
dsl::count()
}

#[pyfunction]
fn cols(names: Vec<String>) -> dsl::PyExpr {
dsl::cols(names)
Expand Down Expand Up @@ -344,6 +349,7 @@ fn polars(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_class::<PyLazyGroupBy>().unwrap();
m.add_class::<dsl::PyExpr>().unwrap();
m.add_wrapped(wrap_pyfunction!(col)).unwrap();
m.add_wrapped(wrap_pyfunction!(count)).unwrap();
m.add_wrapped(wrap_pyfunction!(cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(dtype_cols)).unwrap();
m.add_wrapped(wrap_pyfunction!(lit)).unwrap();
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,15 @@ def test_list_join_strings() -> None:
s = pl.Series("a", [["ab", "c", "d"], ["e", "f"], ["g"], []])
expected = pl.Series("a", ["ab-c-d", "e-f", "g", ""])
verify_series_and_expr_api(s, expected, "arr.join", "-")


def test_count_expr() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 3, 3], "b": ["a", "a", "b", "a", "a"]})

out = df.select(pl.count())
assert out.shape == (1, 1)
assert out[0, 0] == 5

out = df.groupby("b", maintain_order=True).agg(pl.count())
assert out["b"].to_list() == ["a", "b"]
assert out["count"].to_list() == [4, 1]

0 comments on commit 194e402

Please sign in to comment.