Skip to content

Commit

Permalink
perf(rust, python): optimize arr.mean (#7048)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 20, 2023
1 parent b619f42 commit 4672a28
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 89 deletions.
2 changes: 1 addition & 1 deletion polars/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ mod count;
#[cfg(feature = "hash")]
pub(crate) mod hash;
mod namespace;
mod sum;
mod sum_mean;
#[cfg(feature = "list_to_struct")]
mod to_struct;

Expand Down
38 changes: 33 additions & 5 deletions polars/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use polars_core::series::ops::NullBehavior;
use polars_core::utils::{try_get_supertype, CustomIterTools};

use super::*;
use crate::prelude::list::sum::sum_list_numerical;
use crate::prelude::list::sum_mean::{mean_list_numerical, sum_list_numerical};
use crate::series::ArgAgg;

fn has_inner_nulls(ca: &ListChunked) -> bool {
Expand Down Expand Up @@ -154,11 +154,39 @@ pub trait ListNameSpaceImpl: AsList {
}
}

fn lst_mean(&self) -> Float64Chunked {
fn lst_mean(&self) -> Series {
fn inner(ca: &ListChunked) -> Series {
let mut out: Float64Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().mean()))
.collect();

out.rename(ca.name());
out.into_series()
}
use DataType::*;

let ca = self.as_list();
ca.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().mean()))
.collect()

if has_inner_nulls(ca) {
return match ca.inner_dtype() {
Float32 => {
let mut out: Float32Chunked = ca
.amortized_iter()
.map(|s| s.and_then(|s| s.as_ref().mean().map(|v| v as f32)))
.collect();

out.rename(ca.name());
out.into_series()
}
_ => inner(ca),
};
};

match ca.inner_dtype() {
dt if dt.is_numeric() => mean_list_numerical(ca, &dt),
_ => inner(ca),
}
}

#[must_use]
Expand Down
83 changes: 0 additions & 83 deletions polars/polars-ops/src/chunked_array/list/sum.rs

This file was deleted.

146 changes: 146 additions & 0 deletions polars/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
use std::ops::Div;

use arrow::array::{Array, PrimitiveArray};
use arrow::bitmap::Bitmap;
use arrow::types::NativeType;
use polars_arrow::utils::CustomIterTools;
use polars_core::datatypes::ListChunked;
use polars_core::export::num::{NumCast, ToPrimitive};
use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::*;

fn sum_slice<T, S>(values: &[T]) -> S
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum,
{
values
.iter()
.copied()
.map(|t| unsafe {
let s: S = NumCast::from(t).unwrap_unchecked_release();
s
})
.sum()
}

fn sum_between_offsets<T, S>(values: &[T], offset: &[i64]) -> Vec<S>
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum,
{
let mut running_offset = offset[0];

(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
sum_slice(slice)
})
.collect_trusted()
}

fn dispatch_sum<T, S>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
where
T: NativeType + ToPrimitive,
S: NativeType + NumCast + std::iter::Sum,
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
Box::new(PrimitiveArray::from_data_default(
sum_between_offsets::<_, S>(values, offsets).into(),
validity.cloned(),
)) as ArrayRef
}

pub(super) fn sum_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
use DataType::*;
let chunks = ca
.downcast_iter()
.map(|arr| {
let offsets = arr.offsets().as_slice();
let values = arr.values().as_ref();

match inner_type {
Int8 => dispatch_sum::<i8, i64>(values, offsets, arr.validity()),
Int16 => dispatch_sum::<i16, i64>(values, offsets, arr.validity()),
Int32 => dispatch_sum::<i32, i32>(values, offsets, arr.validity()),
Int64 => dispatch_sum::<i64, i64>(values, offsets, arr.validity()),
UInt8 => dispatch_sum::<u8, i64>(values, offsets, arr.validity()),
UInt16 => dispatch_sum::<u16, i64>(values, offsets, arr.validity()),
UInt32 => dispatch_sum::<u32, u32>(values, offsets, arr.validity()),
UInt64 => dispatch_sum::<u64, u64>(values, offsets, arr.validity()),
Float32 => dispatch_sum::<f32, f32>(values, offsets, arr.validity()),
Float64 => dispatch_sum::<f64, f64>(values, offsets, arr.validity()),
_ => unimplemented!(),
}
})
.collect::<Vec<_>>();

Series::try_from((ca.name(), chunks)).unwrap()
}

fn mean_between_offsets<T, S>(values: &[T], offset: &[i64]) -> Vec<S>
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum + Div<Output = S>,
{
let mut running_offset = offset[0];

(offset[1..])
.iter()
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
unsafe {
sum_slice::<_, S>(slice) / NumCast::from(slice.len()).unwrap_unchecked_release()
}
})
.collect_trusted()
}

fn dispatch_mean<T, S>(arr: &dyn Array, offsets: &[i64], validity: Option<&Bitmap>) -> ArrayRef
where
T: NativeType + ToPrimitive,
S: NativeType + NumCast + std::iter::Sum + Div<Output = S>,
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
Box::new(PrimitiveArray::from_data_default(
mean_between_offsets::<_, S>(values, offsets).into(),
validity.cloned(),
)) as ArrayRef
}

pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
use DataType::*;
let chunks = ca
.downcast_iter()
.map(|arr| {
let offsets = arr.offsets().as_slice();
let values = arr.values().as_ref();

match inner_type {
Int8 => dispatch_mean::<i8, f64>(values, offsets, arr.validity()),
Int16 => dispatch_mean::<i16, f64>(values, offsets, arr.validity()),
Int32 => dispatch_mean::<i32, f64>(values, offsets, arr.validity()),
Int64 => dispatch_mean::<i64, f64>(values, offsets, arr.validity()),
UInt8 => dispatch_mean::<u8, f64>(values, offsets, arr.validity()),
UInt16 => dispatch_mean::<u16, f64>(values, offsets, arr.validity()),
UInt32 => dispatch_mean::<u32, f64>(values, offsets, arr.validity()),
UInt64 => dispatch_mean::<u64, f64>(values, offsets, arr.validity()),
Float32 => dispatch_mean::<f32, f32>(values, offsets, arr.validity()),
Float64 => dispatch_mean::<f64, f64>(values, offsets, arr.validity()),
_ => unimplemented!(),
}
})
.collect::<Vec<_>>();

Series::try_from((ca.name(), chunks)).unwrap()
}
10 changes: 10 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,3 +405,13 @@ def test_list_sum_and_dtypes() -> None:
assert pl.DataFrame(
{"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5], None]}
).select(pl.col("a").arr.sum()).to_dict(False) == {"a": [1, 6, 10, 15, None]}


def test_list_mean() -> None:
assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]]}).select(
pl.col("a").arr.mean()
).to_dict(False) == {"a": [1.0, 2.0, 2.5, 3.0]}

assert pl.DataFrame({"a": [[1], [1, 2, 3], [1, 2, 3, 4], None]}).select(
pl.col("a").arr.mean()
).to_dict(False) == {"a": [1.0, 2.0, 2.5, None]}

0 comments on commit 4672a28

Please sign in to comment.