Skip to content

Commit

Permalink
Fix ternary and execute on threadpool
Browse files Browse the repository at this point in the history
This fixes the masks in ternary expressions
composed of different aggregation states

We also run the 3 expressions on the thread pool,
easy parallelization.
  • Loading branch information
ritchie46 committed Nov 25, 2021
1 parent 022464c commit 3cbb874
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 41 deletions.
87 changes: 61 additions & 26 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prelude::*;
use polars_arrow::arrow::array::ArrayRef;
use polars_core::frame::groupby::GroupTuples;
use polars_core::prelude::*;
use polars_core::POOL;
use std::convert::TryFrom;
use std::sync::Arc;

Expand Down Expand Up @@ -36,7 +37,17 @@ impl PhysicalExpr for TernaryExpr {
state: &ExecutionState,
) -> Result<AggregationContext<'a>> {
let required_height = df.height();
let ac_mask = self.predicate.evaluate_on_groups(df, groups, state)?;

let op_mask = || self.predicate.evaluate_on_groups(df, groups, state);
let op_truthy = || self.truthy.evaluate_on_groups(df, groups, state);
let op_falsy = || self.falsy.evaluate_on_groups(df, groups, state);

let (ac_mask, (ac_truthy, ac_falsy)) =
POOL.install(|| rayon::join(op_mask, || rayon::join(op_truthy, op_falsy)));
let mut ac_mask = ac_mask?;
let mut ac_truthy = ac_truthy?;
let mut ac_falsy = ac_falsy?;

let mask_s = ac_mask.flat();

assert!(
Expand All @@ -47,17 +58,10 @@ The predicate produced {} values. Where the original DataFrame has {} values",
required_height
);

let mask = mask_s.bool()?;
let mut ac_truthy = self.truthy.evaluate_on_groups(df, groups, state)?;
let mut ac_falsy = self.falsy.evaluate_on_groups(df, groups, state)?;

if !ac_truthy.can_combine(&ac_falsy) {
return Err(PolarsError::InvalidOperation(
"\
cannot combine this ternary expression, the groups do not match"
.into(),
));
}
assert!(
ac_truthy.can_combine(&ac_falsy),
"cannot combine this ternary expression, the groups do not match"
);

