Skip to content

Commit

Permalink
improve pivot performance (#2458)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 24, 2022
1 parent 12dea7f commit 527553d
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 249 deletions.
130 changes: 74 additions & 56 deletions polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use rayon::prelude::*;
use std::cmp::Ordering;

use crate::frame::groupby::{GroupsIndicator, GroupsProxy};
use crate::frame::row::AnyValueBuffer;
use crate::POOL;
#[cfg(feature = "dtype-date")]
use arrow::temporal_conversions::date32_to_date;
Expand Down Expand Up @@ -133,7 +134,7 @@ impl DataFrame {
// make sure that we make smaller dataframes then the take operations are cheaper
let mut index_df = self.select(index)?;

let mut im_result = POOL.install(|| {
let im_result = POOL.install(|| {
groups
.par_iter()
.map(|indicator| {
Expand Down Expand Up @@ -226,71 +227,88 @@ impl DataFrame {
// for every column where the values are aggregated
let df_cols = (0..columns.len())
.zip(columns_unique)
.flat_map(|(i, unique_vals)| {
let im_results = im_result
.iter_mut()
.map(|v| std::mem::take(&mut v[i + 1]))
.collect::<Vec<_>>();
.flat_map(|(column_index, unique_vals)| {
// the values that will be the new headers

let mut result_map = PlHashMap::with_capacity(unique_vals.len());
for av in unique_vals.iter() {
let value = Vec::with_capacity(groups.len());
result_map.insert(av, value);
// Join every row with the unique column. This join is needed because some rows don't have all values and we want to have
// nulls there.
let result_columns = POOL.install(|| {
im_result
.par_iter()
.map(|im_r| {
// we offset 1 because the first is the group index (can be removed?)
let current_result = &im_r[column_index + 1];
let key = &current_result.get_columns()[0];
let tuples = unique_vals.hash_join_left(key);
let mut iter = tuples.iter().map(|t| t.1.map(|i| i as usize));

let values = &current_result.get_columns()[1];
// Safety
// join tuples are in bounds
unsafe { values.take_opt_iter_unchecked(&mut iter) }
})
.collect::<Vec<_>>()
});

let mut dtype = self
.column(&values[column_index])
.unwrap()
.dtype()
.to_physical();
match (dtype.clone(), &agg_fn) {
(DataType::Float32, PivotAgg::Mean | PivotAgg::Median) => {}
(_, PivotAgg::Mean | PivotAgg::Median) => dtype = DataType::Float64,
(_, PivotAgg::Count) => dtype = DataType::UInt32,
_ => {}
}
let len = result_columns.len();
let mut buffers = (0..unique_vals.len())
.map(|_| {
let buf: AnyValueBuffer = (&dtype, len).into();
buf
})
.collect::<Vec<_>>();

for df in im_results.iter() {
let keys = &df.columns[0];
let vals = &df.columns[1];

// some groups are not available in the intermediate result,
// so we must make sure we keep track of those and insert Null
// in the proper locations.
// because the rows are not guaranteed to be ordered, we must use
// a hash set to keep track the rows we've processed
// in the end, we will take the none processed values and push a null for those.
let not_all = keys.len() != unique_vals.len();

let mut remaining = if not_all {
PlHashSet::from_iter(unique_vals.iter())
} else {
PlHashSet::new()
};

keys.iter().zip(vals.iter()).for_each(|(k, v)| {
if not_all {
remaining.remove(&k);
};

result_map.entry(k).and_modify(|buf| buf.push(v));
// this is very expensive. A lot of cache misses here.
// This is the part that is performance critical.
result_columns.iter().for_each(|s| {
s.iter().zip(buffers.iter_mut()).for_each(|(av, buf)| {
let _out = buf.add(av);
debug_assert!(_out.is_some());
});

for k in remaining {
result_map
.entry(k)
.and_modify(|buf| buf.push(AnyValue::Null));
}
}
let mut df_cols = result_map
});
let cols = buffers
.into_iter()
.map(|(k, v)| {
let name = k.to_string();
let name_slice = if let AnyValue::Utf8(_) = k {
// slice off the quotation marks;
&name[1..name.len() - 1]
} else {
name.as_str()
};
Series::from_any_values(name_slice, &v)
.enumerate()
.map(|(i, buf)| {
let mut s = buf.into_series();
s.rename(&format!("{i}"));
s
})
.collect::<Vec<_>>();
let mut out = DataFrame::new_no_checks(cols);

// add the headers based on the unique vals
let headers = unique_vals.cast(&DataType::Utf8).unwrap();
let headers = headers.utf8().unwrap();
out.get_columns_mut()
.iter_mut()
.zip(headers.into_iter())
.for_each(|(s, name)| {
match name {
None => s.rename("null"),
Some(name) => s.rename(name),
};
});

// make output predictable
sort_cols(out.get_columns_mut(), 0);

sort_cols(&mut df_cols, 0);
let df = DataFrame::new_no_checks(df_cols);
let column_name = &columns[i];
let values_name = &values[i];
let column_name = &columns[column_index];
let values_name = &values[column_index];
let columns_s = self.column(column_name).unwrap();
let values_s = self.column(values_name).unwrap();
finish_logical_types(df, columns_s, values_s)
finish_logical_types(out, columns_s, values_s)
.unwrap()
.columns
})
Expand Down
55 changes: 30 additions & 25 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ impl DataFrame {
.fields()
.iter()
.map(|fld| {
let buf: Buffer = (fld.data_type(), capacity).into();
let buf: AnyValueBuffer = (fld.data_type(), capacity).into();
buf
})
.collect();

rows.try_for_each::<_, Result<()>>(|row| {
for (value, buf) in row.0.iter().zip(&mut buffers) {
buf.add(value.clone())?
buf.add_falible(value.clone())?
}
Ok(())
})?;
Expand Down Expand Up @@ -236,7 +236,7 @@ impl From<&Row<'_>> for Schema {
}
}

pub(crate) enum Buffer {
pub(crate) enum AnyValueBuffer {
Boolean(BooleanChunkedBuilder),
Int32(PrimitiveChunkedBuilder<Int32Type>),
Int64(PrimitiveChunkedBuilder<Int64Type>),
Expand All @@ -258,9 +258,9 @@ pub(crate) enum Buffer {
List(Box<dyn ListBuilderTrait>),
}

impl Debug for Buffer {
impl Debug for AnyValueBuffer {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use Buffer::*;
use AnyValueBuffer::*;
match self {
Boolean(_) => f.write_str("boolean"),
Int32(_) => f.write_str("i32"),
Expand All @@ -281,9 +281,9 @@ impl Debug for Buffer {
}
}

impl Buffer {
fn add(&mut self, val: AnyValue) -> Result<()> {
use Buffer::*;
impl AnyValueBuffer {
pub(crate) fn add(&mut self, val: AnyValue) -> Option<()> {
use AnyValueBuffer::*;
match (self, val) {
(Boolean(builder), AnyValue::Boolean(v)) => builder.append_value(v),
(Boolean(builder), AnyValue::Null) => builder.append_null(),
Expand Down Expand Up @@ -315,14 +315,19 @@ impl Buffer {
(Utf8(builder), AnyValue::Null) => builder.append_null(),
(List(builder), AnyValue::List(v)) => builder.append_series(&v),
(List(builder), AnyValue::Null) => builder.append_null(),
(buf, val) => return Err(PolarsError::ValueError(format!("Could not append {:?} to builder {:?}; make sure that all rows have the same schema.", val, std::mem::discriminant(buf)).into()))
_ => return None,
};
Some(())
}

Ok(())
pub(crate) fn add_falible(&mut self, val: AnyValue) -> Result<()> {
self.add(val.clone()).ok_or_else(|| {
PolarsError::ValueError(format!("Could not append {:?} to builder; make sure that all rows have the same schema.", val).into())
})
}

fn into_series(self) -> Series {
use Buffer::*;
pub(crate) fn into_series(self) -> Series {
use AnyValueBuffer::*;
match self {
Boolean(b) => b.finish().into_series(),
Int32(b) => b.finish().into_series(),
Expand All @@ -344,28 +349,28 @@ impl Buffer {
}

// datatype and length
impl From<(&DataType, usize)> for Buffer {
impl From<(&DataType, usize)> for AnyValueBuffer {
fn from(a: (&DataType, usize)) -> Self {
let (dt, len) = a;
use DataType::*;
match dt {
Boolean => Buffer::Boolean(BooleanChunkedBuilder::new("", len)),
Int32 => Buffer::Int32(PrimitiveChunkedBuilder::new("", len)),
Int64 => Buffer::Int64(PrimitiveChunkedBuilder::new("", len)),
UInt32 => Buffer::UInt32(PrimitiveChunkedBuilder::new("", len)),
UInt64 => Buffer::UInt64(PrimitiveChunkedBuilder::new("", len)),
Boolean => AnyValueBuffer::Boolean(BooleanChunkedBuilder::new("", len)),
Int32 => AnyValueBuffer::Int32(PrimitiveChunkedBuilder::new("", len)),
Int64 => AnyValueBuffer::Int64(PrimitiveChunkedBuilder::new("", len)),
UInt32 => AnyValueBuffer::UInt32(PrimitiveChunkedBuilder::new("", len)),
UInt64 => AnyValueBuffer::UInt64(PrimitiveChunkedBuilder::new("", len)),
#[cfg(feature = "dtype-date")]
Date => Buffer::Date(PrimitiveChunkedBuilder::new("", len)),
Date => AnyValueBuffer::Date(PrimitiveChunkedBuilder::new("", len)),
#[cfg(feature = "dtype-datetime")]
Datetime(tu, tz) => {
Buffer::Datetime(PrimitiveChunkedBuilder::new("", len), *tu, tz.clone())
AnyValueBuffer::Datetime(PrimitiveChunkedBuilder::new("", len), *tu, tz.clone())
}
#[cfg(feature = "dtype-time")]
Time => Buffer::Time(PrimitiveChunkedBuilder::new("", len)),
Float32 => Buffer::Float32(PrimitiveChunkedBuilder::new("", len)),
Float64 => Buffer::Float64(PrimitiveChunkedBuilder::new("", len)),
Utf8 => Buffer::Utf8(Utf8ChunkedBuilder::new("", len, len * 5)),
List(inner) => Buffer::List(get_list_builder(inner, len * 10, len, "")),
Time => AnyValueBuffer::Time(PrimitiveChunkedBuilder::new("", len)),
Float32 => AnyValueBuffer::Float32(PrimitiveChunkedBuilder::new("", len)),
Float64 => AnyValueBuffer::Float64(PrimitiveChunkedBuilder::new("", len)),
Utf8 => AnyValueBuffer::Utf8(Utf8ChunkedBuilder::new("", len, len * 5)),
List(inner) => AnyValueBuffer::List(get_list_builder(inner, len * 10, len, "")),
_ => unimplemented!(),
}
}
Expand Down

0 comments on commit 527553d

Please sign in to comment.