Skip to content

Commit

Permalink
more map_many instead of map_binary and more expression parallelism
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 31, 2021
1 parent 4848b99 commit 943c1b5
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 35 deletions.
54 changes: 33 additions & 21 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,19 @@ impl GetOutput {
fld
}))
}

pub fn map_dtypes<F>(f: F) -> Self
where
F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync,
{
NoEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
let mut fld = flds[0].clone();
let dtypes = flds.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = f(&dtypes);
fld.coerce(new_type);
fld
}))
}
}

impl<F> FunctionOutputField for F
Expand Down Expand Up @@ -1226,12 +1239,13 @@ impl Expr {

/// Replace the null values by a value.
pub fn fill_null(self, fill_value: Expr) -> Self {
map_binary_lazy_field(
self,
fill_value,
|a, b| {
if !a.has_validity() {
Ok(a)
self.map_many(
|s| {
let a = &s[0];
let b = &s[1];

if !a.null_count() == 0 {
Ok(a.clone())
} else {
let st = get_supertype(a.dtype(), b.dtype())?;
let a = a.cast(&st)?;
Expand All @@ -1240,10 +1254,8 @@ impl Expr {
a.zip_with_same_type(&mask, &b)
}
},
|_schema, _ctx, a, b| {
let st = get_supertype(a.data_type(), b.data_type()).unwrap();
Some(Field::new(a.name(), st))
},
&[fill_value],
GetOutput::map_dtypes(|dtypes| get_supertype(dtypes[0], dtypes[1]).unwrap()),
)
}

Expand Down Expand Up @@ -1426,16 +1438,18 @@ impl Expr {
#[cfg(feature = "repeat_by")]
#[cfg_attr(docsrs, doc(cfg(feature = "repeat_by")))]
pub fn repeat_by(self, by: Expr) -> Expr {
let function = |s: Series, by: Series| {
let function = |s: &mut [Series]| {
let by = &s[1];
let s = &s[0];
let by = by.cast(&DataType::UInt32)?;
Ok(s.repeat_by(by.u32()?).into_series())
};
map_binary_lazy_field(self, by, function, |_schema, _ctxt, l, _r| {
Some(Field::new(
l.name(),
DataType::List(l.data_type().clone().into()),
))
})

self.map_many(
function,
&[by],
GetOutput::map_dtype(|dt| DataType::List(dt.clone().into())),
)
}

#[cfg(feature = "is_first")]
Expand All @@ -1452,11 +1466,9 @@ impl Expr {
#[cfg(feature = "dot_product")]
#[cfg_attr(docsrs, doc(cfg(feature = "dot_product")))]
pub fn dot(self, other: Expr) -> Expr {
let function = |s: Series, other: Series| Ok((&s * &other).sum_as_series());
let function = |s: &mut [Series]| Ok((&s[0] * &s[1]).sum_as_series());

map_binary_lazy_field(self, other, function, |_schema, _ctxt, l, _r| {
Some(Field::new(l.name(), l.data_type().clone()))
})
self.map_many(function, &[other], GetOutput::same_type())
}

#[cfg(feature = "mode")]
Expand Down
16 changes: 12 additions & 4 deletions polars/polars-lazy/src/physical_plan/expressions/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,12 @@ impl PhysicalExpr for FilterExpr {
}

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let series = self.input.evaluate(df, state)?;
let predicate = self.by.evaluate(df, state)?;
let s_f = || self.input.evaluate(df, state);
let predicate_f = || self.by.evaluate(df, state);

let (series, predicate) = POOL.install(|| rayon::join(s_f, predicate_f));
let (series, predicate) = (series?, predicate?);

series.filter(predicate.bool()?)
}

Expand All @@ -34,8 +38,12 @@ impl PhysicalExpr for FilterExpr {
groups: &'a GroupTuples,
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let mut ac_s = self.input.evaluate_on_groups(df, groups, state)?;
let ac_predicate = self.by.evaluate_on_groups(df, groups, state)?;
let ac_s_f = || self.input.evaluate_on_groups(df, groups, state);
let ac_predicate_f = || self.by.evaluate_on_groups(df, groups, state);

let (ac_s, ac_predicate) = POOL.install(|| rayon::join(ac_s_f, ac_predicate_f));
let (mut ac_s, ac_predicate) = (ac_s?, ac_predicate?);

let groups = ac_s.groups();
let predicate_s = ac_predicate.flat_naive();
let predicate = predicate_s.bool()?;
Expand Down
28 changes: 18 additions & 10 deletions polars/polars-lazy/src/physical_plan/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::physical_plan::state::ExecutionState;
use crate::prelude::*;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use polars_core::POOL;
use rayon::prelude::*;
use std::sync::Arc;

Expand Down Expand Up @@ -45,21 +46,28 @@ impl PhysicalExpr for SortByExpr {
}

fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> Result<Series> {
let series = self.input.evaluate(df, state)?;
let series_f = || self.input.evaluate(df, state);
let reverse = prepare_reverse(&self.reverse, self.by.len());

let sorted_idx = if self.by.len() == 1 {
let s_sort_by = self.by[0].evaluate(df, state)?;
s_sort_by.argsort(reverse[0])
let (series, sorted_idx) = if self.by.len() == 1 {
let sorted_idx_f = || {
let s_sort_by = self.by[0].evaluate(df, state)?;
Ok(s_sort_by.argsort(reverse[0]))
};
POOL.install(|| rayon::join(series_f, sorted_idx_f))
} else {
let s_sort_by = self
.by
.iter()
.map(|e| e.evaluate(df, state))
.collect::<Result<Vec<_>>>()?;
let sorted_idx_f = || {
let s_sort_by = self
.by
.iter()
.map(|e| e.evaluate(df, state))
.collect::<Result<Vec<_>>>()?;

s_sort_by[0].argsort_multiple(&s_sort_by[1..], &reverse)?
s_sort_by[0].argsort_multiple(&s_sort_by[1..], &reverse)
};
POOL.install(|| rayon::join(series_f, sorted_idx_f))
};
let (sorted_idx, series) = (sorted_idx?, series?);

// Safety:
// sorted index are within bounds
Expand Down

0 comments on commit 943c1b5

Please sign in to comment.