Skip to content

Commit

Permalink
feat[rust]: groupby numeric list columns (#4919)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 21, 2022
1 parent d4f38bc commit faeb014
Show file tree
Hide file tree
Showing 25 changed files with 208 additions and 64 deletions.
59 changes: 59 additions & 0 deletions polars/polars-arrow/src/kernels/list_bytes_iter.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use arrow::array::{ListArray, PrimitiveArray};
use arrow::bitmap::Bitmap;
use arrow::datatypes::PhysicalType::Primitive;
use arrow::types::NativeType;

use crate::error::PolarsError;
use crate::utils::with_match_primitive_type;

unsafe fn bytes_iter<'a, T: NativeType>(
values: &'a [T],
offsets: &'a [i64],
validity: Option<&'a Bitmap>,
) -> impl ExactSizeIterator<Item = Option<&'a [u8]>> {
let mut start = offsets[0] as usize;
offsets[1..].iter().enumerate().map(move |(i, end)| {
let end = *end as usize;
let out = values.get_unchecked(start..end);
start = end;

let data = out.as_ptr() as *const u8;
let out = std::slice::from_raw_parts(data, std::mem::size_of::<T>() * out.len());
match validity {
None => Some(out),
Some(validity) => {
if validity.get_bit_unchecked(i) {
Some(out)
} else {
None
}
}
}
})
}

pub fn numeric_list_bytes_iter(
arr: &ListArray<i64>,
) -> Result<Box<dyn ExactSizeIterator<Item = Option<&[u8]>> + '_>, PolarsError> {
let values = arr.values();
if values.null_count() > 0 {
return Err(PolarsError::ComputeError(
"only allowed for child arrays without nulls".into(),
));
}
let offsets = arr.offsets().as_slice();
let validity = arr.validity();

if let Primitive(primitive) = values.data_type().to_physical_type() {
with_match_primitive_type!(primitive, |$T| {
let arr: &PrimitiveArray<$T> = values.as_any().downcast_ref().unwrap();
let values = arr.values();
let iter = unsafe { bytes_iter(values.as_slice(), offsets, validity) };
Ok(Box::new(iter))
})
} else {
Err(PolarsError::ComputeError(
"only allowed for numeric child arrays".into(),
))
}
}
1 change: 1 addition & 0 deletions polars/polars-arrow/src/kernels/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod concatenate;
pub mod ewm;
pub mod float;
pub mod list;
pub mod list_bytes_iter;
pub mod rolling;
pub mod set;
pub mod sort_partition;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ mod test {
));

let groups = s.group_tuples(false, true);
let aggregated = unsafe { s.agg_list(&groups) };
let aggregated = unsafe { s.agg_list(&groups?) };
match aggregated.get(0) {
AnyValue::List(s) => {
assert!(matches!(s.dtype(), DataType::Categorical(_)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl CategoricalChunked {
}

pub fn value_counts(&self) -> PolarsResult<DataFrame> {
let groups = self.logical().group_tuples(true, false);
let groups = self.logical().group_tuples(true, false).unwrap();
let logical_values = unsafe {
self.logical()
.clone()
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ where
}
let mut groups = ca
.group_tuples(true, false)
.unwrap()
.into_idx()
.into_iter()
.collect_trusted::<Vec<_>>();
Expand Down
100 changes: 83 additions & 17 deletions polars/polars-core/src/frame/groupby/into_groups.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#[cfg(feature = "groupby_list")]
use polars_arrow::kernels::list_bytes_iter::numeric_list_bytes_iter;
use polars_arrow::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
use polars_arrow::prelude::*;
use polars_utils::flatten;
Expand All @@ -10,7 +12,7 @@ pub trait IntoGroupsProxy {
/// Create the tuples need for a groupby operation.
/// * The first value in the tuple is the first index of the group.
/// * The second value in the tuple is are the indexes of the groups including the first value.
fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> GroupsProxy {
fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsProxy> {
unimplemented!()
}
}
Expand Down Expand Up @@ -141,17 +143,17 @@ where
T: PolarsNumericType,
T::Native: NumCast,
{
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
// sorted path
if self.is_sorted() || self.is_sorted_reverse() && self.chunks().len() == 1 {
// don't have to pass `sorted` arg, GroupSlice is always sorted.
return GroupsProxy::Slice {
return Ok(GroupsProxy::Slice {
groups: self.create_groups_from_sorted(multithreaded),
rolling: false,
};
});
}

match self.dtype() {
let out = match self.dtype() {
DataType::UInt64 => {
// convince the compiler that we are this type.
let ca: &UInt64Chunked = unsafe {
Expand Down Expand Up @@ -210,11 +212,12 @@ where
let ca = ca.u32().unwrap();
num_groups_proxy(ca, multithreaded, sorted)
}
}
};
Ok(out)
}
}
impl IntoGroupsProxy for BooleanChunked {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
#[cfg(feature = "performant")]
{
let ca = self.cast(&DataType::UInt8).unwrap();
Expand All @@ -232,11 +235,11 @@ impl IntoGroupsProxy for BooleanChunked {

impl IntoGroupsProxy for Utf8Chunked {
#[allow(clippy::needless_lifetimes)]
fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
let hb = RandomState::default();
let null_h = get_null_hash_value(hb.clone());

if multithreaded {
let out = if multithreaded {
let n_partitions = set_partition_size();

let split = _split_offsets(self.len(), n_partitions);
Expand All @@ -260,7 +263,7 @@ impl IntoGroupsProxy for Utf8Chunked {
)
}
})
.collect::<Vec<_>>()
.collect_trusted::<Vec<_>>()
})
.collect::<Vec<_>>()
});
Expand All @@ -275,16 +278,79 @@ impl IntoGroupsProxy for Utf8Chunked {
};
BytesHash::new_from_str(opt_s, hash)
})
.collect::<Vec<_>>();
.collect_trusted::<Vec<_>>();
groupby(str_hashes.iter(), sorted)
}
};
Ok(out)
}
}

