Skip to content

Commit

Permalink
fix quadratic behavior pivot (#2455)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 24, 2022
1 parent f5700ef commit c649a20
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 65 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::convert::TryFrom;

#[inline]
#[allow(unused_variables)]
unsafe fn arr_to_any_value<'a>(
pub(crate) unsafe fn arr_to_any_value<'a>(
arr: &'a dyn Array,
idx: usize,
categorical_map: &'a Option<Arc<RevMapping>>,
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::ops::Deref;
#[cfg(feature = "abs")]
mod abs;
pub(crate) mod aggregate;
mod any_value;
pub(crate) mod any_value;
mod append;
mod apply;
mod bit_repr;
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-core/src/chunked_array/ops/take/traits.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
//! Traits that indicate the allowed arguments in a ChunkedArray::take operation.
use crate::frame::groupby::GroupsProxyIter;
use crate::prelude::*;
use arrow::array::UInt32Array;
use polars_arrow::array::PolarsArray;
Expand All @@ -13,6 +14,7 @@ pub trait TakeIteratorNulls: Iterator<Item = Option<usize>> + TrustedLen {

unsafe impl TrustedLen for &mut dyn TakeIterator {}
unsafe impl TrustedLen for &mut dyn TakeIteratorNulls {}
unsafe impl TrustedLen for GroupsProxyIter<'_> {}

// Implement for the ref as well
impl TakeIterator for &mut dyn TakeIterator {
Expand Down
25 changes: 25 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub use crate::chunked_array::logical::*;
use crate::chunked_array::object::PolarsObjectSafe;
use crate::chunked_array::ops::sort::PlIsNan;
use crate::prelude::*;
use crate::utils::Wrap;
use ahash::RandomState;
use arrow::compute::arithmetics::basic::NativeArithmetics;
use arrow::compute::comparison::Simd8;
Expand All @@ -26,6 +27,7 @@ use num::{Bounded, FromPrimitive, Num, NumCast, Zero};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};
use std::ops::{Add, AddAssign, Div, Mul, Rem, Sub};

pub struct Utf8Type {}
Expand Down Expand Up @@ -261,6 +263,29 @@ pub enum AnyValue<'a> {
Object(&'a dyn PolarsObjectSafe),
}

impl<'a> Hash for AnyValue<'a> {
fn hash<H: Hasher>(&self, state: &mut H) {
use AnyValue::*;
match self {
Null => state.write_u64(u64::MAX / 2 + 135123),
Int8(v) => state.write_i8(*v),
Int16(v) => state.write_i16(*v),
Int32(v) => state.write_i32(*v),
Int64(v) => state.write_i64(*v),
UInt8(v) => state.write_u8(*v),
UInt16(v) => state.write_u16(*v),
UInt32(v) => state.write_u32(*v),
UInt64(v) => state.write_u64(*v),
Utf8(s) => state.write(s.as_bytes()),
Boolean(v) => state.write_u8(*v as u8),
List(v) => Hash::hash(&Wrap(v.clone()), state),
_ => unimplemented!(),
}
}
}

impl<'a> Eq for AnyValue<'a> {}

impl From<f64> for AnyValue<'_> {
fn from(a: f64) -> Self {
AnyValue::Float64(a)
Expand Down
186 changes: 129 additions & 57 deletions polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
use super::GroupBy;
use crate::prelude::*;
use rayon::prelude::*;
use std::borrow::Cow;
use std::cmp::Ordering;

use crate::frame::groupby::{GroupsIndicator, GroupsProxy};
use crate::utils::accumulate_dataframes_vertical;
use crate::POOL;
#[cfg(feature = "dtype-date")]
use arrow::temporal_conversions::date32_to_date;
Expand Down Expand Up @@ -123,14 +121,17 @@ impl DataFrame {
.column(columns[i].as_str())?
.to_physical_repr()
.into_owned();
let values = self.column(values[i].as_str())?;
let values = self
.column(values[i].as_str())?
.to_physical_repr()
.into_owned();

Ok(DataFrame::new_no_checks(vec![values.clone(), column]))
Ok(DataFrame::new_no_checks(vec![values, column]))
})
.collect::<Result<Vec<_>>>()?;

// make sure that we make smaller dataframes then the take operations are cheaper
let index_df = self.select(index)?;
let mut index_df = self.select(index)?;

let mut im_result = POOL.install(|| {
groups
Expand Down Expand Up @@ -215,65 +216,90 @@ impl DataFrame {
.collect::<Vec<_>>()
});
// Now we have a lot of small DataFrames with aggregation results
// we first join them together.
// This will lead to a long dataframe that finally is transposed
// we must map the results to the right column. This requires a hashmap

let columns_unique = columns
.iter()
.map(|name| self.column(name)?.to_physical_repr().unique())
.collect::<Result<Vec<_>>>()?;

// for every column where the values are aggregated
let mut all_values = (0..columns.len())
.map(|i| {
let to_join = im_result
let df_cols = (0..columns.len())
.zip(columns_unique)
.flat_map(|(i, unique_vals)| {
let im_results = im_result
.iter_mut()
.map(|v| std::mem::take(&mut v[i + 1]))
.collect::<Vec<_>>();
let mut name_count = 0;

let mut joined = to_join
.iter()
.map(Cow::Borrowed)
.reduce(|df_l, df_r| {
let mut out = df_l
.outer_join(&df_r, [columns[i].as_str()], [columns[i].as_str()])
.unwrap();
let last_idx = out.width() - 1;
out.columns[last_idx].rename(&format!("{}_{}", values[i], name_count));
name_count += 1;
Cow::Owned(out)
})
.unwrap()
.into_owned();
let header = joined
.drop_in_place(&columns[i])
.unwrap()
.cast(&DataType::Utf8)
.unwrap();
let header = header.utf8().unwrap();
let mut values = joined.transpose().unwrap();

for (opt_name, s) in header.into_iter().zip(values.columns.iter_mut()) {
match opt_name {
None => s.rename("null"),
Some(v) => s.rename(v),
};

let mut result_map = PlHashMap::with_capacity(unique_vals.len());
for av in unique_vals.iter() {
let value = Vec::with_capacity(groups.len());
result_map.insert(av, value);
}
values
})
.collect::<Vec<_>>();

let indices = im_result.iter_mut().map(|v| std::mem::take(&mut v[0]));
let mut out = accumulate_dataframes_vertical(indices).unwrap();
for df in im_results.iter() {
let keys = &df.columns[0];
let vals = &df.columns[1];

// some groups are not available in the intermediate result,
// so we must make sure we keep track of those and insert Null
// in the proper locations.
// because the rows are not guaranteed to be ordered, we must use
// a hash set to keep track the rows we've processed
// in the end, we will take the none processed values and push a null for those.
let not_all = keys.len() != unique_vals.len();

let mut remaining = if not_all {
PlHashSet::from_iter(unique_vals.iter())
} else {
PlHashSet::new()
};

// values is the dataframe to stack
// columns is the original series that is pivoted
for (values, columns) in all_values.iter_mut().zip(columns) {
let mut cols = std::mem::take(&mut values.columns);
sort_cols(&mut cols, 0);
keys.iter().zip(vals.iter()).for_each(|(k, v)| {
if not_all {
remaining.remove(&k);
};

let df = DataFrame::new_no_checks(cols);
let df = finish_logical_types(df, self.column(columns).unwrap()).unwrap();
result_map.entry(k).and_modify(|buf| buf.push(v));
});

out = out.hstack(&df.columns)?
}
Ok(out)
for k in remaining {
result_map
.entry(k)
.and_modify(|buf| buf.push(AnyValue::Null));
}
}
let mut df_cols = result_map
.into_iter()
.map(|(k, v)| {
let name = k.to_string();
let name_slice = if let AnyValue::Utf8(_) = k {
// slice off the quotation marks;
&name[1..name.len() - 1]
} else {
name.as_str()
};
Series::from_any_values(name_slice, &v)
})
.collect::<Vec<_>>();

sort_cols(&mut df_cols, 0);
let df = DataFrame::new_no_checks(df_cols);
let column_name = &columns[i];
let values_name = &values[i];
let columns_s = self.column(column_name).unwrap();
let values_s = self.column(values_name).unwrap();
finish_logical_types(df, columns_s, values_s)
.unwrap()
.columns
})
.collect::<Vec<_>>();

index_df.columns.iter_mut().for_each(|s| {
*s = s.agg_first(groups);
});
index_df.hstack(&df_cols)
}
}

Expand Down Expand Up @@ -404,7 +430,14 @@ fn sort_cols(cols: &mut [Series], offset: usize) {
});
}

