Skip to content

Commit

Permalink
feat: Impl count_matches for array namespace (#13675)
Browse files Browse the repository at this point in the history
Co-authored-by: Stijn de Gooijer <stijndegooijer@gmail.com>
  • Loading branch information
2 people authored and r-brink committed Jan 24, 2024
1 parent 86f0451 commit ece6b00
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 11 deletions.
2 changes: 1 addition & 1 deletion crates/polars-lazy/Cargo.toml
Expand Up @@ -81,7 +81,7 @@ sign = ["polars-plan/sign"]
timezones = ["polars-plan/timezones"]
list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"]
list_count = ["polars-ops/list_count", "polars-plan/list_count"]

array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"]
true_div = ["polars-plan/true_div"]
extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"]

Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Expand Up @@ -110,6 +110,7 @@ chunked_ids = ["polars-core/chunked_ids"]
asof_join = ["polars-core/asof_join"]
semi_anti_join = []
array_any_all = ["dtype-array"]
array_count = ["dtype-array"]
list_gather = []
list_sets = []
list_any_all = []
Expand Down
45 changes: 45 additions & 0 deletions crates/polars-ops/src/chunked_array/array/count.rs
@@ -0,0 +1,45 @@
use arrow::array::{Array, BooleanArray};
use arrow::bitmap::utils::count_zeros;
use arrow::bitmap::Bitmap;
use arrow::legacy::utils::CustomIterTools;
use polars_core::prelude::arity::unary_mut_with_options;

use super::*;

pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult<Series> {
let value = Series::new("", [value]);

let ca = ca.apply_to_inner(&|s| {
ChunkCompare::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series())
})?;
let out = count_boolean_bits(&ca);
Ok(out.into_series())
}

fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa {
unary_mut_with_options(ca, |arr| {
let inner_arr = arr.values();
let mask = inner_arr.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(mask.null_count(), 0);
let out = count_bits_set(mask.values(), arr.len(), arr.size());
IdxArr::from_data_default(out.into(), arr.validity().cloned())
})
}

fn count_bits_set(values: &Bitmap, len: usize, width: usize) -> Vec<IdxSize> {
// Fast path where all bits are either set or unset.
if values.unset_bits() == values.len() {
return vec![0 as IdxSize; len];
} else if values.unset_bits() == 0 {
return vec![width as IdxSize; len];
}

let (bits, bitmap_offset, _) = values.as_slice();

(0..len)
.map(|i| {
let set_ones = width - count_zeros(bits, bitmap_offset + i * width, width);
set_ones as IdxSize
})
.collect_trusted()
}
2 changes: 2 additions & 0 deletions crates/polars-ops/src/chunked_array/array/mod.rs
@@ -1,5 +1,7 @@
#[cfg(feature = "array_any_all")]
mod any_all;
#[cfg(feature = "array_count")]
mod count;
mod get;
mod join;
mod min_max;
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
@@ -1,5 +1,7 @@
use super::min_max::AggType;
use super::*;
#[cfg(feature = "array_count")]
use crate::chunked_array::array::count::array_count_matches;
use crate::chunked_array::array::sum_mean::sum_with_nulls;
#[cfg(feature = "array_any_all")]
use crate::prelude::array::any_all::{array_all, array_any};
Expand Down Expand Up @@ -104,6 +106,12 @@ pub trait ArrayNameSpace: AsArray {
let ca = self.as_array();
array_join(ca, separator).map(|ok| ok.into_series())
}

#[cfg(feature = "array_count")]
fn array_count_matches(&self, element: AnyValue) -> PolarsResult<Series> {
let ca = self.as_array();
array_count_matches(ca, element)
}
}

impl ArrayNameSpace for ArrayChunked {}
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Expand Up @@ -80,6 +80,7 @@ object = ["polars-core/object"]
date_offset = ["polars-time", "chrono"]
list_gather = ["polars-ops/list_gather"]
list_count = ["polars-ops/list_count"]
array_count = ["polars-ops/array_count", "dtype-array"]
trigonometry = []
sign = []
timezones = ["chrono-tz", "polars-time/timezones", "polars-core/timezones", "regex"]
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Expand Up @@ -114,4 +114,22 @@ impl ArrayNameSpace {
false,
)
}