impl IntoGroupsProxy for ListChunked {
#[cfg(feature = "groupby_list")]
fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> GroupsProxy {
groupby(self.into_iter().map(|opt_s| opt_s.map(Wrap)), sorted)
#[allow(clippy::needless_lifetimes)]
#[allow(unused_variables)]
fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
#[cfg(feature = "groupby_list")]
{
if !self.inner_dtype().to_physical().is_numeric() {
return Err(PolarsError::ComputeError(
"Grouping on List type is only allowed if the inner type is numeric".into(),
));
}

let hb = RandomState::default();
let null_h = get_null_hash_value(hb.clone());

let arr_to_hashes = |ca: &ListChunked| {
let mut out = Vec::with_capacity(ca.len());

for arr in ca.downcast_iter() {
out.extend(numeric_list_bytes_iter(arr)?.map(|opt_bytes| {
let hash = match opt_bytes {
Some(s) => str::get_hash(s, &hb),
None => null_h,
};

// Safety:
// the underlying data is tied to self
unsafe {
std::mem::transmute::<BytesHash<'_>, BytesHash<'a>>(BytesHash::new(
opt_bytes, hash,
))
}
}))
}
Ok(out)
};

if multithreaded {
let n_partitions = set_partition_size();
let split = _split_offsets(self.len(), n_partitions);

let groups: PolarsResult<_> = POOL.install(|| {
let bytes_hashes = split
.into_par_iter()
.map(|(offset, len)| {
let ca = self.slice(offset as i64, len);
arr_to_hashes(&ca)
})
.collect::<PolarsResult<Vec<_>>>()?;
Ok(groupby_threaded_num(
bytes_hashes,
0,
n_partitions as u64,
sorted,
))
});
groups
} else {
let hashes = arr_to_hashes(self)?;
Ok(groupby(hashes.iter(), sorted))
}
}
#[cfg(not(feature = "groupby_list"))]
{
panic!("activate 'groupby_list' feature")
}
}
}

Expand All @@ -293,8 +359,8 @@ impl<T> IntoGroupsProxy for ObjectChunked<T>
where
T: PolarsObject,
{
fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> GroupsProxy {
groupby(self.into_iter(), sorted)
fn group_tuples(&self, _multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
Ok(groupby(self.into_iter(), sorted))
}
}

Expand Down
12 changes: 5 additions & 7 deletions polars/polars-core/src/frame/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ use rayon::prelude::*;

use self::hashing::*;
use crate::prelude::*;
#[cfg(feature = "groupby_list")]
use crate::utils::Wrap;
use crate::utils::{_split_offsets, accumulate_dataframes_vertical, set_partition_size};
use crate::vector_hasher::{get_null_hash_value, AsU64, BytesHash};
use crate::POOL;
Expand Down Expand Up @@ -151,20 +149,20 @@ impl DataFrame {
// arbitrarily chosen bound, if avg no of bytes to encode is larger than this
// value we fall back to default groupby
if (lhs.get_values_size() + rhs.get_values_size()) / (lhs.len() + 1) < 128 {
pack_utf8_columns(lhs, rhs, n_partitions, sorted)
Ok(pack_utf8_columns(lhs, rhs, n_partitions, sorted))
} else {
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)?
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)
}
} else {
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)?
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)
}
}
_ => {
let keys_df = prepare_dataframe_unsorted(&by);
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)?
groupby_threaded_multiple_keys_flat(keys_df, n_partitions, sorted)
}
};
Ok(GroupBy::new(self, by, groups, None))
Ok(GroupBy::new(self, by, groups?, None))
}

/// Group DataFrame using a Series column.
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/boolean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl private::PrivateSeries for SeriesWrap<BooleanChunked> {
) -> Series {
ZipOuterJoinColumn::zip_outer_join_column(&self.0, right_column, opt_join_tuples)
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl private::PrivateSeries for SeriesWrap<CategoricalChunked> {
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, new_rev_map).into_series()
}
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
self.0.logical().group_tuples(multithreaded, sorted)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ macro_rules! impl_dyn_series {
"cannot do remainder operation on logical".into(),
))
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
self.0.group_tuples(multithreaded, sorted)
}
#[cfg(feature = "sort_multiple")]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ impl private::PrivateSeries for SeriesWrap<DatetimeChunked> {
"cannot do remainder operation on logical".into(),
))
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
self.0.group_tuples(multithreaded, sorted)
}
#[cfg(feature = "sort_multiple")]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ impl private::PrivateSeries for SeriesWrap<DurationChunked> {
"cannot do remainder operation on logical".into(),
))
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
self.0.group_tuples(multithreaded, sorted)
}
#[cfg(feature = "sort_multiple")]
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ macro_rules! impl_dyn_series {
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
NumOpsDispatch::remainder(&self.0, rhs)
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ impl private::PrivateSeries for SeriesWrap<ListChunked> {
self.0.agg_list(groups)
}

fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}
}
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ macro_rules! impl_dyn_series {
fn remainder(&self, rhs: &Series) -> PolarsResult<Series> {
NumOpsDispatch::remainder(&self.0, rhs)
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/implementations/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ where
Ok(())
}

fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}
#[cfg(feature = "zip_with")]
Expand Down

0 comments on commit faeb014

Please sign in to comment.