match (ac_truthy.agg_state(), ac_falsy.agg_state()) {
(AggState::AggregatedFlat(_), AggState::NotAggregated(_)) => {
Expand All @@ -80,15 +84,28 @@ The predicate produced {} values. Where the original DataFrame has {} values",

// this is now a list
let falsy = ac_falsy.aggregated();
let falsy = falsy.as_ref();
let falsy = falsy.list().unwrap();

let mask = ac_mask.aggregated();
let mask = mask.as_ref();
let mask = mask.list()?;
if !matches!(mask.inner_dtype(), DataType::Boolean) {
return Err(PolarsError::ComputeError(
format!("expected mask of type bool, got {:?}", mask.inner_dtype()).into(),
));
}

let mut ca: ListChunked = falsy
.amortized_iter()
.zip(mask.amortized_iter())
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let falsy = s.as_ref();
.map(|(idx, (opt_falsy, opt_mask))| {
match (opt_falsy, opt_mask) {
(Some(falsy), Some(mask)) => {
let falsy = falsy.as_ref();
let mask = mask.as_ref();
let mask = mask.bool()?;

// Safety:
// we are in bounds
Expand All @@ -97,9 +114,11 @@ The predicate produced {} values. Where the original DataFrame has {} values",
std::mem::swap(&mut chunks[0], &mut arr);
let truthy = &dummy;

truthy.zip_with(mask, falsy)
})
.transpose()
Some(truthy.zip_with(mask, falsy))
}
_ => None,
}
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(truthy.name());
Expand All @@ -110,6 +129,7 @@ The predicate produced {} values. Where the original DataFrame has {} values",
(AggState::NotAggregated(_), AggState::AggregatedFlat(_)) => {
// this is now a list
let truthy = ac_truthy.aggregated();
let truthy = truthy.as_ref();
let truthy = truthy.list().unwrap();

// this is a flat series of len eq to group tuples
Expand All @@ -128,24 +148,38 @@ The predicate produced {} values. Where the original DataFrame has {} values",
let len = chunks.len();
std::slice::from_raw_parts_mut(ptr, len)
};
let mask = ac_mask.aggregated();
let mask = mask.as_ref();
let mask = mask.list()?;
if !matches!(mask.inner_dtype(), DataType::Boolean) {
return Err(PolarsError::ComputeError(
format!("expected mask of type bool, got {:?}", mask.inner_dtype()).into(),
));
}

let mut ca: ListChunked = truthy
.amortized_iter()
.zip(mask.amortized_iter())
.enumerate()
.map(|(idx, opt_s)| {
opt_s
.map(|s| {
let truthy = s.as_ref();
.map(|(idx, (opt_truthy, opt_mask))| {
match (opt_truthy, opt_mask) {
(Some(truthy), Some(mask)) => {
let truthy = truthy.as_ref();
let mask = mask.as_ref();
let mask = mask.bool()?;

// Safety:
// we are in bounds
let mut arr =
unsafe { Arc::from(arr_falsy.slice_unchecked(idx, 1)) };
std::mem::swap(&mut chunks[0], &mut arr);
let falsy = &dummy;

truthy.zip_with(mask, falsy)
})
.transpose()
Some(truthy.zip_with(mask, falsy))
}
_ => None,
}
.transpose()
})
.collect::<Result<_>>()?;
ca.rename(truthy.name());
Expand All @@ -156,6 +190,7 @@ The predicate produced {} values. Where the original DataFrame has {} values",
// Both are or a flat series or aggreagated into a list
// so we can flatten the Series an apply the operators
_ => {
let mask = mask_s.bool()?;
let out = ac_truthy.flat().zip_with(mask, ac_falsy.flat().as_ref())?;

assert!(!(out.len() != required_height), "The output of the `when -> then -> otherwise-expr` is of a different length than the groups.\
Expand Down
35 changes: 20 additions & 15 deletions polars/polars-lazy/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2284,50 +2284,55 @@ fn test_binary_agg_context_0() -> Result<()> {
fn test_binary_agg_context_1() -> Result<()> {
let df = df![
"groups" => [1, 1, 2, 2, 3, 3],
"vals" => [1, 2, 3, 4, 5, 6]
"vals" => [1, 13, 3, 87, 1, 6]
]?;

// groups
// 1 => [1, 13]
// 2 => [3, 87]
// 3 => [1, 6]

let out = df
.clone()
.lazy()
.stable_groupby([col("groups")])
.agg([when(col("vals").neq(lit(1)))
.then(col("vals").first())
.agg([when(col("vals").eq(lit(1)))
.then(col("vals").sum())
.otherwise(lit(90))
.alias("vals")])
.collect()?;

// [90, 1]
// [90, 3]
// [90, 5]
// if vals == 1 then sum(vals) else vals
// [14, 90]
// [90, 90]
// [7, 90]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(90), Some(1), Some(90), Some(3), Some(90), Some(5)]
&[Some(14), Some(90), Some(90), Some(90), Some(7), Some(90)]
);

let out = df
.lazy()
.stable_groupby([col("groups")])
.agg([when(col("vals").neq(lit(1)))
.agg([when(col("vals").eq(lit(1)))
.then(lit(90))
.otherwise(col("vals").first())
.otherwise(col("vals").sum())
.alias("vals")])
.collect()?;

dbg!(&out);

// [1, 90]
// [3, 90]
// [5, 90]
// if vals == 1 then 90 else sum(vals)
// [90, 14]
// [90, 90]
// [90, 7]
let out = out.column("vals")?;
let out = out.explode()?;
let out = out.i32()?;
assert_eq!(
Vec::from(out),
&[Some(1), Some(90), Some(3), Some(90), Some(5), Some(90)]
&[Some(90), Some(14), Some(90), Some(90), Some(90), Some(7)]
);

Ok(())
Expand Down

0 comments on commit 3cbb874

Please sign in to comment.