fn finish_logical_types(mut out: DataFrame, columns: &Series) -> Result<DataFrame> {
// Takes a `DataFrame` that only consists of the column aggregates that are pivoted by
// the values in `columns`
fn finish_logical_types(
mut out: DataFrame,
columns: &Series,
values: &Series,
) -> Result<DataFrame> {
// We cast the column headers to another string repr
match columns.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => {
Expand Down Expand Up @@ -439,6 +472,29 @@ fn finish_logical_types(mut out: DataFrame, columns: &Series) -> Result<DataFram
}
_ => {}
}

let dtype = values.dtype();
match dtype {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => {
let piv = columns.categorical().unwrap();
let rev_map = piv.categorical_map.as_ref().cloned();

for s in out.columns.iter_mut() {
let s_ = s.cast(&DataType::Categorical).unwrap();
let mut ca = s_.categorical().unwrap().clone();
ca.categorical_map = rev_map.clone();
*s = ca.into_series();
}
}
DataType::Datetime(_, _) | DataType::Date | DataType::Time => {
for s in out.columns.iter_mut() {
*s = s.cast(dtype).unwrap();
}
}
_ => {}
}

Ok(out)
}

Expand Down Expand Up @@ -584,8 +640,24 @@ mod test {
]?;
df.try_apply("C", |s| s.cast(&DataType::Date))?;

