Skip to content

Commit

Permalink
refactor has_expr
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 16, 2021
1 parent 2bd5789 commit bcfbd30
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 312 deletions.
189 changes: 7 additions & 182 deletions polars/polars-lazy/src/dsl.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
//! Domain specific language for the Lazy api.
use crate::logical_plan::Context;
use crate::prelude::*;
use crate::utils::{has_expr, output_name, rename_field};
use polars_core::{
frame::groupby::{fmt_groupby_column, GroupByMethod},
prelude::*,
utils::get_supertype,
};
use crate::utils::{has_expr, output_name};
use polars_core::prelude::*;

#[cfg(feature = "temporal")]
use polars_core::utils::chrono::{NaiveDate, NaiveDateTime};
Expand Down Expand Up @@ -236,183 +232,12 @@ pub enum Expr {
}

impl Expr {
/// Get DataType result of the expression. The schema is the input data.
pub fn get_type(&self, schema: &Schema, context: Context) -> Result<DataType> {
self.to_field(schema, context)
.map(|f| f.data_type().clone())
}

/// Get Field result of the expression. The schema is the input data.
pub(crate) fn to_field(&self, schema: &Schema, ctxt: Context) -> Result<Field> {
use Expr::*;
match self {
Window { function, .. } => function.to_field(schema, ctxt),
IsUnique(expr) => {
let field = expr.to_field(&schema, ctxt)?;
Ok(Field::new(field.name(), DataType::Boolean))
}
Duplicated(expr) => {
let field = expr.to_field(&schema, ctxt)?;
Ok(Field::new(field.name(), DataType::Boolean))
}
Reverse(expr) => expr.to_field(&schema, ctxt),
Explode(expr) => expr.to_field(&schema, ctxt),
Alias(expr, name) => Ok(Field::new(name, expr.get_type(schema, ctxt)?)),
Column(name) => {
let field = schema.field_with_name(name).map(|f| f.clone())?;
Ok(field)
}
Literal(sv) => Ok(Field::new("lit", sv.get_datatype())),
BinaryExpr { left, right, op } => {
let left_type = left.get_type(schema, ctxt)?;
let right_type = right.get_type(schema, ctxt)?;

let expr_type = match op {
Operator::Not
| Operator::Lt
| Operator::Gt
| Operator::Eq
| Operator::NotEq
| Operator::And
| Operator::LtEq
| Operator::GtEq
| Operator::Or
| Operator::NotLike
| Operator::Like => DataType::Boolean,
_ => get_supertype(&left_type, &right_type)?,
};

use Operator::*;
let out_field;
let out_name = match op {
Plus | Minus | Multiply | Divide | Modulus => {
out_field = left.to_field(schema, ctxt)?;
out_field.name().as_str()
}
Eq | Lt | GtEq | LtEq => "",
_ => "binary_expr",
};

Ok(Field::new(out_name, expr_type))
}
Not(_) => Ok(Field::new("not", DataType::Boolean)),
IsNull(_) => Ok(Field::new("is_null", DataType::Boolean)),
IsNotNull(_) => Ok(Field::new("is_not_null", DataType::Boolean)),
Sort { expr, .. } => expr.to_field(schema, ctxt),
SortBy { expr, .. } => expr.to_field(schema, ctxt),
Filter { input, .. } => input.to_field(schema, ctxt),
Agg(agg) => {
use AggExpr::*;
let field = match agg {
Min(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Min)
}
Max(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Max)
}
Median(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Median)
}
Mean(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Mean)
}
First(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::First)
}
Last(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Last)
}
List(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::List)
}
NUnique(expr) => {
let field = expr.to_field(schema, ctxt)?;
let field = Field::new(field.name(), DataType::UInt32);
match ctxt {
Context::Default => field,
Context::Aggregation => {
let new_name =
fmt_groupby_column(field.name(), GroupByMethod::NUnique);
rename_field(&field, &new_name)
}
}
}
Sum(expr) => {
field_by_context(expr.to_field(schema, ctxt)?, ctxt, GroupByMethod::Sum)
}
Std(expr) => {
let field = expr.to_field(schema, ctxt)?;
let field = Field::new(field.name(), DataType::Float64);
field_by_context(field, ctxt, GroupByMethod::Std)
}
Var(expr) => {
let field = expr.to_field(schema, ctxt)?;
let field = Field::new(field.name(), DataType::Float64);
field_by_context(field, ctxt, GroupByMethod::Var)
}
Count(expr) => {
let field = expr.to_field(schema, ctxt)?;
let field = Field::new(field.name(), DataType::UInt32);
match ctxt {
Context::Default => field,
Context::Aggregation => {
let new_name =
fmt_groupby_column(field.name(), GroupByMethod::Count);
rename_field(&field, &new_name)
}
}
}
AggGroups(expr) => {
let field = expr.to_field(schema, ctxt)?;
let new_name = fmt_groupby_column(field.name(), GroupByMethod::Groups);
Field::new(&new_name, DataType::List(ArrowDataType::UInt32))
}
Quantile { expr, quantile } => field_by_context(
expr.to_field(schema, ctxt)?,
ctxt,
GroupByMethod::Quantile(*quantile),
),
};
Ok(field)
}
Cast { expr, data_type } => {
let field = expr.to_field(schema, ctxt)?;
Ok(Field::new(field.name(), data_type.clone()))
}
Ternary { truthy, .. } => truthy.to_field(schema, ctxt),
Udf {
output_type, input, ..
} => match output_type {
None => input.to_field(schema, ctxt),
Some(output_type) => {
let input_field = input.to_field(schema, ctxt)?;
Ok(Field::new(input_field.name(), output_type.clone()))
}
},
BinaryFunction {
input_a,
input_b,
output_field,
..
} => {
let field_a = input_a.to_field(schema, ctxt)?;
let field_b = input_b.to_field(schema, ctxt)?;
// if field is unknown we try to guess a return type. May fail.
Ok(output_field
.get_field(schema, ctxt, &field_a, &field_b)
.unwrap_or_else(|| {
Field::new(
"binary_expr",
get_supertype(field_a.data_type(), field_b.data_type())
.unwrap_or(DataType::Null),
)
}))
}
Shift { input, .. } => input.to_field(schema, ctxt),
Slice { input, .. } => input.to_field(schema, ctxt),
Wildcard => panic!("should be no wildcard at this point"),
Except(_) => panic!("should be no except at this point"),
}
// this is not called much and th expression depth is typically shallow
let mut arena = Arena::with_capacity(5);
let root = to_aexpr(self.clone(), &mut arena);
arena.get(root).to_field(schema, ctxt, &arena)
}
}