#[cfg(feature = "array_count")]
/// Count how often the value produced by ``element`` occurs.
pub fn count_matches<E: Into<Expr>>(self, element: E) -> Expr {
let other = element.into();

self.0
.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::CountMatches),
&[other],
false,
false,
)
.with_function_options(|mut options| {
options.input_wildcard_expansion = true;
options
})
}
}
21 changes: 21 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Expand Up @@ -23,6 +23,8 @@ pub enum ArrayFunction {
Join,
#[cfg(feature = "is_in")]
Contains,
#[cfg(feature = "array_count")]
CountMatches,
}

impl ArrayFunction {
Expand All @@ -42,6 +44,8 @@ impl ArrayFunction {
Join => mapper.with_dtype(DataType::String),
#[cfg(feature = "is_in")]
Contains => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
}
}
}
Expand Down Expand Up @@ -75,6 +79,8 @@ impl Display for ArrayFunction {
Join => "join",
#[cfg(feature = "is_in")]
Contains => "contains",
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
};
write!(f, "arr.{name}")
}
Expand All @@ -101,6 +107,8 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Join => map_as_slice!(join),
#[cfg(feature = "is_in")]
Contains => map_as_slice!(contains),
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
}
}
}
Expand Down Expand Up @@ -177,3 +185,16 @@ pub(super) fn contains(s: &[Series]) -> PolarsResult<Series> {
let item = &s[1];
Ok(is_in(item, array)?.with_name(array.name()).into_series())
}

