Skip to content

Commit

Permalink
perf(rust, python): optimize arr.sum (#7047)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 20, 2023
1 parent 3c57e67 commit b619f42
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 40 deletions.
6 changes: 6 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub enum ListFunction {
Take(bool),
#[cfg(feature = "list_count")]
CountMatch,
Sum,
}

impl Display for ListFunction {
Expand All @@ -31,6 +32,7 @@ impl Display for ListFunction {
Take(_) => "take",
#[cfg(feature = "list_count")]
CountMatch => "count",
Sum => "sum",
};
write!(f, "{name}")
}
Expand Down Expand Up @@ -229,3 +231,7 @@ pub(super) fn count_match(args: &[Series]) -> PolarsResult<Series> {
let ca = s.list()?;
list_count_match(ca, element.get(0).unwrap())
}

pub(super) fn sum(s: &Series) -> PolarsResult<Series> {
Ok(s.list()?.lst_sum())
}
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Take(null_ob_oob) => map_as_slice!(list::take, null_ob_oob),
#[cfg(feature = "list_count")]
CountMatch => map_as_slice!(list::count_match),
Sum => map!(list::sum),
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
12 changes: 12 additions & 0 deletions polars/polars-lazy/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,18 @@ impl FunctionExpr {
Take(_) => same_type(),
#[cfg(feature = "list_count")]
CountMatch => with_dtype(IDX_DTYPE),
Sum => {
let mut first = fields[0].clone();
use DataType::*;
let dt = first.data_type().inner_dtype().cloned().unwrap_or(Unknown);

if matches!(dt, UInt8 | Int8 | Int16 | UInt16) {
first.coerce(Int64);
} else {
first.coerce(dt);
}
Ok(first)
}
}
}
#[cfg(feature = "dtype-struct")]
Expand Down
13 changes: 1 addition & 12 deletions polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,18 +62,7 @@ impl ListNameSpace {
/// Compute the sum the items in every sublist.
pub fn sum(self) -> Expr {
self.0
.map(
|s| Ok(Some(s.list()?.lst_sum())),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
}
}),
)
.with_fmt("arr.sum")
.map_private(FunctionExpr::ListExpr(ListFunction::Sum))
}

/// Compute the mean of every sublist and return a `Series` of dtype `Float64`
Expand Down
31 changes: 17 additions & 14 deletions polars/polars-ops/src/chunked_array/list/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@ use arrow::array::{Array, BooleanArray};
use arrow::bitmap::utils::count_zeros;
use arrow::bitmap::Bitmap;
use polars_arrow::utils::CustomIterTools;
use polars_core::utils::NoNull;

use super::*;

fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> IdxCa {
fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> Vec<IdxSize> {
let (bits, bitmap_offset, _) = values.as_slice();

let mut running_offset = offset[0];

let ca: NoNull<IdxCa> = (offset[1..])
(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
Expand All @@ -22,9 +21,7 @@ fn count_bits_set_by_offsets(values: &Bitmap, offset: &[i64]) -> IdxCa {
let set_ones = len - count_zeros(bits, bitmap_offset + current_offset as usize, len);
set_ones as IdxSize
})
.collect_trusted();

ca.into_inner()
.collect_trusted()
}

#[cfg(feature = "list_count")]
Expand All @@ -39,12 +36,18 @@ pub fn list_count_match(ca: &ListChunked, value: AnyValue) -> PolarsResult<Serie
}

pub(super) fn count_boolean_bits(ca: &ListChunked) -> IdxCa {
assert_eq!(ca.chunks().len(), 1);
let arr = ca.downcast_iter().next().unwrap();
let inner_arr = arr.values();
let mask = inner_arr.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(mask.null_count(), 0);
let mut out = count_bits_set_by_offsets(mask.values(), arr.offsets().as_slice());
out.rename(ca.name());
out
let chunks = ca
.downcast_iter()
.map(|arr| {
let inner_arr = arr.values();
let mask = inner_arr.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(mask.null_count(), 0);
let out = count_bits_set_by_offsets(mask.values(), arr.offsets().as_slice());
Box::new(IdxArr::from_data_default(
out.into(),
arr.validity().cloned(),
)) as ArrayRef
})
.collect();
unsafe { IdxCa::from_chunks(ca.name(), chunks) }
}
1 change: 1 addition & 0 deletions polars/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod count;
#[cfg(feature = "hash")]
pub(crate) mod hash;
mod namespace;
mod sum;
#[cfg(feature = "list_to_struct")]
mod to_struct;

Expand Down
33 changes: 21 additions & 12 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,18 @@ use polars_core::series::ops::NullBehavior;
use polars_core::utils::{try_get_supertype, CustomIterTools};

use super::*;
use crate::prelude::list::sum::sum_list_numerical;
use crate::series::ArgAgg;

fn has_inner_nulls(ca: &ListChunked) -> bool {
for arr in ca.downcast_iter() {
if arr.values().null_count() > 0 {
return true;
}
}
false
}

