Skip to content

Commit

Permalink
Slice expr args (#2786)
Browse files Browse the repository at this point in the history
* accept expression as slice expr args

* Slice expr: accept expression as arguments
  • Loading branch information
ritchie46 committed Feb 26, 2022
1 parent 9befa6d commit 1e7b8b2
Show file tree
Hide file tree
Showing 19 changed files with 344 additions and 71 deletions.
16 changes: 8 additions & 8 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,8 @@ pub enum Expr {
Slice {
input: Box<Expr>,
/// length is not yet known so we accept negative offsets
offset: i64,
length: usize,
offset: Box<Expr>,
length: Box<Expr>,
},
/// Can be used in a select statement to exclude a column from selection
Exclude(Box<Expr>, Vec<Excluded>),
Expand Down Expand Up @@ -391,7 +391,6 @@ pub enum Operator {
Minus,
Multiply,
Divide,
#[cfg(feature = "true_div")]
TrueDivide,
Modulus,
And,
Expand Down Expand Up @@ -722,23 +721,24 @@ impl Expr {
}

/// Slice the Series.
pub fn slice(self, offset: i64, length: usize) -> Self {
/// `offset` may be negative.
pub fn slice(self, offset: Expr, length: Expr) -> Self {
Expr::Slice {
input: Box::new(self),
offset,
length,
offset: Box::new(offset),
length: Box::new(length),
}
}

/// Get the first `n` elements of the Expr result
pub fn head(self, length: Option<usize>) -> Self {
self.slice(0, length.unwrap_or(10))
self.slice(lit(0), lit(length.unwrap_or(10) as u64))
}

/// Get the last `n` elements of the Expr result
pub fn tail(self, length: Option<usize>) -> Self {
let len = length.unwrap_or(10);
self.slice(-(len as i64), len)
self.slice(lit(-(len as i64)), lit(len as u64))
}

/// Get unique values of this expression.
Expand Down
8 changes: 0 additions & 8 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -504,14 +504,6 @@ pub fn dtype_cols<DT: AsRef<[DataType]>>(dtype: DT) -> Expr {
Expr::DtypeColumn(dtypes)
}

/// Count the number of values in this Expression.
pub fn count(name: &str) -> Expr {
match name {
"" => col(name).count().alias("count"),
_ => col(name).count(),
}
}

/// Sum all the values in this Expression.
pub fn sum(name: &str) -> Expr {
col(name).sum()
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-lazy/src/logical_plan/aexpr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ pub enum AExpr {
Wildcard,
Slice {
input: Node,
offset: i64,
length: usize,
offset: Node,
length: Node,
},
Count,
Nth(i64),
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-lazy/src/logical_plan/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ pub(crate) fn to_aexpr(expr: Expr, arena: &mut Arena<AExpr>) -> Node {
length,
} => AExpr::Slice {
input: to_aexpr(*input, arena),
offset,
length,
offset: to_aexpr(*offset, arena),
length: to_aexpr(*length, arena),
},
Expr::Wildcard => AExpr::Wildcard,
Expr::Count => AExpr::Count,
Expand Down Expand Up @@ -592,8 +592,8 @@ pub(crate) fn node_to_expr(node: Node, expr_arena: &Arena<AExpr>) -> Expr {
length,
} => Expr::Slice {
input: Box::new(node_to_expr(input, expr_arena)),
offset,
length,
offset: Box::new(node_to_expr(offset, expr_arena)),
length: Box::new(node_to_expr(length, expr_arena)),
},
AExpr::Count => Expr::Count,
AExpr::Nth(i) => Expr::Nth(i),
Expand Down
7 changes: 5 additions & 2 deletions polars/polars-lazy/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,11 @@ impl fmt::Debug for Expr {
input,
offset,
length,
} => write!(f, "SLICE {:?} offset: {} len: {}", input, offset, length),
} => write!(
f,
"{:?}.slice(offset={:?}, length={:?})",
input, offset, length
),
Wildcard => write!(f, "*"),
Exclude(column, names) => write!(f, "{:?}, EXCEPT {:?}", column, names),
KeepName(e) => write!(f, "KEEP NAME {:?}", e),
Expand All @@ -281,7 +285,6 @@ impl Debug for Operator {
Minus => "-",
Multiply => "*",
Divide => "//",
#[cfg(feature = "true_div")]
TrueDivide => "/",
Modulus => "%",
And => "&",
Expand Down
20 changes: 18 additions & 2 deletions polars/polars-lazy/src/logical_plan/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ macro_rules! push_expr {
$push(e);
}
}
Slice { input, .. } => $push(input),
Slice {
input,
offset,
length,
} => {
$push(input);
$push(offset);
$push(length);
}
Exclude(e, _) => $push(e),
KeepName(e) => $push(e),
RenameAlias { expr, .. } => $push(expr),
Expand Down Expand Up @@ -228,7 +236,15 @@ impl AExpr {
push(e);
}
}
Slice { input, .. } => push(input),
Slice {
input,
offset,
length,
} => {
push(input);
push(offset);
push(length);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ impl OptimizationRule for SimplifyExprRule {
Operator::Minus => eval_binary_same_type!(left_aexpr, -, right_aexpr),
Operator::Multiply => eval_binary_same_type!(left_aexpr, *, right_aexpr),
Operator::Divide => eval_binary_same_type!(left_aexpr, /, right_aexpr),
#[cfg(feature = "true_div")]
Operator::TrueDivide => {
if let (AExpr::Literal(lit_left), AExpr::Literal(lit_right)) =
(left_aexpr, right_aexpr)
Expand Down
41 changes: 34 additions & 7 deletions polars/polars-lazy/src/physical_plan/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ pub(crate) fn apply_operator(left: &Series, right: &Series, op: Operator) -> Res
Operator::Minus => Ok(left - right),
Operator::Multiply => Ok(left * right),
Operator::Divide => Ok(left / right),
#[cfg(feature = "true_div")]
Operator::TrueDivide => {
use DataType::*;
match left.dtype() {
Expand Down Expand Up @@ -96,20 +95,46 @@ impl PhysicalExpr for BinaryExpr {
));
}

match (ac_l.agg_state(), ac_r.agg_state()) {
match (ac_l.agg_state(), ac_r.agg_state(), self.op) {
// Some aggregations must return boolean masks that fit the group. That's why not all literals can take this path.
// only literals that are used in arithmetic
(
AggState::AggregatedFlat(lhs),
AggState::Literal(rhs),
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
)
| (
AggState::Literal(lhs),
AggState::AggregatedFlat(rhs),
Operator::Plus
| Operator::Minus
| Operator::Divide
| Operator::Multiply
| Operator::Modulus
| Operator::TrueDivide,
) => {
let out = apply_operator(lhs, rhs, self.op)?;

ac_l.with_series(out, true);
Ok(ac_l)
}
// One of the two exprs is aggregated with flat aggregation, e.g. `e.min(), e.max(), e.first()`

// if the groups_len == df.len we can just apply all flat.
// within an aggregation a `col().first() - lit(0)` must still produce a boolean array of group length,
// that's why a literal also takes this branch
(AggState::AggregatedFlat(s), AggState::NotAggregated(_) | AggState::Literal(_))
(AggState::AggregatedFlat(s), AggState::NotAggregated(_) | AggState::Literal(_), _)
if s.len() != df.height() =>
{
// this is a flat series of len eq to group tuples
let l = ac_l.aggregated();
let l = l.as_ref();
let arr_l = &l.chunks()[0];
assert_eq!(l.len(), groups.len());

// we create a dummy Series that is not cloned nor moved
// so we can swap the ArrayRef during the hot loop
Expand Down Expand Up @@ -154,6 +179,7 @@ impl PhysicalExpr for BinaryExpr {
(
AggState::Literal(_) | AggState::AggregatedList(_) | AggState::NotAggregated(_),
AggState::AggregatedFlat(s),
_,
) if s.len() != df.height() => {
// this is now a list
let l = ac_l.aggregated();
Expand Down Expand Up @@ -198,8 +224,9 @@ impl PhysicalExpr for BinaryExpr {
ac_l.with_update_groups(UpdateGroups::WithGroupsLen);
Ok(ac_l)
}
(AggState::AggregatedList(_), AggState::NotAggregated(_) | AggState::Literal(_))
| (AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedList(_)) => {
(AggState::AggregatedList(_), AggState::NotAggregated(_) | AggState::Literal(_), _)
| (AggState::NotAggregated(_) | AggState::Literal(_), AggState::AggregatedList(_), _) =>
{
ac_l.sort_by_groups();
ac_r.sort_by_groups();

Expand All @@ -215,7 +242,7 @@ impl PhysicalExpr for BinaryExpr {
Ok(ac_l)
}
// flatten the Series and apply the operators
(AggState::AggregatedList(_), AggState::AggregatedList(_)) => {
(AggState::AggregatedList(_), AggState::AggregatedList(_), _) => {
let out = apply_operator(
ac_l.flat_naive().as_ref(),
ac_r.flat_naive().as_ref(),
Expand Down
4 changes: 4 additions & 0 deletions polars/polars-lazy/src/physical_plan/expressions/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ impl PhysicalExpr for CountExpr {
fn to_field(&self, _input_schema: &Schema) -> Result<Field> {
Ok(Field::new("count", DataType::UInt32))
}

fn as_agg_expr(&self) -> Result<&dyn PhysicalAggregation> {
Ok(self)
}
}

impl PhysicalAggregation for CountExpr {
Expand Down

0 comments on commit 1e7b8b2

Please sign in to comment.