#[cfg(feature = "array_count")]
pub(super) fn count_matches(args: &[Series]) -> PolarsResult<Series> {
let s = &args[0];
let element = &args[1];
polars_ensure!(
element.len() == 1,
ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
element.len()
);
let ca = s.array()?;
ca.array_count_matches(element.get(0).unwrap())
}
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Expand Up @@ -128,7 +128,7 @@ impl Display for ListFunction {
#[cfg(feature = "list_gather")]
Gather(_) => "gather",
#[cfg(feature = "list_count")]
CountMatches => "count",
CountMatches => "count_matches",
Sum => "sum",
Min => "min",
Max => "max",
Expand Down Expand Up @@ -459,7 +459,7 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult<Series> {
let element = &args[1];
polars_ensure!(
element.len() == 1,
ComputeError: "argument expression in `arr.count` must produce exactly one element, got {}",
ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}",
element.len()
);
let ca = s.list()?;
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Expand Up @@ -155,6 +155,7 @@ is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"]
regex = ["polars-lazy?/regex"]
list_any_all = ["polars-lazy?/list_any_all"]
list_count = ["polars-ops/list_count", "polars-lazy?/list_count"]
array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"]
list_drop_nulls = ["polars-lazy?/list_drop_nulls"]
list_eval = ["polars-lazy?/list_eval"]
list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"]
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Expand Up @@ -137,6 +137,7 @@ cse = ["polars/cse"]
merge_sorted = ["polars/merge_sorted"]
list_gather = ["polars/list_gather"]
list_count = ["polars/list_count"]
array_count = ["polars/array_count", "polars/dtype-array"]
binary_encoding = ["polars/binary_encoding"]
list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars/list_any_all"]
Expand All @@ -163,6 +164,7 @@ dtypes = [

operations = [
"array_any_all",
"array_count",
"is_in",
"repeat_by",
"trigonometry",
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Expand Up @@ -25,3 +25,4 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.last
Expr.arr.join
Expr.arr.contains
Expr.arr.count_matches
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/array.rst
Expand Up @@ -25,3 +25,4 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.last
Series.arr.join
Series.arr.contains
Series.arr.count_matches
31 changes: 30 additions & 1 deletion py-polars/polars/expr/array.py
Expand Up @@ -9,7 +9,7 @@
from datetime import date, datetime, time

from polars import Expr
from polars.type_aliases import IntoExprColumn
from polars.type_aliases import IntoExpr, IntoExprColumn


class ExprArrayNameSpace:
Expand Down Expand Up @@ -509,3 +509,32 @@ def contains(
"""
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.arr_contains(item))

def count_matches(self, element: IntoExpr) -> Expr:
"""
Count how often the value produced by `element` occurs.
Parameters
----------
element
An expression that produces a single value
Examples
--------
>>> df = pl.DataFrame(
... {"a": [[1, 2], [1, 1], [2, 2]]}, schema={"a": pl.Array(pl.Int64, 2)}
... )
>>> df.with_columns(number_of_twos=pl.col("a").arr.count_matches(2))
shape: (3, 2)
┌───────────────┬────────────────┐
│ a ┆ number_of_twos │
│ --- ┆ --- │
│ array[i64, 2] ┆ u32 │
╞═══════════════╪════════════════╡
│ [1, 2] ┆ 1 │
│ [1, 1] ┆ 0 │
│ [2, 2] ┆ 2 │
└───────────────┴────────────────┘
"""
element = parse_as_expression(element, str_as_lit=True)
return wrap_expr(self._pyexpr.arr_count_matches(element))
24 changes: 23 additions & 1 deletion py-polars/polars/series/array.py
Expand Up @@ -9,7 +9,7 @@

from polars import Series
from polars.polars import PySeries
from polars.type_aliases import IntoExprColumn
from polars.type_aliases import IntoExpr, IntoExprColumn


@expr_dispatch
Expand Down Expand Up @@ -406,3 +406,25 @@ def contains(
]
"""

def count_matches(self, element: IntoExpr) -> Series:
"""
Count how often the value produced by `element` occurs.
Parameters
----------
element
An expression that produces a single value
Examples
--------
>>> s = pl.Series("a", [[1, 2, 3], [2, 2, 2]], dtype=pl.Array(pl.Int64, 3))
>>> s.arr.count_matches(2)
shape: (2,)
Series: 'a' [u32]
[
1
3
]
"""
5 changes: 2 additions & 3 deletions py-polars/polars/series/list.py
Expand Up @@ -16,6 +16,7 @@
from polars import Expr, Series
from polars.polars import PySeries
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
Expand Down Expand Up @@ -694,9 +695,7 @@ def explode(self) -> Series:
]
"""

def count_matches(
self, element: float | str | bool | int | date | datetime | time | Expr
) -> Expr:
def count_matches(self, element: IntoExpr) -> Series:
"""
Count how often the value produced by `element` occurs.
Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/expr/array.rs
Expand Up @@ -73,4 +73,9 @@ impl PyExpr {
fn arr_contains(&self, other: PyExpr) -> Self {
self.inner.clone().arr().contains(other.inner).into()
}

#[cfg(feature = "array_count")]
fn arr_count_matches(&self, expr: PyExpr) -> Self {
self.inner.clone().arr().count_matches(expr.inner).into()
}
}
3 changes: 0 additions & 3 deletions py-polars/tests/unit/datatypes/test_list.py
Expand Up @@ -304,9 +304,6 @@ def test_list_count_matches() -> None:
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
pl.col("listcol").list.count_matches(2).alias("number_of_twos")
).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]}
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
pl.col("listcol").list.count_matches(2).alias("number_of_twos")
).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]}


def test_list_sum_and_dtypes() -> None:
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/namespaces/array/test_array.py
Expand Up @@ -280,3 +280,20 @@ def test_array_contains_literal(
out = df.select(contains=pl.col("array").arr.contains(data)).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


@pytest.mark.parametrize(
("arr", "data", "expected", "dtype"),
[
([[1, 2], [3, None], None], 1, [1, 0, None], pl.Int64),
([[True, False], [True, None], None], True, [1, 1, None], pl.Boolean),
([["a", "b"], ["c", None], None], "a", [1, 0, None], pl.String),
([[b"a", b"b"], [b"c", None], None], b"a", [1, 0, None], pl.Binary),
],
)
def test_array_count_matches(
arr: list[list[Any] | None], data: Any, expected: list[Any], dtype: pl.DataType
) -> None:
df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)})
out = df.select(count_matches=pl.col("arr").arr.count_matches(data))
assert out.to_dict(as_series=False) == {"count_matches": expected}

0 comments on commit ece6b00

Please sign in to comment.