Skip to content

Commit

Permalink
perf: optimize arr.sum for inner non-null bool (pola-rs#13800)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored and r-brink committed Jan 22, 2024
1 parent 6647e9c commit d6504a5
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
3 changes: 2 additions & 1 deletion crates/polars-ops/src/chunked_array/array/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use polars_core::prelude::arity::unary_mut_with_options;

use super::*;

#[cfg(feature = "array_count")]
pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult<Series> {
let value = Series::new("", [value]);

Expand All @@ -16,7 +17,7 @@ pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult<S
Ok(out.into_series())
}

fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa {
pub(super) fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa {
unary_mut_with_options(ca, |arr| {
let inner_arr = arr.values();
let mask = inner_arr.as_any().downcast_ref::<BooleanArray>().unwrap();
Expand Down
1 change: 0 additions & 1 deletion crates/polars-ops/src/chunked_array/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#[cfg(feature = "array_any_all")]
mod any_all;
#[cfg(feature = "array_count")]
mod count;
mod get;
mod join;
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use super::min_max::AggType;
use super::*;
#[cfg(feature = "array_count")]
use crate::chunked_array::array::count::array_count_matches;
use crate::chunked_array::array::count::count_boolean_bits;
use crate::chunked_array::array::sum_mean::sum_with_nulls;
#[cfg(feature = "array_any_all")]
use crate::prelude::array::any_all::{array_all, array_any};
Expand Down Expand Up @@ -44,6 +45,7 @@ pub trait ArrayNameSpace: AsArray {
};

match ca.inner_dtype() {
DataType::Boolean => Ok(count_boolean_bits(ca).into_series()),
dt if dt.is_numeric() => Ok(sum_array_numerical(ca, &dt)),
dt => sum_with_nulls(ca, &dt),
}
Expand Down
19 changes: 16 additions & 3 deletions py-polars/tests/unit/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,22 @@ def test_array_min_max_dtype_12123() -> None:
assert_frame_equal(out, pl.DataFrame({"max": [3.0, 10.0], "min": [1.0, 4.0]}))


def test_arr_sum() -> None:
s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(pl.Int64, 2))
assert s.arr.sum().to_list() == [3, 7]
@pytest.mark.parametrize(
("data", "expected_sum", "dtype"),
[
([[1, 2], [4, 3]], [3, 7], pl.Int64),
([[1, None], [None, 3], [None, None]], [1, 3, 0], pl.Int64),
([[1.0, 2.0], [4.0, 3.0]], [3.0, 7.0], pl.Float32),
([[1.0, None], [None, 3.0], [None, None]], [1.0, 3.0, 0], pl.Float32),
([[True, False], [True, True], [False, False]], [1, 2, 0], pl.Boolean),
([[True, None], [None, False], [None, None]], [1, 0, 0], pl.Boolean),
],
)
def test_arr_sum(
data: list[list[Any]], expected_sum: list[Any], dtype: pl.DataType
) -> None:
s = pl.Series("a", data, dtype=pl.Array(dtype, 2))
assert s.arr.sum().to_list() == expected_sum


def test_arr_unique() -> None:
Expand Down

0 comments on commit d6504a5

Please sign in to comment.