Skip to content

Commit

Permalink
improve strictness/consistency of when then otherwise (#4241)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 3, 2022
1 parent 00b50c2 commit 6b0041a
Show file tree
Hide file tree
Showing 9 changed files with 278 additions and 79 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ impl PartialEq for AnyValue<'_> {
#[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
(Datetime(l, tul, tzl), Datetime(r, tur, tzr)) => l == r && tul == tur && tzl == tzr,
(Boolean(l), Boolean(r)) => l == r,
(List(_), List(_)) => panic!("eq between list series not supported"),
(List(l), List(r)) => l == r,
#[cfg(feature = "object")]
(Object(_), Object(_)) => panic!("eq between object not supported"),
// should it?
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod pow;
mod rolling;
#[cfg(feature = "row_hash")]
mod row_hash;
mod shift_and_fill;
#[cfg(feature = "sign")]
mod sign;
#[cfg(feature = "strings")]
Expand Down Expand Up @@ -60,6 +61,9 @@ pub enum FunctionExpr {
window_size: usize,
bias: bool,
},
ShiftAndFill {
periods: i64,
},
}

#[cfg(feature = "trigonometry")]
Expand Down Expand Up @@ -127,6 +131,7 @@ impl FunctionExpr {
ListContains => with_dtype(DataType::Boolean),
#[cfg(all(feature = "rolling_window", feature = "moment"))]
RollingSkew { .. } => float_dtype(),
ShiftAndFill { .. } => same_type(),
}
}
}
Expand Down Expand Up @@ -247,6 +252,9 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
RollingSkew { window_size, bias } => {
map_with_args!(rolling::rolling_skew, window_size, bias)
}
ShiftAndFill { periods } => {
map_as_slice!(shift_and_fill::shift_and_fill, periods)
}
}
}
}
26 changes: 26 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/shift_and_fill.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use super::*;

pub(super) fn shift_and_fill(args: &mut [Series], periods: i64) -> Result<Series> {
let s = &args[0];
let fill_value = &args[1];

let mask: BooleanChunked = if periods > 0 {
let len = s.len();
let mut bits = MutableBitmap::with_capacity(s.len());
bits.extend_constant(periods as usize, false);
bits.extend_constant(len.saturating_sub(periods as usize), true);
let mask = BooleanArray::from_data_default(bits.into(), None);
mask.into()
} else {
let length = s.len() as i64;
// periods is negative, so subtraction.
let tipping_point = std::cmp::max(length + periods, 0);
let mut bits = MutableBitmap::with_capacity(s.len());
bits.extend_constant(tipping_point as usize, true);
bits.extend_constant(-periods as usize, false);
let mask = BooleanArray::from_data_default(bits.into(), None);
mask.into()
};

s.shift(periods).zip_with_same_type(&mask, fill_value)
}
45 changes: 5 additions & 40 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ use crate::dsl::function_expr::FunctionExpr;
#[cfg(feature = "trigonometry")]
use crate::dsl::function_expr::TrigonometricFunction;

use polars_arrow::array::default_arrays::FromData;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
use polars_core::series::IsSorted;
Expand Down Expand Up @@ -859,47 +858,13 @@ impl Expr {
}
}

pub fn shift_and_fill_impl(self, periods: i64, fill_value: Expr) -> Self {
// Note:
// The order of the then | otherwise is important
if periods > 0 {
when(self.clone().apply(
move |s: Series| {
let len = s.len();
let mut bits = MutableBitmap::with_capacity(s.len());
bits.extend_constant(periods as usize, false);
bits.extend_constant(len.saturating_sub(periods as usize), true);
let mask = BooleanArray::from_data_default(bits.into(), None);
let ca: BooleanChunked = mask.into();
Ok(ca.into_series())
},
GetOutput::from_type(DataType::Boolean),
))
.then(self.shift(periods))
.otherwise(fill_value)
} else {
when(self.clone().apply(
move |s: Series| {
let length = s.len() as i64;
// periods is negative, so subtraction.
let tipping_point = std::cmp::max(length + periods, 0);
let mut bits = MutableBitmap::with_capacity(s.len());
bits.extend_constant(tipping_point as usize, true);
bits.extend_constant(-periods as usize, false);
let mask = BooleanArray::from_data_default(bits.into(), None);
let ca: BooleanChunked = mask.into();
Ok(ca.into_series())
},
GetOutput::from_type(DataType::Boolean),
))
.then(self.shift(periods))
.otherwise(fill_value)
}
}

