Skip to content

Commit

Permalink
feat: Support list group-by of non numeric lists (#15540)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 8, 2024
1 parent 2a42096 commit a8c9738
Show file tree
Hide file tree
Showing 13 changed files with 28 additions and 197 deletions.
56 changes: 0 additions & 56 deletions crates/polars-arrow/src/legacy/kernels/list_bytes_iter.rs

This file was deleted.

1 change: 0 additions & 1 deletion crates/polars-arrow/src/legacy/kernels/mod.rs
Expand Up @@ -10,7 +10,6 @@ pub mod fixed_size_list;
pub mod float;
#[cfg(feature = "compute_take")]
pub mod list;
pub mod list_bytes_iter;
pub mod pow;
pub mod rolling;
pub mod set;
Expand Down
66 changes: 8 additions & 58 deletions crates/polars-core/src/frame/group_by/into_groups.rs
@@ -1,10 +1,9 @@
#[cfg(feature = "group_by_list")]
use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter;
use arrow::legacy::kernels::sort_partition::{create_clean_partitions, partition_to_groups};
use polars_utils::total_ord::{ToTotalOrd, TotalHash};

use super::*;
use crate::config::verbose;
use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded_ca_unordered;
use crate::utils::_split_offsets;
use crate::utils::flatten::flatten_par;

Expand Down Expand Up @@ -368,63 +367,14 @@ impl IntoGroupsProxy for ListChunked {
#[allow(clippy::needless_lifetimes)]
#[allow(unused_variables)]
fn group_tuples<'a>(&'a self, multithreaded: bool, sorted: bool) -> PolarsResult<GroupsProxy> {
#[cfg(feature = "group_by_list")]
{
polars_ensure!(
self.inner_dtype().to_physical().is_numeric(),
ComputeError: "grouping on list type is only allowed if the inner type is numeric"
);

let hb = RandomState::default();
let null_h = get_null_hash_value(&hb);

let arr_to_hashes = |ca: &ListChunked| {
let mut out = Vec::with_capacity(ca.len());

for arr in ca.downcast_iter() {
out.extend(numeric_list_bytes_iter(arr)?.map(|opt_bytes| {
let hash = match opt_bytes {
Some(s) => hb.hash_one(s),
None => null_h,
};

// SAFETY:
// the underlying data is tied to self
unsafe {
std::mem::transmute::<BytesHash<'_>, BytesHash<'a>>(BytesHash::new(
opt_bytes, hash,
))
}
}))
}
Ok(out)
};
let by = &[self.clone().into_series()];
let ca = if multithreaded {
encode_rows_vertical_par_unordered(by).unwrap()
} else {
_get_rows_encoded_ca_unordered("", by).unwrap()
};

if multithreaded {
let n_partitions = _set_partition_size();
let split = _split_offsets(self.len(), n_partitions);

let groups: PolarsResult<_> = POOL.install(|| {
let bytes_hashes = split
.into_par_iter()
.map(|(offset, len)| {
let ca = self.slice(offset as i64, len);
arr_to_hashes(&ca)
})
.collect::<PolarsResult<Vec<_>>>()?;
let bytes_hashes = bytes_hashes.iter().collect::<Vec<_>>();
Ok(group_by_threaded_slice(bytes_hashes, n_partitions, sorted))
});
groups
} else {
let hashes = arr_to_hashes(self)?;
Ok(group_by(hashes.iter(), sorted))
}
}
#[cfg(not(feature = "group_by_list"))]
{
panic!("activate 'group_by_list' feature")
}
ca.group_tuples(multithreaded, sorted)
}
}

Expand Down
54 changes: 0 additions & 54 deletions crates/polars-core/src/hashing/vector_hasher.rs
@@ -1,6 +1,4 @@
use arrow::bitmap::utils::get_bit_unchecked;
#[cfg(feature = "group_by_list")]
use arrow::legacy::kernels::list_bytes_iter::numeric_list_bytes_iter;
use polars_utils::total_ord::{ToTotalOrd, TotalHash};
use rayon::prelude::*;
use xxhash_rust::xxh3::xxh3_64_with_seed;
Expand Down Expand Up @@ -379,58 +377,6 @@ impl VecHash for BooleanChunked {
}
}