fn cast_rhs(
other: &mut [Series],
inner_type: &DataType,
Expand Down Expand Up @@ -123,25 +133,24 @@ pub trait ListNameSpaceImpl: AsList {
}

fn lst_sum(&self) -> Series {
let ca = self.as_list();

fn inner(ca: &ListChunked) -> Series {
ca.apply_amortized(|s| s.as_ref().sum_as_series())
.explode()
.unwrap()
.into_series()
}

// fast implementation for booleans
if matches!(ca.inner_dtype(), DataType::Boolean) {
let ca = ca.rechunk();
if ca.chunks()[0].null_count() == 0 {
count_boolean_bits(&ca).into()
} else {
inner(&ca)
}
} else {
inner(ca)
let ca = self.as_list();

if has_inner_nulls(ca) {
return inner(ca);
};

use DataType::*;
match ca.inner_dtype() {
Boolean => count_boolean_bits(ca).into_series(),
dt if dt.is_numeric() => sum_list_numerical(ca, &dt),
_ => inner(ca),
}
}

Expand Down
83 changes: 83 additions & 0 deletions polars/polars-ops/src/chunked_array/list/sum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::Bitmap;
use arrow::types::NativeType;
use polars_arrow::utils::CustomIterTools;
use polars_core::datatypes::ListChunked;
use polars_core::export::num::{NumCast, ToPrimitive};
use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::*;

fn sum_slice<T, S>(values: &[T]) -> S
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum,
{
values
.iter()
.copied()
.map(|t| unsafe {
let s: S = NumCast::from(t).unwrap_unchecked_release();
s
})
.sum()
}

fn sum_between_offsets<T, S>(values: &[T], offset: &[i64]) -> Vec<S>
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum,
{
let mut running_offset = offset[0];

(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
sum_slice(slice)
})
.collect_trusted()
}

fn dispatch<T, S>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
where
T: NativeType + ToPrimitive,
S: NativeType + NumCast + std::iter::Sum,
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
Box::new(PrimitiveArray::from_data_default(
sum_between_offsets::<_, S>(values, offsets).into(),
validity.cloned(),
)) as ArrayRef
}

pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
use DataType::*;
let chunks = ca
.downcast_iter()
.map(|arr| {
let offsets = arr.offsets().as_slice();
let values = arr.values().as_ref();

match inner_type {
Int8 => dispatch::<i8, i64>(values, offsets, arr.validity()),
Int16 => dispatch::<i16, i64>(values, offsets, arr.validity()),
Int32 => dispatch::<i32, i32>(values, offsets, arr.validity()),
Int64 => dispatch::<i64, i64>(values, offsets, arr.validity()),
UInt8 => dispatch::<u8, i64>(values, offsets, arr.validity()),
UInt16 => dispatch::<u16, i64>(values, offsets, arr.validity()),
UInt32 => dispatch::<u32, u32>(values, offsets, arr.validity()),
UInt64 => dispatch::<u64, u64>(values, offsets, arr.validity()),
Float32 => dispatch::<f32, f32>(values, offsets, arr.validity()),
Float64 => dispatch::<f64, f64>(values, offsets, arr.validity()),
_ => unimplemented!(),
}
})
.collect::<Vec<_>>();

Series::try_from((ca.name(), chunks)).unwrap()
}
4 changes: 2 additions & 2 deletions py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,3 +372,36 @@ def test_list_count_match() -> None:
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
pl.col("listcol").arr.count_match(2).alias("number_of_twos")
).to_dict(False) == {"number_of_twos": [0, 0, 2, 1, 0]}
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
pl.col("listcol").arr.count_match(2).alias("number_of_twos")
).to_dict(False) == {"number_of_twos": [0, 0, 2, 1, 0]}


def test_list_sum_and_dtypes() -> None:
# ensure the dtypes of sum align with normal sum
for dt_in, dt_out in [
(pl.Int8, pl.Int64),
(pl.Int16, pl.Int64),
(pl.Int32, pl.Int32),
(pl.Int64, pl.Int64),
(pl.UInt8, pl.Int64),
(pl.UInt16, pl.Int64),
(pl.UInt32, pl.UInt32),
(pl.UInt64, pl.UInt64),
]:
df = pl.DataFrame(
{"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]},
schema={"a": pl.List(dt_in)},
)

summed = df.explode("a").sum()
assert summed.dtypes == [dt_out]
assert summed.item() == 32
assert df.select(pl.col("a").arr.sum()).dtypes == [dt_out]

assert df.select(pl.col("a").arr.sum()).to_dict(False) == {"a": [1, 6, 10, 15]}

# include nulls
assert pl.DataFrame(
{"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}
).select(pl.col("a").arr.sum()).to_dict(False) == {"a": [1, 6, 10, 15, None]}

0 comments on commit b619f42

Please sign in to comment.