/// Shift the values in the array by some period and fill the resulting empty values.
pub fn shift_and_fill<E: Into<Expr>>(self, periods: i64, fill_value: E) -> Self {
self.shift_and_fill_impl(periods, fill_value.into())
self.apply_many_private(
FunctionExpr::ShiftAndFill { periods },
&[fill_value.into()],
"shift_and_fill",
)
}

/// Get an array with the cumulative sum computed at every element
Expand Down
50 changes: 50 additions & 0 deletions polars/polars-lazy/src/logical_plan/optimizer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,56 @@ impl OptimizationRule for TypeCoercionRule {
options,
})
}
AExpr::Function {
// only for `DataType::Unknown` as it still has to be set.
function: FunctionExpr::ShiftAndFill { periods },
ref input,
options,
} => {
let input_schema = get_schema(lp_arena, lp_node);
let self_node = input[0];
let other_node = input[1];
let (left, type_self) = get_aexpr_and_type(expr_arena, self_node, &input_schema)?;
let (fill_value, type_other) =
get_aexpr_and_type(expr_arena, other_node, &input_schema)?;

if type_self == type_other {
return None;
}

let super_type = get_supertype(&type_self, &type_other).ok()?;
let super_type =
modify_supertype(super_type, left, fill_value, &type_self, &type_other);

// only cast if the type is not already the super type.
// this can prevent an expensive flattening and subsequent aggregation
// in a groupby context. To be able to cast the groups need to be
// flattened
let new_node_self = if type_self != super_type {
expr_arena.add(AExpr::Cast {
expr: self_node,
data_type: super_type.clone(),
strict: false,
})
} else {
self_node
};
let new_node_other = if type_other != super_type {
expr_arena.add(AExpr::Cast {
expr: other_node,
data_type: super_type,
strict: false,
})
} else {
other_node
};

Some(AExpr::Function {
function: FunctionExpr::ShiftAndFill { periods },
input: vec![new_node_self, new_node_other],
options,
})
}