#[cfg(feature = "group_by_list")]
impl VecHash for ListChunked {
fn vec_hash(&self, _random_state: RandomState, _buf: &mut Vec<u64>) -> PolarsResult<()> {
polars_ensure!(
self.inner_dtype().to_physical().is_numeric(),
ComputeError: "grouping on list type is only allowed if the inner type is numeric"
);
_buf.clear();
_buf.reserve(self.len());
let null_h = get_null_hash_value(&_random_state);

for arr in self.downcast_iter() {
_buf.extend(
numeric_list_bytes_iter(arr)?.map(|opt_bytes| match opt_bytes {
Some(s) => xxh3_64_with_seed(s, null_h),
None => null_h,
}),
)
}
Ok(())
}

fn vec_hash_combine(
&self,
_random_state: RandomState,
_hashes: &mut [u64],
) -> PolarsResult<()> {
polars_ensure!(
self.inner_dtype().to_physical().is_numeric(),
ComputeError: "grouping on list type is only allowed if the inner type is numeric"
);

let null_h = get_null_hash_value(&_random_state);

let mut offset = 0;
self.downcast_iter().try_for_each(|arr| {
numeric_list_bytes_iter(arr)?
.zip(&mut _hashes[offset..])
.for_each(|(opt_bytes, h)| {
let l = match opt_bytes {
Some(s) => xxh3_64_with_seed(s, null_h),
None => null_h,
};
*h = _boost_hash_combine(l, *h)
});
offset += arr.len();
PolarsResult::Ok(())
})?;
Ok(())
}
}

#[cfg(feature = "object")]
impl<T> VecHash for ObjectChunked<T>
where
Expand Down
19 changes: 0 additions & 19 deletions crates/polars-core/src/series/implementations/list.rs
Expand Up @@ -44,22 +44,6 @@ impl private::PrivateSeries for SeriesWrap<ListChunked> {
IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted)
}

#[cfg(feature = "group_by_list")]
fn vec_hash(&self, _build_hasher: RandomState, _buf: &mut Vec<u64>) -> PolarsResult<()> {
self.0.vec_hash(_build_hasher, _buf)?;
Ok(())
}

#[cfg(feature = "group_by_list")]
fn vec_hash_combine(
&self,
_build_hasher: RandomState,
_hashes: &mut [u64],
) -> PolarsResult<()> {
self.0.vec_hash_combine(_build_hasher, _hashes)?;
Ok(())
}

fn into_total_eq_inner<'a>(&'a self) -> Box<dyn TotalEqInner + 'a> {
(&self.0).into_total_eq_inner()
}
Expand Down Expand Up @@ -154,7 +138,6 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
self.0.has_validity()
}

