Skip to content

Commit

Permalink
perf(rust, python): improve streaming performance (~15%) (#5170)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 12, 2022
1 parent f86c3fa commit f8f3af9
Show file tree
Hide file tree
Showing 10 changed files with 113 additions and 52 deletions.
1 change: 1 addition & 0 deletions polars/polars-lazy/polars-pipe/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
enum_dispatch = "0.3"
hashbrown.workspace = true
num.wokspace = true
polars-core = { version = "0.24.2", path = "../../polars-core", features = ["lazy", "private", "zip_with", "random"], default-features = false }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::executors::sinks::groupby::aggregates::count::CountAgg;
use crate::executors::sinks::groupby::aggregates::first::FirstAgg;
use crate::executors::sinks::groupby::aggregates::last::LastAgg;
use crate::executors::sinks::groupby::aggregates::mean::MeanAgg;
use crate::executors::sinks::groupby::aggregates::{AggregateFn, SumAgg};
use crate::executors::sinks::groupby::aggregates::{AggregateFunction, SumAgg};
use crate::expressions::PhysicalPipedExpr;
use crate::operators::DataChunk;

Expand Down Expand Up @@ -83,64 +83,65 @@ pub fn convert_to_hash_agg<F>(
expr_arena: &Arena<AExpr>,
schema: &Schema,
to_physical: &F,
) -> (Arc<dyn PhysicalPipedExpr>, Box<dyn AggregateFn>)
) -> (Arc<dyn PhysicalPipedExpr>, AggregateFunction)
where
F: Fn(Node, &Arena<AExpr>) -> PolarsResult<Arc<dyn PhysicalPipedExpr>>,
{
match expr_arena.get(node) {
AExpr::Alias(input, _) => convert_to_hash_agg(*input, expr_arena, schema, to_physical),
AExpr::Count => (Arc::new(Count {}), Box::new(CountAgg::new())),
AExpr::Count => (
Arc::new(Count {}),
AggregateFunction::Count(CountAgg::new()),
),
AExpr::Agg(agg) => match agg {
AAggExpr::Sum(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let agg_fn = match phys_expr.field(schema).unwrap().dtype.to_physical() {
// Boolean is aggregated as the IDX type.
DataType::Boolean => Box::new(SumAgg::<IdxSize>::new()) as Box<dyn AggregateFn>,
DataType::Boolean => {
if std::mem::size_of::<IdxSize>() == 4 {
AggregateFunction::SumU32(SumAgg::<u32>::new())
} else {
AggregateFunction::SumU64(SumAgg::<u64>::new())
}
}
// these are aggregated as i64 to prevent overflow
DataType::Int8 => Box::new(SumAgg::<i64>::new()) as Box<dyn AggregateFn>,
DataType::Int16 => Box::new(SumAgg::<i64>::new()) as Box<dyn AggregateFn>,
DataType::UInt8 => Box::new(SumAgg::<i64>::new()) as Box<dyn AggregateFn>,
DataType::UInt16 => Box::new(SumAgg::<i64>::new()) as Box<dyn AggregateFn>,
DataType::Int8 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
DataType::Int16 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
DataType::UInt8 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
DataType::UInt16 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
// these stay true to there types
DataType::UInt32 => Box::new(SumAgg::<u32>::new()) as Box<dyn AggregateFn>,
DataType::UInt64 => Box::new(SumAgg::<u64>::new()) as Box<dyn AggregateFn>,
DataType::Int32 => Box::new(SumAgg::<i32>::new()) as Box<dyn AggregateFn>,
DataType::Int64 => Box::new(SumAgg::<i64>::new()) as Box<dyn AggregateFn>,
DataType::Float32 => Box::new(SumAgg::<f32>::new()) as Box<dyn AggregateFn>,
DataType::Float64 => Box::new(SumAgg::<f64>::new()) as Box<dyn AggregateFn>,
DataType::UInt32 => AggregateFunction::SumI32(SumAgg::<i32>::new()),
DataType::UInt64 => AggregateFunction::SumI64(SumAgg::<i64>::new()),
DataType::Int32 => AggregateFunction::SumU32(SumAgg::<u32>::new()),
DataType::Int64 => AggregateFunction::SumU64(SumAgg::<u64>::new()),
DataType::Float32 => AggregateFunction::SumF32(SumAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::SumF64(SumAgg::<f64>::new()),
_ => unreachable!(),
};
(phys_expr, agg_fn)
}
AAggExpr::Mean(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let agg_fn = match phys_expr.field(schema).unwrap().dtype.to_physical() {
dt if dt.is_integer() => {
Box::new(MeanAgg::<f64>::new()) as Box<dyn AggregateFn>
}
dt if dt.is_integer() => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
// Boolean is aggregated as the IDX type.
DataType::Boolean => Box::new(MeanAgg::<f64>::new()) as Box<dyn AggregateFn>,
DataType::Float32 => Box::new(MeanAgg::<f32>::new()) as Box<dyn AggregateFn>,
DataType::Float64 => Box::new(MeanAgg::<f64>::new()) as Box<dyn AggregateFn>,
DataType::Boolean => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
_ => unreachable!(),
};
(phys_expr, agg_fn)
}
AAggExpr::First(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let dtype = phys_expr.field(schema).unwrap().dtype;
(
phys_expr,
Box::new(FirstAgg::new(dtype)) as Box<dyn AggregateFn>,
)
(phys_expr, AggregateFunction::First(FirstAgg::new(dtype)))
}
AAggExpr::Last(input) => {
let phys_expr = to_physical(*input, expr_arena).unwrap();
let dtype = phys_expr.field(schema).unwrap().dtype;
(
phys_expr,
Box::new(LastAgg::new(dtype)) as Box<dyn AggregateFn>,
)
(phys_expr, AggregateFunction::Last(LastAgg::new(dtype)))
}
_ => todo!(),
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::operators::IdxSize;
pub struct FirstAgg {
chunk_idx: IdxSize,
first: Option<AnyValue<'static>>,
dtype: DataType,
pub(crate) dtype: DataType,
}

impl FirstAgg {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
use std::any::Any;

use enum_dispatch::enum_dispatch;
use polars_core::datatypes::DataType;
use polars_core::prelude::AnyValue;

use crate::executors::sinks::groupby::aggregates::count::CountAgg;
use crate::executors::sinks::groupby::aggregates::first::FirstAgg;
use crate::executors::sinks::groupby::aggregates::last::LastAgg;
use crate::executors::sinks::groupby::aggregates::mean::MeanAgg;
use crate::executors::sinks::groupby::aggregates::SumAgg;
use crate::operators::IdxSize;

#[enum_dispatch(AggregateFunction)]
pub trait AggregateFn: Send + Sync {
fn pre_agg(&mut self, _chunk_idx: IdxSize, item: &mut dyn ExactSizeIterator<Item = AnyValue>);

Expand All @@ -18,3 +25,43 @@ pub trait AggregateFn: Send + Sync {

fn as_any(&self) -> &dyn Any;
}

// We dispatch via an enum
// as that saves an indirection
#[enum_dispatch]
pub enum AggregateFunction {
First(FirstAgg),
Last(LastAgg),
Count(CountAgg),
SumF32(SumAgg<f32>),
SumF64(SumAgg<f64>),
SumU32(SumAgg<u32>),
SumU64(SumAgg<u64>),
SumI32(SumAgg<i32>),
SumI64(SumAgg<i64>),
MeanF32(MeanAgg<f32>),
MeanF64(MeanAgg<f64>),
// place holder for any aggregate function
// this is not preferred because of the extra
// indirection
// Other(Box<dyn AggregateFn>)
}

impl AggregateFunction {
pub(crate) fn split2(&self) -> Self {
use AggregateFunction::*;
match self {
First(agg) => First(FirstAgg::new(agg.dtype.clone())),
Last(agg) => Last(LastAgg::new(agg.dtype.clone())),
SumF32(_) => SumF32(SumAgg::new()),
SumF64(_) => SumF64(SumAgg::new()),
SumU32(_) => SumU32(SumAgg::new()),
SumU64(_) => SumU64(SumAgg::new()),
SumI32(_) => SumI32(SumAgg::new()),
SumI64(_) => SumI64(SumAgg::new()),
MeanF32(_) => MeanF32(MeanAgg::new()),
MeanF64(_) => MeanF64(MeanAgg::new()),
Count(_) => Count(CountAgg::new()),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::operators::IdxSize;
pub struct LastAgg {
chunk_idx: IdxSize,
last: Option<AnyValue<'static>>,
dtype: DataType,
pub(crate) dtype: DataType,
}

impl LastAgg {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ mod mean;
mod sum;

pub use convert::*;
pub(crate) use interface::AggregateFn;
pub(crate) use interface::{AggregateFn, AggregateFunction};
pub(crate) use sum::SumAgg;
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@ use polars_utils::unwrap::UnwrapUncheckedRelease;
use rayon::prelude::*;

use super::aggregates::AggregateFn;
use super::HASHMAP_INIT_SIZE;
use crate::executors::sinks::groupby::aggregates::AggregateFunction;
use crate::expressions::PhysicalPipedExpr;
use crate::operators::{DataChunk, PExecutionContext, Sink, SinkResult};

// We must strike a balance between cache coherence and resizing costs.
// Overallocation seems a lot more expensive than resizing so we start reasonable small.
pub(crate) const HASHMAP_INIT_SIZE: usize = 128;

// This is the hash and the Index offset in the linear buffer
type Key = (u64, IdxSize);

Expand All @@ -40,14 +38,16 @@ pub struct GenericGroupbySink {
// * offset = (idx)
// * end = (offset + n_aggs)
keys: Vec<Vec<AnyValue<'static>>>,
aggregators: Vec<Vec<Box<dyn AggregateFn>>>,
aggregators: Vec<Vec<AggregateFunction>>,
// the keys that will be aggregated on
key_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
// the columns that will be aggregated
aggregation_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
hb: RandomState,
// Aggregation functions
agg_fns: Vec<Box<dyn AggregateFn>>,
// Initializing Aggregation functions. If we aggregate by 2 columns
// this vec will have two functions. We will use these functions
// to populate the buffer where the hashmap points to
agg_fns: Vec<AggregateFunction>,
output_schema: SchemaRef,
// amortize allocations
aggregation_series: Vec<Series>,
Expand All @@ -59,7 +59,7 @@ impl GenericGroupbySink {
pub fn new(
key_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
aggregation_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
agg_fns: Vec<Box<dyn AggregateFn>>,
agg_fns: Vec<AggregateFunction>,
output_schema: SchemaRef,
) -> Self {
let hb = RandomState::default();
Expand Down Expand Up @@ -198,7 +198,7 @@ impl Sink for GenericGroupbySink {
};
// initialize the aggregators
for agg_fn in &self.agg_fns {
current_aggregators.push(agg_fn.split())
current_aggregators.push(agg_fn.split2())
}
value_offset
}
Expand Down Expand Up @@ -287,7 +287,7 @@ impl Sink for GenericGroupbySink {
entry.insert_with_hasher(h, key, values_offset, |_| h);
// initialize the new aggregators
for agg_fn in &self.agg_fns {
aggregators_self.push(agg_fn.split())
aggregators_self.push(agg_fn.split2())
}
values_offset
}
Expand All @@ -313,7 +313,7 @@ impl Sink for GenericGroupbySink {
let mut new = Self::new(
self.key_columns.clone(),
self.aggregation_columns.clone(),
self.agg_fns.iter().map(|func| func.split()).collect(),
self.agg_fns.iter().map(|func| func.split2()).collect(),
self.output_schema.clone(),
);
new.hb = self.hb.clone();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,7 @@ mod primitive;

pub(crate) use generic::*;
pub(crate) use primitive::*;

// We must strike a balance between cache coherence and resizing costs.
// Overallocation seems a lot more expensive than resizing so we start reasonable small.
const HASHMAP_INIT_SIZE: usize = 64;
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@ use polars_utils::unwrap::UnwrapUncheckedRelease;
use rayon::prelude::*;

use super::aggregates::AggregateFn;
use super::HASHMAP_INIT_SIZE;
use crate::executors::sinks::groupby::aggregates::AggregateFunction;
use crate::expressions::PhysicalPipedExpr;
use crate::operators::{DataChunk, PExecutionContext, Sink, SinkResult};

// We must strike a balance between cache coherence and resizing costs.
// Overallocation seems a lot more expensive than resizing so we start reasonable small.
pub(crate) const HASHMAP_INIT_SIZE: usize = 128;
// hash + value
#[derive(Eq, Copy, Clone)]
struct Key<T: Copy> {
Expand Down Expand Up @@ -50,13 +49,15 @@ pub struct PrimitiveGroupbySink<K: PolarsNumericType> {
// first get the correct vec by the partition index
// * offset = (idx)
// * end = (offset + n_aggs)
aggregators: Vec<Vec<Box<dyn AggregateFn>>>,
aggregators: Vec<Vec<AggregateFunction>>,
key: Arc<dyn PhysicalPipedExpr>,
// the columns that will be aggregated
aggregation_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
hb: RandomState,
// Aggregation functions
agg_fns: Vec<Box<dyn AggregateFn>>,
// Initializing Aggregation functions. If we aggregate by 2 columns
// this vec will have two functions. We will use these functions
// to populate the buffer where the hashmap points to
agg_fns: Vec<AggregateFunction>,
output_schema: SchemaRef,
// amortize allocations
aggregation_series: Vec<Series>,
Expand All @@ -67,7 +68,7 @@ impl<K: PolarsNumericType> PrimitiveGroupbySink<K> {
pub fn new(
key: Arc<dyn PhysicalPipedExpr>,
aggregation_columns: Arc<Vec<Arc<dyn PhysicalPipedExpr>>>,
agg_fns: Vec<Box<dyn AggregateFn>>,
agg_fns: Vec<AggregateFunction>,
output_schema: SchemaRef,
) -> Self {
let hb = RandomState::default();
Expand Down Expand Up @@ -159,7 +160,7 @@ where
entry.insert(key, offset);
// initialize the aggregators
for agg_fn in &self.agg_fns {
current_aggregators.push(agg_fn.split())
current_aggregators.push(agg_fn.split2())
}
offset
}
Expand Down Expand Up @@ -198,7 +199,7 @@ where
entry.insert(*key, offset);
// initialize the aggregators
for agg_fn in &self.agg_fns {
aggregators_self.push(agg_fn.split())
aggregators_self.push(agg_fn.split2())
}
offset
}
Expand All @@ -221,7 +222,7 @@ where
let mut new = Self::new(
self.key.clone(),
self.aggregation_columns.clone(),
self.agg_fns.iter().map(|func| func.split()).collect(),
self.agg_fns.iter().map(|func| func.split2()).collect(),
self.output_schema.clone(),
);
new.hb = self.hb.clone();
Expand Down
7 changes: 7 additions & 0 deletions polars/polars-lazy/polars-pipe/src/pipeline/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ impl Pipeline {
};
sink.sink(ec, chunk)
})
// only collect failed and finished messages as there should be acted upon those
// the other ones (e.g. success and can have more input) can be ignored
// this saves a lot of allocations.
.filter(|result| match result {
Ok(sink_result) => matches!(sink_result, SinkResult::Finished),
Err(_) => true,
})
.collect()
})
}
Expand Down

0 comments on commit f8f3af9

Please sign in to comment.