Expand Down Expand Up @@ -978,7 +803,7 @@ impl Expr {
/// Should be used in aggregation context. If you want to filter on a DataFrame level, use
/// [LazyFrame::filter](LazyFrame::filter)
pub fn filter(self, predicate: Expr) -> Self {
if has_expr(&self, &Expr::Wildcard) {
if has_expr(&self, |e| matches!(e, Expr::Wildcard)) {
panic!("filter '*' not allowed, use LazyFrame::filter")
};
Expr::Filter {
Expand Down
15 changes: 0 additions & 15 deletions polars/polars-lazy/src/dummies.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::dsl::{BinaryUdfOutputField, NoEq, SeriesBinaryUdf};
use crate::logical_plan::Context;
use crate::prelude::*;
use polars_core::prelude::*;
use std::sync::Arc;

Expand All @@ -16,17 +15,3 @@ impl Default for NoEq<Arc<dyn BinaryUdfOutputField>> {
NoEq::new(Arc::new(output_field))
}
}

pub(crate) fn dummy_aexpr_sort_by() -> AExpr {
AExpr::SortBy {
expr: Default::default(),
by: Default::default(),
reverse: Default::default(),
}
}
pub(crate) fn dummy_aexpr_filter() -> AExpr {
AExpr::Filter {
input: Default::default(),
by: Default::default(),
}
}
7 changes: 3 additions & 4 deletions polars/polars-lazy/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -770,12 +770,11 @@ fn rewrite_projections(exprs: Vec<Expr>, schema: &Schema) -> Vec<Expr> {
}
}

let has_wildcard = has_expr(&expr, &Expr::Wildcard);
let has_wildcard = has_expr(&expr, |e| matches!(e, Expr::Wildcard));

if has_wildcard {
// if count wildcard. count one column
let dummy = &Expr::Agg(AggExpr::Count(Box::new(Expr::Wildcard)));
if has_expr(&expr, dummy) {
if has_expr(&expr, |e| matches!(e, Expr::Agg(AggExpr::Count(_)))) {
let new_name = Arc::new(schema.field(0).unwrap().name().clone());
let expr = rename_expr_root_name(&expr, new_name).unwrap();

Expand Down Expand Up @@ -1004,7 +1003,7 @@ impl LogicalPlanBuilder {

/// Apply a filter
pub fn filter(self, predicate: Expr) -> Self {
let predicate = if has_expr(&predicate, &Expr::Wildcard) {
let predicate = if has_expr(&predicate, |e| matches!(e, Expr::Wildcard)) {
let it = self.0.schema().fields().iter().map(|field| {
replace_wildcard_with_column(predicate.clone(), Arc::new(field.name().clone()))
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,14 @@ impl AggregatePushdown {
lp_arena: &mut Arena<ALogicalPlan>,
expr_arena: &mut Arena<AExpr>,
) -> Option<ALogicalPlan> {
let dummy_node = usize::max_value();
let dummy_min = AExpr::Agg(AAggExpr::Min(Node(dummy_node)));
let dummy_max = AExpr::Agg(AAggExpr::Max(Node(dummy_node)));
let dummy_first = AExpr::Agg(AAggExpr::First(Node(dummy_node)));
let dummy_last = AExpr::Agg(AAggExpr::First(Node(dummy_node)));
let dummy_sum = AExpr::Agg(AAggExpr::Sum(Node(dummy_node)));

// only do aggregation pushdown if all projections are aggregations
#[allow(clippy::blocks_in_if_conditions)]
if !self.processed_state
&& expr.iter().all(|node| {
(has_aexpr(*node, expr_arena, &dummy_min)
|| has_aexpr(*node, expr_arena, &dummy_max)
|| has_aexpr(*node, expr_arena, &dummy_first)
|| has_aexpr(*node, expr_arena, &dummy_sum)
|| has_aexpr(*node, expr_arena, &dummy_last))
&& {
let roots = aexpr_to_root_nodes(*node, expr_arena);
roots.len() == 1
}
has_aexpr(*node, expr_arena, |e| matches!(e, AExpr::Agg(_))) && {
let roots = aexpr_to_root_nodes(*node, expr_arena);
roots.len() == 1
}
})
{
// add to state
Expand Down

0 comments on commit bcfbd30

Please sign in to comment.