#[cfg(feature = "group_by_list")]
#[cfg(feature = "algorithm_group_by")]
fn unique(&self) -> PolarsResult<Series> {
if !self.inner_dtype().is_numeric() {
Expand All @@ -171,7 +154,6 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
Ok(unsafe { self.0.clone().into_series().agg_first(&groups?) })
}

#[cfg(feature = "group_by_list")]
#[cfg(feature = "algorithm_group_by")]
fn n_unique(&self) -> PolarsResult<usize> {
if !self.inner_dtype().is_numeric() {
Expand All @@ -189,7 +171,6 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
}
}

#[cfg(feature = "group_by_list")]
#[cfg(feature = "algorithm_group_by")]
fn arg_unique(&self) -> PolarsResult<IdxCa> {
if !self.inner_dtype().is_numeric() {
Expand Down
1 change: 0 additions & 1 deletion crates/polars-ops/Cargo.toml
Expand Up @@ -104,7 +104,6 @@ extract_jsonpath = ["serde_json", "jsonpath_lib", "polars-json"]
log = []
hash = []
reinterpret = ["polars-core/reinterpret"]
group_by_list = ["polars-core/group_by_list"]
rolling_window = ["polars-core/rolling_window"]
moment = []
mode = []
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-ops/src/series/ops/is_first_distinct.rs
Expand Up @@ -88,7 +88,6 @@ fn is_first_distinct_struct(s: &Series) -> PolarsResult<BooleanChunked> {
Ok(BooleanChunked::with_chunk(s.name(), arr))
}

#[cfg(feature = "group_by_list")]
fn is_first_distinct_list(ca: &ListChunked) -> PolarsResult<BooleanChunked> {
let groups = ca.group_tuples(true, false)?;
let first = groups.take_group_firsts();
Expand Down Expand Up @@ -136,7 +135,6 @@ pub fn is_first_distinct(s: &Series) -> PolarsResult<BooleanChunked> {
},
#[cfg(feature = "dtype-struct")]
Struct(_) => return is_first_distinct_struct(&s),
#[cfg(feature = "group_by_list")]
List(inner) if inner.is_numeric() => {
let ca = s.list().unwrap();
return is_first_distinct_list(ca);
Expand Down
2 changes: 0 additions & 2 deletions crates/polars-ops/src/series/ops/is_last_distinct.rs
Expand Up @@ -40,7 +40,6 @@ pub fn is_last_distinct(s: &Series) -> PolarsResult<BooleanChunked> {
},
#[cfg(feature = "dtype-struct")]
Struct(_) => return is_last_distinct_struct(&s),
#[cfg(feature = "group_by_list")]
List(inner) if inner.is_numeric() => {
let ca = s.list().unwrap();
return is_last_distinct_list(ca);
Expand Down Expand Up @@ -157,7 +156,6 @@ fn is_last_distinct_struct(s: &Series) -> PolarsResult<BooleanChunked> {
Ok(BooleanChunked::with_chunk(s.name(), arr))
}

#[cfg(feature = "group_by_list")]
fn is_last_distinct_list(ca: &ListChunked) -> PolarsResult<BooleanChunked> {
let groups = ca.group_tuples(true, false)?;
// SAFETY: all groups have at least a single member
Expand Down
1 change: 0 additions & 1 deletion crates/polars/Cargo.toml
Expand Up @@ -160,7 +160,6 @@ extract_jsonpath = [
]
find_many = ["polars-plan/find_many"]
fused = ["polars-ops/fused", "polars-lazy?/fused"]
group_by_list = ["polars-core/group_by_list", "polars-ops/group_by_list"]
interpolate = ["polars-ops/interpolate", "polars-lazy?/interpolate"]
is_between = ["polars-lazy?/is_between", "polars-ops/is_between"]
is_first_distinct = ["polars-lazy?/is_first_distinct", "polars-ops/is_first_distinct"]
Expand Down
1 change: 0 additions & 1 deletion crates/polars/src/lib.rs
Expand Up @@ -225,7 +225,6 @@
//! - `asof_join` - Join ASOF, to join on nearest keys instead of exact equality match.
//! - `cross_join` - Create the Cartesian product of two [`DataFrame`]s.
//! - `semi_anti_join` - SEMI and ANTI joins.
//! - `group_by_list` - Allow group_by operation on keys of type List.
//! - `row_hash` - Utility to hash [`DataFrame`] rows to [`UInt64Chunked`]
//! - `diagonal_concat` - Concat diagonally thereby combining different schemas.
//! - `dataframe_arithmetic` - Arithmetic on ([`Dataframe`] and [`DataFrame`]s) and ([`DataFrame`] on [`Series`])
Expand Down
1 change: 0 additions & 1 deletion docs/user-guide/installation.md
Expand Up @@ -125,7 +125,6 @@ The opt-in features are:
- `join_asof` - Join ASOF, to join on nearest keys instead of exact equality match.
- `cross_join` - Create the Cartesian product of two DataFrames.
- `semi_anti_join` - SEMI and ANTI joins.
- `group_by_list` - Allow group by operation on keys of type List.
- `row_hash` - Utility to hash DataFrame rows to UInt64Chunked
- `diagonal_concat` - Concat diagonally thereby combining different schemas.
- `dataframe_arithmetic` - Arithmetic on (Dataframe and DataFrames) and (DataFrame on Series)
Expand Down
1 change: 0 additions & 1 deletion py-polars/Cargo.toml
Expand Up @@ -163,7 +163,6 @@ dtypes = [
"dtype-i16",
"dtype-u8",
"dtype-u16",
"polars/group_by_list",
"object",
]

Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Expand Up @@ -815,3 +815,23 @@ def test_join_projection_invalid_name_contains_suffix_15243() -> None:
.select(pl.col("b").filter(pl.col("b") == pl.col("foo_right")))
.collect()
)


def test_join_list_non_numeric() -> None:
assert (
pl.DataFrame(
{
"lists": [
["a", "b", "c"],
["a", "c", "b"],
["a", "c", "b"],
["a", "c", "d"],
]
}
)
).group_by("lists", maintain_order=True).agg(pl.len().alias("count")).to_dict(
as_series=False
) == {
"lists": [["a", "b", "c"], ["a", "c", "b"], ["a", "c", "d"]],
"count": [1, 2, 1],
}

0 comments on commit a8c9738

Please sign in to comment.