Skip to content

Commit

Permalink
add pivot list first (#2141)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 23, 2021
1 parent aaa5805 commit 3b01500
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 3 deletions.
30 changes: 28 additions & 2 deletions polars/polars-core/src/chunked_array/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,14 +707,40 @@ macro_rules! impl_cmp_list {
(Some(_), None) => None,
(Some(left), Some(right)) => Some(left.$cmp_method(&right)),
})
.collect(),
.collect_trusted(),
}
}};
}

impl ChunkCompare<&ListChunked> for ListChunked {
fn eq_missing(&self, rhs: &ListChunked) -> BooleanChunked {
impl_cmp_list!(self, rhs, series_equal_missing)
match (self.has_validity(), rhs.has_validity()) {
(false, false) => self
.into_no_null_iter()
.zip(rhs.into_no_null_iter())
.map(|(left, right)| left.eq(&right))
.collect_trusted(),
(false, _) => self
.into_no_null_iter()
.zip(rhs.into_iter())
.map(|(left, opt_right)| opt_right.map(|right| left.eq(&right)))
.collect_trusted(),
(_, false) => self
.into_iter()
.zip(rhs.into_no_null_iter())
.map(|(opt_left, right)| opt_left.map(|left| left.eq(&right)))
.collect_trusted(),
(_, _) => self
.into_iter()
.zip(rhs.into_iter())
.map(|(opt_left, opt_right)| match (opt_left, opt_right) {
(None, None) => true,
(None, Some(_)) => false,
(Some(_), None) => false,
(Some(left), Some(right)) => left.eq(&right),
})
.collect_trusted(),
}
}

fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
Expand Down
86 changes: 85 additions & 1 deletion polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::GroupBy;
use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;
use hashbrown::HashMap;
use itertools::Itertools;
Expand Down Expand Up @@ -406,7 +407,90 @@ impl ChunkPivot for CategoricalChunked {
}
}

impl ChunkPivot for ListChunked {}
impl ChunkPivot for ListChunked {
fn pivot<'a>(
&self,
pivot_series: &'a Series,
keys: Vec<Series>,
groups: &[(u32, Vec<u32>)],
agg_type: PivotAgg,
) -> Result<DataFrame> {
// TODO: save an allocation by creating a random access struct for the Groupable utility type.

// Note: we also create pivot_vec with unique values, otherwise we have quadratic behavior
let mut pivot_series = pivot_series.clone();
let mut pivot_unique = pivot_series.unique()?;
let iter = pivot_unique.as_groupable_iter()?;
let pivot_vec_unique: Vec<_> = iter.collect();
let iter = pivot_series.as_groupable_iter()?;
let pivot_vec: Vec<_> = iter.collect();
let values_taker = self.take_rand();
// create a hash map that will be filled with the results of the aggregation.
let mut columns_agg_map_main = {
// create a hash map that will be filled with the results of the aggregation.
let mut columns_agg_map_main = PlHashMap::new();
for column_name in pivot_vec.iter().flatten() {
columns_agg_map_main.entry(column_name).or_insert_with(|| {
get_list_builder(
&self.inner_dtype(),
groups.len(),
groups.len(),
&format!("{:?}", column_name),
)
});
}
columns_agg_map_main
};

// iterate over the groups that need to be aggregated
// idxes are the indexes of the groups in the keys, pivot, and values columns
for (_first, idx) in groups {
// for every group do the aggregation by adding them to the vector belonging by that column
// the columns are hashed with the pivot values
let mut columns_agg_map_group =
create_column_values_map::<Series>(&pivot_vec_unique, idx.len());
for &i in idx {
let i = i as usize;
let opt_pivot_val = unsafe { pivot_vec.get_unchecked(i) };

if let Some(pivot_val) = opt_pivot_val {
let values_val = values_taker.get(i);
if let Some(v) = columns_agg_map_group.get_mut(&pivot_val) {
v.push(values_val)
}
}
}

// After the vectors are filled we really do the aggregation and add the result to the main
// hash map, mapping pivot values as column to aggregate result.
for (k, v) in &mut columns_agg_map_group {
let main_builder = columns_agg_map_main.get_mut(k).unwrap();

match v.len() {
0 => main_builder.append_null(),
// NOTE: now we take first, but this is the place where all aggregations happen
_ => match agg_type {
PivotAgg::First => {
main_builder.append_opt_series(v[0].as_ref());
}
_ => unimplemented!(),
},
}
}
}
// Finalize the pivot by creating a vec of all the columns and creating a DataFrame
let mut cols = keys;
cols.reserve_exact(columns_agg_map_main.len());

for (_, mut builder) in columns_agg_map_main {
let ca = builder.finish();
cols.push(ca.into_series());
}

DataFrame::new(cols)
}
}

#[cfg(feature = "object")]
impl<T> ChunkPivot for ObjectChunked<T> {}

Expand Down
16 changes: 16 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1610,3 +1610,19 @@ def test_get_item() -> None:
df[[False, True]].frame_equal(pl.DataFrame({"a": [1.0], "b": [3]}))
with pytest.raises(IndexError):
_ = df[pl.Series("", ["hello Im a string"])]


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

expected = pl.DataFrame(
{
"a": [1, 2, 3],
"1": [[1, 1], None, None],
"2": [None, [2, 2], None],
"3": [None, None, [3, 3]],
}
)

out = df.groupby("a").pivot("a", "b").first()["a", "1", "2", "3"].sort("a")
assert out.frame_equal(expected, null_equal=True)

0 comments on commit 3b01500

Please sign in to comment.