_ => None,
}
Expand Down
141 changes: 103 additions & 38 deletions polars/polars-lazy/src/physical_plan/expressions/ternary.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::physical_plan::state::{ExecutionState, StateFlags};
use crate::prelude::*;
use polars_arrow::utils::CustomIterTools;
use polars_core::frame::groupby::GroupsProxy;
use polars_core::prelude::*;
use polars_core::POOL;
Expand Down Expand Up @@ -135,7 +136,7 @@ impl PhysicalExpr for TernaryExpr {

let ac_mask = ac_mask?;
let mut ac_truthy = ac_truthy?;
let ac_falsy = ac_falsy?;
let mut ac_falsy = ac_falsy?;

let mask_s = ac_mask.flat_naive();

Expand All @@ -151,49 +152,113 @@ impl PhysicalExpr for TernaryExpr {
// truthy -> aggregated-flat | literal
// falsy -> aggregated-flat | literal
// simply align lengths and zip
(
Literal(truthy) | AggregatedFlat(truthy),
AggregatedFlat(falsy) | Literal(falsy),
)
(Literal(truthy) | AggregatedFlat(truthy), AggregatedFlat(falsy) | Literal(falsy))
| (AggregatedList(truthy), AggregatedList(falsy))
if matches!(ac_mask.agg_state(), AggState::AggregatedFlat(_))
=>
{
let mut truthy = truthy.clone();
let mut falsy = falsy.clone();
let mut mask = ac_mask.series().bool()?.clone();
expand_lengths(&mut truthy, &mut falsy, &mut mask);
let mut out = truthy.zip_with(&mask, &falsy).unwrap();
out.rename(truthy.name());
ac_truthy.with_series(out, true);
Ok(ac_truthy)
}

// Fallthrough from previous branch. That means the mask is not aggregated
// and may be longer than branches, so we iterate over the groups
(AggregatedFlat(agg), Literal(_)) |
(Literal(_), AggregatedFlat(agg))
// todo! check if we need to take this branch for those
//| (AggregatedFlat(agg), NotAggregated(_)) |
// (NotAggregated(_), AggregatedFlat(agg))

// if the groups_len == df.len we can just apply all flat.
// make sure that the order of the aggregation does not matter!
// this is the case when we have a flat aggregation combined with a literal:
// `sum(..) - lit(..)`.
if agg.len() != df.height()
=> {
finish_as_iters(ac_truthy, ac_falsy, ac_mask)
if matches!(ac_mask.agg_state(), AggState::AggregatedFlat(_)) =>
{
let mut truthy = truthy.clone();
let mut falsy = falsy.clone();
let mut mask = ac_mask.series().bool()?.clone();
expand_lengths(&mut truthy, &mut falsy, &mut mask);
let mut out = truthy.zip_with(&mask, &falsy).unwrap();
out.rename(truthy.name());
ac_truthy.with_series(out, true);
Ok(ac_truthy)
}

// Same branch as above, but without escape hatch for `num_groups == df.height()`

// we cannot flatten a list because that changes the order, so we apply over groups
(AggregatedList(_), NotAggregated(_)) |
(NotAggregated(_), AggregatedList(_)) => {
(AggregatedList(_), NotAggregated(_)) | (NotAggregated(_), AggregatedList(_)) => {
finish_as_iters(ac_truthy, ac_falsy, ac_mask)
}

// then:
// col().shift()
// otherwise:
// None
(AggregatedList(_), Literal(_)) | (Literal(_), AggregatedList(_)) => {
let mask = mask_s.bool()?;
let check_length = |ca: &ListChunked, mask: &BooleanChunked| {
if ca.len() != mask.len() {
Err(PolarsError::ComputeError(format!("the predicates length: '{}' does not match the length of the groups: {}", mask.len(), ca.len()).into()))
} else {
Ok(())
}
};

if ac_falsy.is_literal() && self.falsy.as_expression().map(has_null) == Some(true) {
let s = ac_truthy.aggregated();
let ca = s.list().unwrap();
check_length(ca, mask)?;
let mut out: ListChunked = ca
.into_iter()
.zip(mask.into_iter())
.map(|(truthy, take)| match (truthy, take) {
(Some(v), Some(true)) => Some(v),
(Some(_), Some(false)) => None,
_ => None,
})
.collect_trusted();
out.rename(ac_truthy.series().name());
ac_truthy.with_series(out.into_series(), true);
Ok(ac_truthy)
} else if ac_truthy.is_literal()
&& self.truthy.as_expression().map(has_null) == Some(true)
{
let s = ac_falsy.aggregated();
let ca = s.list().unwrap();
check_length(ca, mask)?;
let mut out: ListChunked = ca
.into_iter()
.zip(mask.into_iter())
.map(|(falsy, take)| match (falsy, take) {
(Some(_), Some(true)) => None,
(Some(v), Some(false)) => Some(v),
_ => None,
})
.collect_trusted();
out.rename(ac_truthy.series().name());
ac_truthy.with_series(out.into_series(), true);
Ok(ac_truthy)
}
// then:
// col().shift()
// otherwise:
// lit(list)
else if ac_truthy.is_literal() {
let literal = ac_truthy.series();
let s = ac_falsy.aggregated();
let ca = s.list().unwrap();
check_length(ca, mask)?;
let mut out: ListChunked = ca
.into_iter()
.zip(mask.into_iter())
.map(|(falsy, take)| match (falsy, take) {
(Some(_), Some(true)) => Some(literal.clone()),
(Some(v), Some(false)) => Some(v),
_ => None,
})
.collect_trusted();
out.rename(ac_truthy.series().name());
ac_truthy.with_series(out.into_series(), true);
Ok(ac_truthy)
} else {
let literal = ac_falsy.series();
let s = ac_truthy.aggregated();
let ca = s.list().unwrap();
check_length(ca, mask)?;
let mut out: ListChunked = ca
.into_iter()
.zip(mask.into_iter())
.map(|(truthy, take)| match (truthy, take) {
(Some(v), Some(true)) => Some(v),
(Some(_), Some(false)) => Some(literal.clone()),
_ => None,
})
.collect_trusted();
out.rename(ac_truthy.series().name());
ac_truthy.with_series(out.into_series(), true);
Ok(ac_truthy)
}
}
// Both are or a flat series or aggregated into a list
// so we can flatten the Series an apply the operators
_ => {
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::dsl::AggExpr::List;
use polars_arrow::prelude::QuantileInterpolOptions;
use polars_core::frame::explode::MeltArgs;
use polars_core::series::ops::NullBehavior;
Expand Down
6 changes: 6 additions & 0 deletions polars/polars-lazy/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ pub(crate) fn has_nth(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| matches!(e, Expr::Nth(_)))
}

pub(crate) fn has_null(current_expr: &Expr) -> bool {
has_expr(current_expr, |e| {
matches!(e, Expr::Literal(LiteralValue::Null))
})
}

/// output name of expr
pub(crate) fn expr_output_name(expr: &Expr) -> Result<Arc<str>> {
for e in expr {
Expand Down

0 comments on commit 6b0041a

Please sign in to comment.