let out = df.groupby(["B"])?.pivot(["C"], ["A"]).count()?;
assert_eq!(out.get_column_names(), &["B", "1972-09-27"]);
let out = df.groupby_stable(["B"])?.pivot(["C"], ["A"]).count()?;
let expected = df![
"B" => [8i32, 2, 3, 6],
"1972-09-27" => [1u32, 3, 2, 2]
]?;
assert!(out.frame_equal_missing(&expected));

let mut out = df.groupby_stable(["B"])?.pivot(["A"], ["C"]).first()?;
out.try_apply("1", |s| {
let ca = s.date()?;
Ok(ca.strftime("%Y-%d-%m"))
})?;

let expected = df![
"B" => [8i32, 2, 3, 6],
"1" => ["1972-27-09", "1972-27-09", "1972-27-09", "1972-27-09"]
]?;
assert!(out.frame_equal_missing(&expected));

Ok(())
}
Expand Down
6 changes: 6 additions & 0 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -297,10 +297,16 @@ impl Buffer {
(UInt64(builder), AnyValue::Null) => builder.append_null(),
#[cfg(feature = "dtype-date")]
(Date(builder), AnyValue::Null) => builder.append_null(),
#[cfg(feature = "dtype-date")]
(Date(builder), AnyValue::Date(v)) => builder.append_value(v),
#[cfg(feature = "dtype-datetime")]
(Datetime(builder, _, _), AnyValue::Null) => builder.append_null(),
#[cfg(feature = "dtype-datetime")]
(Datetime(builder, _, _), AnyValue::Datetime(v, _, _)) => builder.append_value(v),
#[cfg(feature = "dtype-time")]
(Time(builder), AnyValue::Time(v)) => builder.append_value(v),
#[cfg(feature = "dtype-time")]
(Time(builder), AnyValue::Null) => builder.append_null(),
(Float32(builder), AnyValue::Float32(v)) => builder.append_value(v),
(Float32(builder), AnyValue::Null) => builder.append_null(),
(Float64(builder), AnyValue::Float64(v)) => builder.append_value(v),
Expand Down

0 comments on commit c649a20

Please sign in to comment.