From 230c68c02bf0c3d5b7d50d24145eb50604420d4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Mon, 13 May 2024 12:53:56 +0100 Subject: [PATCH 01/11] Add `simplify` method to aggregate function (#10354) * add simplify method for aggregate function * simplify returns closure --- .../examples/simplify_udaf_expression.rs | 180 ++++++++++++++++++ datafusion/expr/src/function.rs | 13 ++ datafusion/expr/src/udaf.rs | 33 +++- .../simplify_expressions/expr_simplifier.rs | 105 +++++++++- 4 files changed, 328 insertions(+), 3 deletions(-) create mode 100644 datafusion-examples/examples/simplify_udaf_expression.rs diff --git a/datafusion-examples/examples/simplify_udaf_expression.rs b/datafusion-examples/examples/simplify_udaf_expression.rs new file mode 100644 index 000000000000..92deb20272e4 --- /dev/null +++ b/datafusion-examples/examples/simplify_udaf_expression.rs @@ -0,0 +1,180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::{Field, Schema}; +use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; +use datafusion_expr::function::AggregateFunctionSimplification; +use datafusion_expr::simplify::SimplifyInfo; + +use std::{any::Any, sync::Arc}; + +use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch}; +use datafusion::error::Result; +use datafusion::{assert_batches_eq, prelude::*}; +use datafusion_common::cast::as_float64_array; +use datafusion_expr::{ + expr::{AggregateFunction, AggregateFunctionDefinition}, + function::AccumulatorArgs, + Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature, +}; + +/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user +/// defined aggregate function with a different expression which is defined in the `simplify` method. + +#[derive(Debug, Clone)] +struct BetterAvgUdaf { + signature: Signature, +} + +impl BetterAvgUdaf { + /// Create a new instance of the GeoMeanUdaf struct + fn new() -> Self { + Self { + signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable), + } + } +} + +impl AggregateUDFImpl for BetterAvgUdaf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "better_avg" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + unimplemented!("should not be invoked") + } + + fn state_fields( + &self, + _name: &str, + _value_type: DataType, + _ordering_fields: Vec, + ) -> Result> { + unimplemented!("should not be invoked") + } + + fn groups_accumulator_supported(&self) -> bool { + true + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("should not get here"); + } + // we override method, to return new expression which would substitute + // user defined function call + fn simplify(&self) -> Option { + // as an example for this functionality we replace UDF function + // with build-in aggregate function to illustrate the use + let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction, + _: &dyn SimplifyInfo| { + Ok(Expr::AggregateFunction(AggregateFunction { + func_def: AggregateFunctionDefinition::BuiltIn( + // yes it is the same Avg, `BetterAvgUdaf` was just a + // marketing pitch :) + datafusion_expr::aggregate_function::AggregateFunction::Avg, + ), + args: aggregate_function.args, + distinct: aggregate_function.distinct, + filter: aggregate_function.filter, + order_by: aggregate_function.order_by, + null_treatment: aggregate_function.null_treatment, + })) + }; + + Some(Box::new(simplify)) + } +} + +// create local session context with an in-memory table +fn create_context() -> Result { + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float32, false), + ])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])), + Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])), + ], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![16.0])), + Arc::new(Float32Array::from(vec![2.0])), + ], + )?; + + let ctx = SessionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Arc::new(provider))?; + Ok(ctx) +} + +#[tokio::main] +async fn main() -> Result<()> { + let ctx = create_context()?; + + let better_avg = AggregateUDF::from(BetterAvgUdaf::new()); + ctx.register_udaf(better_avg.clone()); + + let result = ctx + .sql("SELECT better_avg(a) FROM t group by b") + .await? + .collect() + .await?; + + let expected = [ + "+-----------------+", + "| better_avg(t.a) |", + "+-----------------+", + "| 7.5 |", + "+-----------------+", + ]; + + assert_batches_eq!(expected, &result); + + let df = ctx.table("t").await?; + let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?; + + let results = df.collect().await?; + let result = as_float64_array(results[0].column(0))?; + + assert!((result.value(0) - 7.5).abs() < f64::EPSILON); + println!("The average of [2,4,8,16] is {}", result.value(0)); + + Ok(()) +} diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 7a92a50ae15d..4e4d77924a9d 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -97,3 +97,16 @@ pub type PartitionEvaluatorFactory = /// its state, given its return datatype. pub type StateTypeFunction = Arc Result>> + Send + Sync>; + +/// [crate::udaf::AggregateUDFImpl::simplify] simplifier closure +/// A closure with two arguments: +/// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked +/// * 'info': [crate::simplify::SimplifyInfo] +/// +/// closure returns simplified [Expr] or an error. +pub type AggregateFunctionSimplification = Box< + dyn Fn( + crate::expr::AggregateFunction, + &dyn crate::simplify::SimplifyInfo, + ) -> Result, +>; diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index e5a47ddcd8b6..95121d78e7aa 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -17,7 +17,7 @@ //! [`AggregateUDF`]: User Defined Aggregate Functions -use crate::function::AccumulatorArgs; +use crate::function::{AccumulatorArgs, AggregateFunctionSimplification}; use crate::groups_accumulator::GroupsAccumulator; use crate::utils::format_state_name; use crate::{Accumulator, Expr}; @@ -199,6 +199,12 @@ impl AggregateUDF { pub fn coerce_types(&self, _args: &[DataType]) -> Result> { not_impl_err!("coerce_types not implemented for {:?} yet", self.name()) } + /// Do the function rewrite + /// + /// See [`AggregateUDFImpl::simplify`] for more details. + pub fn simplify(&self) -> Option { + self.inner.simplify() + } } impl From for AggregateUDF @@ -358,6 +364,31 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { fn aliases(&self) -> &[String] { &[] } + + /// Optionally apply per-UDaF simplification / rewrite rules. + /// + /// This can be used to apply function specific simplification rules during + /// optimization (e.g. `arrow_cast` --> `Expr::Cast`). The default + /// implementation does nothing. + /// + /// Note that DataFusion handles simplifying arguments and "constant + /// folding" (replacing a function call with constant arguments such as + /// `my_add(1,2) --> 3` ). Thus, there is no need to implement such + /// optimizations manually for specific UDFs. + /// + /// # Returns + /// + /// [None] if simplify is not defined or, + /// + /// Or, a closure with two arguments: + /// * 'aggregate_function': [crate::expr::AggregateFunction] for which simplified has been invoked + /// * 'info': [crate::simplify::SimplifyInfo] + /// + /// closure returns simplified [Expr] or an error. + /// + fn simplify(&self) -> Option { + None + } } /// AggregateUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 5122de4f09a7..55052542a8bf 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -32,7 +32,7 @@ use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, }; use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue}; -use datafusion_expr::expr::{InList, InSubquery}; +use datafusion_expr::expr::{AggregateFunctionDefinition, InList, InSubquery}; use datafusion_expr::simplify::ExprSimplifyResult; use datafusion_expr::{ and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility, @@ -1382,6 +1382,16 @@ impl<'a, S: SimplifyInfo> TreeNodeRewriter for Simplifier<'a, S> { } } + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction { + func_def: AggregateFunctionDefinition::UDF(ref udaf), + .. + }) => match (udaf.simplify(), expr) { + (Some(simplify_function), Expr::AggregateFunction(af)) => { + Transformed::yes(simplify_function(af, info)?) + } + (_, expr) => Transformed::no(expr), + }, + // // Rules for Between // @@ -1748,7 +1758,9 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result { #[cfg(test)] mod tests { use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema}; - use datafusion_expr::{interval_arithmetic::Interval, *}; + use datafusion_expr::{ + function::AggregateFunctionSimplification, interval_arithmetic::Interval, *, + }; use std::{ collections::HashMap, ops::{BitAnd, BitOr, BitXor}, @@ -3698,4 +3710,93 @@ mod tests { assert_eq!(expr, expected); assert_eq!(num_iter, 2); } + #[test] + fn test_simplify_udaf() { + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_with_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = col("result_column"); + assert_eq!(simplify(aggregate_function_expr), expected); + + let udaf = AggregateUDF::new_from_impl(SimplifyMockUdaf::new_without_simplify()); + let aggregate_function_expr = + Expr::AggregateFunction(datafusion_expr::expr::AggregateFunction::new_udf( + udaf.into(), + vec![], + false, + None, + None, + None, + )); + + let expected = aggregate_function_expr.clone(); + assert_eq!(simplify(aggregate_function_expr), expected); + } + + /// A Mock UDAF which defines `simplify` to be used in tests + /// related to UDAF simplification + #[derive(Debug, Clone)] + struct SimplifyMockUdaf { + simplify: bool, + } + + impl SimplifyMockUdaf { + /// make simplify method return new expression + fn new_with_simplify() -> Self { + Self { simplify: true } + } + /// make simplify method return no change + fn new_without_simplify() -> Self { + Self { simplify: false } + } + } + + impl AggregateUDFImpl for SimplifyMockUdaf { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "mock_simplify" + } + + fn signature(&self) -> &Signature { + unimplemented!() + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + unimplemented!("not needed for tests") + } + + fn accumulator( + &self, + _acc_args: function::AccumulatorArgs, + ) -> Result> { + unimplemented!("not needed for tests") + } + + fn groups_accumulator_supported(&self) -> bool { + unimplemented!("not needed for testing") + } + + fn create_groups_accumulator(&self) -> Result> { + unimplemented!("not needed for testing") + } + + fn simplify(&self) -> Option { + if self.simplify { + Some(Box::new(|_, _| Ok(col("result_column")))) + } else { + None + } + } + } } From 5fac581efbaffd0e6a9edf931182517524526afd Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 13 May 2024 05:14:19 -0700 Subject: [PATCH 02/11] Add cast array test to sqllogictest (#10474) --- datafusion/sqllogictest/test_files/cast.slt | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/datafusion/sqllogictest/test_files/cast.slt b/datafusion/sqllogictest/test_files/cast.slt index 73862be60d9b..4554c9292b6e 100644 --- a/datafusion/sqllogictest/test_files/cast.slt +++ b/datafusion/sqllogictest/test_files/cast.slt @@ -56,3 +56,16 @@ query I SELECT 10::bigint unsigned ---- 10 + +# cast array +query ? +SELECT CAST(MAKE_ARRAY(1, 2, 3) AS VARCHAR[]) +---- +[1, 2, 3] + + +# cast empty array +query ? +SELECT CAST(MAKE_ARRAY() AS VARCHAR[]) +---- +[] From 53de994423fc85f655da232db7e807c2a38276ea Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 09:10:13 -0400 Subject: [PATCH 03/11] Add Expr::try_as_col, deprecate Expr::try_into_col (#10448) --- datafusion/expr/src/expr.rs | 23 +++++++++++++++++++ datafusion/expr/src/expr_rewriter/mod.rs | 2 +- datafusion/expr/src/logical_plan/builder.rs | 10 +++++--- datafusion/expr/src/logical_plan/plan.rs | 14 +++++++++-- datafusion/optimizer/src/push_down_filter.rs | 4 ++-- .../simplify_expressions/inlist_simplifier.rs | 2 +- datafusion/proto/src/logical_plan/mod.rs | 12 +++++++--- 7 files changed, 55 insertions(+), 12 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f0f41a4c55c5..660a45c27a29 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1264,6 +1264,7 @@ impl Expr { }) } + #[deprecated(since = "39.0.0", note = "use try_as_col instead")] pub fn try_into_col(&self) -> Result { match self { Expr::Column(it) => Ok(it.clone()), @@ -1271,6 +1272,28 @@ impl Expr { } } + /// Return a reference to the inner `Column` if any + /// + /// returns `None` if the expression is not a `Column` + /// + /// Example + /// ``` + /// # use datafusion_common::Column; + /// use datafusion_expr::{col, Expr}; + /// let expr = col("foo"); + /// assert_eq!(expr.try_as_col(), Some(&Column::from("foo"))); + /// + /// let expr = col("foo").alias("bar"); + /// assert_eq!(expr.try_as_col(), None); + /// ``` + pub fn try_as_col(&self) -> Option<&Column> { + if let Expr::Column(it) = self { + Some(it) + } else { + None + } + } + /// Return all referenced columns of this expression. pub fn to_columns(&self) -> Result> { let mut using_columns = HashSet::new(); diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 700dd560ec0b..1441374bdba3 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -221,7 +221,7 @@ pub fn coerce_plan_expr_for_schema( let exprs: Vec = plan.schema().iter().map(Expr::from).collect(); let new_exprs = coerce_exprs_for_schema(exprs, plan.schema(), schema)?; - let add_project = new_exprs.iter().any(|expr| expr.try_into_col().is_err()); + let add_project = new_exprs.iter().any(|expr| expr.try_as_col().is_none()); if add_project { let projection = Projection::try_new(new_exprs, Arc::new(plan.clone()))?; Ok(LogicalPlan::Projection(projection)) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 3f15b84784f1..2c6cfd8f9d20 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1489,7 +1489,7 @@ pub fn wrap_projection_for_join_if_necessary( let mut projection = expand_wildcard(input_schema, &input, None)?; let join_key_items = alias_join_keys .iter() - .flat_map(|expr| expr.try_into_col().is_err().then_some(expr)) + .flat_map(|expr| expr.try_as_col().is_none().then_some(expr)) .cloned() .collect::>(); projection.extend(join_key_items); @@ -1504,8 +1504,12 @@ pub fn wrap_projection_for_join_if_necessary( let join_on = alias_join_keys .into_iter() .map(|key| { - key.try_into_col() - .or_else(|_| Ok(Column::from_name(key.display_name()?))) + if let Some(col) = key.try_as_col() { + Ok(col.clone()) + } else { + let name = key.display_name()?; + Ok(Column::from_name(name)) + } }) .collect::>>()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 9832b69f841a..266e7abc341a 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -369,8 +369,18 @@ impl LogicalPlan { // The join keys in using-join must be columns. let columns = on.iter().try_fold(HashSet::new(), |mut accumu, (l, r)| { - accumu.insert(l.try_into_col()?); - accumu.insert(r.try_into_col()?); + let Some(l) = l.try_as_col().cloned() else { + return internal_err!( + "Invalid join key. Expected column, found {l:?}" + ); + }; + let Some(r) = r.try_as_col().cloned() else { + return internal_err!( + "Invalid join key. Expected column, found {r:?}" + ); + }; + accumu.insert(l); + accumu.insert(r); Result::<_, DataFusionError>::Ok(accumu) })?; using_columns.push(columns); diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 9ce135b0d646..57b38bd0d0fd 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -535,8 +535,8 @@ fn push_down_join( .on .iter() .filter_map(|(l, r)| { - let left_col = l.try_into_col().ok()?; - let right_col = r.try_into_col().ok()?; + let left_col = l.try_as_col().cloned()?; + let right_col = r.try_as_col().cloned()?; Some((left_col, right_col)) }) .collect::>(); diff --git a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs index 9dcb8ed15563..c8638eb72395 100644 --- a/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/inlist_simplifier.rs @@ -52,7 +52,7 @@ impl TreeNodeRewriter for ShortenInListSimplifier { // expressions list.len() == 1 || list.len() <= THRESHOLD_INLINE_INLIST - && expr.try_into_col().is_ok() + && expr.try_as_col().is_some() ) { let first_val = list[0].clone(); diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a6352bcefc3e..83e58c3a22cc 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -45,8 +45,9 @@ use datafusion::{ prelude::SessionContext, }; use datafusion_common::{ - context, internal_err, not_impl_err, parsers::CompressionTypeVariant, - plan_datafusion_err, DataFusionError, Result, TableReference, + context, internal_datafusion_err, internal_err, not_impl_err, + parsers::CompressionTypeVariant, plan_datafusion_err, DataFusionError, Result, + TableReference, }; use datafusion_expr::{ dml, @@ -695,7 +696,12 @@ impl AsLogicalPlan for LogicalPlanNode { // The equijoin keys in using-join must be column. let using_keys = left_keys .into_iter() - .map(|key| key.try_into_col()) + .map(|key| { + key.try_as_col().cloned() + .ok_or_else(|| internal_datafusion_err!( + "Using join keys must be column references, got: {key:?}" + )) + }) .collect::, _>>()?; builder.join_using( into_logical_plan!(join.right, ctx, extension_codec)?, From 3491f6bd5003624dc064db410eeaa41ef3f86acf Mon Sep 17 00:00:00 2001 From: Abrar Khan Date: Mon, 13 May 2024 18:53:59 +0530 Subject: [PATCH 04/11] Implement `From>` for `LogicalPlanBuilder` (#10466) * implement From> for LogicalPlanBuilder * make fmt happy * added test case and doc comment --- datafusion/expr/src/logical_plan/builder.rs | 47 ++++++++++++++++++++- 1 file changed, 45 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 2c6cfd8f9d20..6055537ac511 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -42,8 +42,9 @@ use crate::utils::{ expand_wildcard, find_valid_equijoin_key_pair, group_window_expr_by_sort_keys, }; use crate::{ - and, binary_expr, DmlStatement, Expr, ExprSchemable, Operator, RecursiveQuery, - TableProviderFilterPushDown, TableSource, WriteOp, + and, binary_expr, logical_plan::tree_node::unwrap_arc, DmlStatement, Expr, + ExprSchemable, Operator, RecursiveQuery, TableProviderFilterPushDown, TableSource, + WriteOp, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; @@ -1138,6 +1139,31 @@ impl LogicalPlanBuilder { )?)) } } + +/// Converts a `Arc` into `LogicalPlanBuilder` +/// fn employee_schema() -> Schema { +/// Schema::new(vec![ +/// Field::new("id", DataType::Int32, false), +/// Field::new("first_name", DataType::Utf8, false), +/// Field::new("last_name", DataType::Utf8, false), +/// Field::new("state", DataType::Utf8, false), +/// Field::new("salary", DataType::Int32, false), +/// ]) +/// } +/// let plan = table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? +/// .sort(vec![ +/// Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), +/// Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), +/// ])? +/// .build()?; +/// let plan_builder: LogicalPlanBuilder = Arc::new(plan).into(); + +impl From> for LogicalPlanBuilder { + fn from(plan: Arc) -> Self { + LogicalPlanBuilder::from(unwrap_arc(plan)) + } +} + pub fn change_redundant_column(fields: &Fields) -> Vec { let mut name_map = HashMap::new(); fields @@ -2144,4 +2170,21 @@ mod tests { ); Ok(()) } + + #[test] + fn plan_builder_from_logical_plan() -> Result<()> { + let plan = + table_scan(Some("employee_csv"), &employee_schema(), Some(vec![3, 4]))? + .sort(vec![ + Expr::Sort(expr::Sort::new(Box::new(col("state")), true, true)), + Expr::Sort(expr::Sort::new(Box::new(col("salary")), false, false)), + ])? + .build()?; + + let plan_expected = format!("{plan:?}"); + let plan_builder: LogicalPlanBuilder = Arc::new(plan).into(); + assert_eq!(plan_expected, format!("{:?}", plan_builder.plan)); + + Ok(()) + } } From c7dbfeb79a0f41b6098184de33499546697ef631 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 10:49:27 -0400 Subject: [PATCH 05/11] Minor: Improve documentation for `catalog.has_header` config option (#10452) * Minor: document catalog.has_header better * update docs * update test --- datafusion/common/src/config.rs | 3 ++- datafusion/sqllogictest/test_files/information_schema.slt | 2 +- docs/source/user-guide/configs.md | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index c60f843393f8..0f1d9b8f0264 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -181,7 +181,8 @@ config_namespace! { /// Type of `TableProvider` to use when loading `default` schema pub format: Option, default = None - /// If the file has a header + /// Default value for `format.has_header` for `CREATE EXTERNAL TABLE` + /// if not specified explicitly in the statement. pub has_header: bool, default = false } } diff --git a/datafusion/sqllogictest/test_files/information_schema.slt b/datafusion/sqllogictest/test_files/information_schema.slt index de00cf9d0547..6f31973fdb6f 100644 --- a/datafusion/sqllogictest/test_files/information_schema.slt +++ b/datafusion/sqllogictest/test_files/information_schema.slt @@ -246,7 +246,7 @@ datafusion.catalog.create_default_catalog_and_schema true Whether the default ca datafusion.catalog.default_catalog datafusion The default catalog name - this impacts what SQL queries use if not specified datafusion.catalog.default_schema public The default schema name - this impacts what SQL queries use if not specified datafusion.catalog.format NULL Type of `TableProvider` to use when loading `default` schema -datafusion.catalog.has_header false If the file has a header +datafusion.catalog.has_header false Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. datafusion.catalog.information_schema true Should DataFusion provide access to `information_schema` virtual tables for displaying schema information datafusion.catalog.location NULL Location scanned to load tables for `default` schema datafusion.execution.aggregate.scalar_update_factor 10 Specifies the threshold for using `ScalarValue`s to update accumulators during high-cardinality aggregations for each input batch. The aggregation is considered high-cardinality if the number of affected groups is greater than or equal to `batch_size / scalar_update_factor`. In such cases, `ScalarValue`s are utilized for updating accumulators, rather than the default batch-slice approach. This can lead to performance improvements. By adjusting the `scalar_update_factor`, you can balance the trade-off between more efficient accumulator updates and the number of groups affected. diff --git a/docs/source/user-guide/configs.md b/docs/source/user-guide/configs.md index ef2a2a4119e3..0cfd81eff75a 100644 --- a/docs/source/user-guide/configs.md +++ b/docs/source/user-guide/configs.md @@ -43,7 +43,7 @@ Environment variables are read during `SessionConfig` initialisation so they mus | datafusion.catalog.information_schema | false | Should DataFusion provide access to `information_schema` virtual tables for displaying schema information | | datafusion.catalog.location | NULL | Location scanned to load tables for `default` schema | | datafusion.catalog.format | NULL | Type of `TableProvider` to use when loading `default` schema | -| datafusion.catalog.has_header | false | If the file has a header | +| datafusion.catalog.has_header | false | Default value for `format.has_header` for `CREATE EXTERNAL TABLE` if not specified explicitly in the statement. | | datafusion.execution.batch_size | 8192 | Default batch size while creating new batches, it's especially useful for buffer-in-memory batches since creating tiny batches would result in too much metadata memory consumption | | datafusion.execution.coalesce_batches | true | When set to true, record batches will be examined between each operator and small batches will be coalesced into larger batches. This is helpful when there are highly selective filters or joins that could produce tiny output batches. The target batch size is determined by the configuration setting | | datafusion.execution.collect_statistics | false | Should DataFusion collect statistics after listing files | From 9cc981b06115ee40b53384c287689ce0e07950bc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 10:49:51 -0400 Subject: [PATCH 06/11] Minor: Simplify conjunction and disjunction, improve docs (#10446) --- datafusion/expr/src/utils.rs | 37 ++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 0c1084674d8e..43e8ff7b23d6 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -1107,7 +1107,7 @@ fn split_binary_impl<'a>( /// assert_eq!(conjunction(split), Some(expr)); /// ``` pub fn conjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.and(expr)) + filters.into_iter().reduce(Expr::and) } /// Combines an array of filter expressions into a single filter @@ -1115,12 +1115,41 @@ pub fn conjunction(filters: impl IntoIterator) -> Option { /// logical OR. /// /// Returns None if the filters array is empty. +/// +/// # Example +/// ``` +/// # use datafusion_expr::{col, lit}; +/// # use datafusion_expr::utils::disjunction; +/// // a=1 OR b=2 +/// let expr = col("a").eq(lit(1)).or(col("b").eq(lit(2))); +/// +/// // [a=1, b=2] +/// let split = vec![ +/// col("a").eq(lit(1)), +/// col("b").eq(lit(2)), +/// ]; +/// +/// // use disjuncton to join them together with `OR` +/// assert_eq!(disjunction(split), Some(expr)); +/// ``` pub fn disjunction(filters: impl IntoIterator) -> Option { - filters.into_iter().reduce(|accum, expr| accum.or(expr)) + filters.into_iter().reduce(Expr::or) } -/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with -/// its predicate be all `predicates` ANDed. +/// Returns a new [LogicalPlan] that filters the output of `plan` with a +/// [LogicalPlan::Filter] with all `predicates` ANDed. +/// +/// # Example +/// Before: +/// ```text +/// plan +/// ``` +/// +/// After: +/// ```text +/// Filter(predicate) +/// plan +/// ``` pub fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> Result { // reduce filters to a single filter with an AND let predicate = predicates From a2eca291ad9d586222f042ab4c068feeb055526b Mon Sep 17 00:00:00 2001 From: ClSlaid Date: Tue, 14 May 2024 00:03:25 +0800 Subject: [PATCH 07/11] Stop copying LogicalPlan and Exprs in `ReplaceDistinctWithAggregate` (#10460) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * patch: implement rewrite for RDWA Signed-off-by: cailue * refactor: rewrite replace_distinct_aggregate Signed-off-by: 蔡略 * patch: recorrect aggr_expr Signed-off-by: 蔡略 * Update datafusion/optimizer/src/replace_distinct_aggregate.rs --------- Signed-off-by: cailue Signed-off-by: 蔡略 Co-authored-by: Andrew Lamb --- .../src/replace_distinct_aggregate.rs | 73 +++++++++++++------ 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/datafusion/optimizer/src/replace_distinct_aggregate.rs b/datafusion/optimizer/src/replace_distinct_aggregate.rs index 4f68e2623f40..404f054cb9fa 100644 --- a/datafusion/optimizer/src/replace_distinct_aggregate.rs +++ b/datafusion/optimizer/src/replace_distinct_aggregate.rs @@ -19,7 +19,9 @@ use crate::optimizer::{ApplyOrder, ApplyOrder::BottomUp}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{Column, Result}; +use datafusion_common::tree_node::Transformed; +use datafusion_common::{internal_err, Column, Result}; +use datafusion_expr::expr_rewriter::normalize_cols; use datafusion_expr::utils::expand_wildcard; use datafusion_expr::{ aggregate_function::AggregateFunction as AggregateFunctionFunc, col, @@ -66,20 +68,24 @@ impl ReplaceDistinctWithAggregate { } impl OptimizerRule for ReplaceDistinctWithAggregate { - fn try_optimize( + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( &self, - plan: &LogicalPlan, + plan: LogicalPlan, _config: &dyn OptimizerConfig, - ) -> Result> { + ) -> Result> { match plan { LogicalPlan::Distinct(Distinct::All(input)) => { - let group_expr = expand_wildcard(input.schema(), input, None)?; - let aggregate = LogicalPlan::Aggregate(Aggregate::try_new( - input.clone(), + let group_expr = expand_wildcard(input.schema(), &input, None)?; + let aggr_plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, vec![], )?); - Ok(Some(aggregate)) + Ok(Transformed::yes(aggr_plan)) } LogicalPlan::Distinct(Distinct::On(DistinctOn { select_expr, @@ -88,13 +94,15 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { input, schema, })) => { + let expr_cnt = on_expr.len(); + // Construct the aggregation expression to be used to fetch the selected expressions. let aggr_expr = select_expr - .iter() + .into_iter() .map(|e| { Expr::AggregateFunction(AggregateFunction::new( AggregateFunctionFunc::FirstValue, - vec![e.clone()], + vec![e], false, None, sort_expr.clone(), @@ -103,45 +111,62 @@ impl OptimizerRule for ReplaceDistinctWithAggregate { }) .collect::>(); + let aggr_expr = normalize_cols(aggr_expr, input.as_ref())?; + let group_expr = normalize_cols(on_expr, input.as_ref())?; + // Build the aggregation plan - let plan = LogicalPlanBuilder::from(input.as_ref().clone()) - .aggregate(on_expr.clone(), aggr_expr.to_vec())? - .build()?; + let plan = LogicalPlan::Aggregate(Aggregate::try_new( + input, group_expr, aggr_expr, + )?); + // TODO use LogicalPlanBuilder directly rather than recreating the Aggregate + // when https://github.com/apache/datafusion/issues/10485 is available + let lpb = LogicalPlanBuilder::from(plan); - let plan = if let Some(sort_expr) = sort_expr { + let plan = if let Some(mut sort_expr) = sort_expr { // While sort expressions were used in the `FIRST_VALUE` aggregation itself above, // this on it's own isn't enough to guarantee the proper output order of the grouping // (`ON`) expression, so we need to sort those as well. - LogicalPlanBuilder::from(plan) - .sort(sort_expr[..on_expr.len()].to_vec())? - .build()? + + // truncate the sort_expr to the length of on_expr + sort_expr.truncate(expr_cnt); + + lpb.sort(sort_expr)?.build()? } else { - plan + lpb.build()? }; // Whereas the aggregation plan by default outputs both the grouping and the aggregation // expressions, for `DISTINCT ON` we only need to emit the original selection expressions. + let project_exprs = plan .schema() .iter() - .skip(on_expr.len()) + .skip(expr_cnt) .zip(schema.iter()) .map(|((new_qualifier, new_field), (old_qualifier, old_field))| { - Ok(col(Column::from((new_qualifier, new_field))) - .alias_qualified(old_qualifier.cloned(), old_field.name())) + col(Column::from((new_qualifier, new_field))) + .alias_qualified(old_qualifier.cloned(), old_field.name()) }) - .collect::>>()?; + .collect::>(); let plan = LogicalPlanBuilder::from(plan) .project(project_exprs)? .build()?; - Ok(Some(plan)) + Ok(Transformed::yes(plan)) } - _ => Ok(None), + _ => Ok(Transformed::no(plan)), } } + fn try_optimize( + &self, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, + ) -> Result> { + internal_err!("Should have called ReplaceDistinctWithAggregate::rewrite") + } + fn name(&self) -> &str { "replace_distinct_aggregate" } From adf0bfc757d2f9ba48c45d368578d07806858b89 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 14:00:35 -0400 Subject: [PATCH 08/11] Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` (4% faster planning) (#10431) * Stop copying LogicalPlan and Exprs in `EliminateCrossJoin` * Clarify when can_flatten_join_inputs runs * Use a single `map` --- .../optimizer/src/eliminate_cross_join.rs | 298 +++++++++++------- datafusion/optimizer/src/join_key_set.rs | 73 ++++- 2 files changed, 254 insertions(+), 117 deletions(-) diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 923be7574803..9d871c50ad99 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -18,11 +18,13 @@ //! [`EliminateCrossJoin`] converts `CROSS JOIN` to `INNER JOIN` if join predicates are available. use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use crate::join_key_set::JoinKeySet; -use datafusion_common::{plan_err, Result}; +use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::{internal_err, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; +use datafusion_expr::logical_plan::tree_node::unwrap_arc; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; @@ -39,65 +41,109 @@ impl EliminateCrossJoin { } } -/// Attempt to reorder join to eliminate cross joins to inner joins. -/// for queries: -/// 'select ... from a, b where a.x = b.y and b.xx = 100;' -/// 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' -/// 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) -/// or (a.x = b.y and b.xx = 200 and a.z=c.z);' -/// 'select ... from a, b where a.x > b.y' +/// Eliminate cross joins by rewriting them to inner joins when possible. +/// +/// # Example +/// The initial plan for this query: +/// ```sql +/// select ... from a, b where a.x = b.y and b.xx = 100; +/// ``` +/// +/// Looks like this: +/// ```text +/// Filter(a.x = b.y AND b.xx = 100) +/// CrossJoin +/// TableScan a +/// TableScan b +/// ``` +/// +/// After the rule is applied, the plan will look like this: +/// ```text +/// Filter(b.xx = 100) +/// InnerJoin(a.x = b.y) +/// TableScan a +/// TableScan b +/// ``` +/// +/// # Other Examples +/// * 'select ... from a, b where a.x = b.y and b.xx = 100;' +/// * 'select ... from a, b where (a.x = b.y and b.xx = 100) or (a.x = b.y and b.xx = 200);' +/// * 'select ... from a, b, c where (a.x = b.y and b.xx = 100 and a.z = c.z) +/// * or (a.x = b.y and b.xx = 200 and a.z=c.z);' +/// * 'select ... from a, b where a.x > b.y' +/// /// For above queries, the join predicate is available in filters and they are moved to /// join nodes appropriately +/// /// This fix helps to improve the performance of TPCH Q19. issue#78 impl OptimizerRule for EliminateCrossJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called EliminateCrossJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { + let plan_schema = plan.schema().clone(); let mut possible_join_keys = JoinKeySet::new(); let mut all_inputs: Vec = vec![]; - let parent_predicate = match plan { - LogicalPlan::Filter(filter) => { - let input = filter.input.as_ref(); - match input { - LogicalPlan::Join(Join { - join_type: JoinType::Inner, - .. - }) - | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs( - input, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - extract_possible_join_keys( - &filter.predicate, - &mut possible_join_keys, - ); - Some(&filter.predicate) - } - _ => { - return utils::optimize_children(self, plan, config); - } - } + + let parent_predicate = if let LogicalPlan::Filter(filter) = plan { + // if input isn't a join that can potentially be rewritten + // avoid unwrapping the input + let rewriteable = matches!( + filter.input.as_ref(), + LogicalPlan::Join(Join { + join_type: JoinType::Inner, + .. + }) | LogicalPlan::CrossJoin(_) + ); + + if !rewriteable { + // recursively try to rewrite children + return rewrite_children(self, LogicalPlan::Filter(filter), config); } + + if !can_flatten_join_inputs(&filter.input) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + + let Filter { + input, predicate, .. + } = filter; + flatten_join_inputs( + unwrap_arc(input), + &mut possible_join_keys, + &mut all_inputs, + )?; + + extract_possible_join_keys(&predicate, &mut possible_join_keys); + Some(predicate) + } else if matches!( + plan, LogicalPlan::Join(Join { join_type: JoinType::Inner, .. - }) => { - if !try_flatten_join_inputs( - plan, - &mut possible_join_keys, - &mut all_inputs, - )? { - return Ok(None); - } - None + }) + ) { + if !can_flatten_join_inputs(&plan) { + return Ok(Transformed::no(plan)); } - _ => return utils::optimize_children(self, plan, config), + flatten_join_inputs(plan, &mut possible_join_keys, &mut all_inputs)?; + None + } else { + // recursively try to rewrite children + return rewrite_children(self, plan, config); }; // Join keys are handled locally: @@ -105,36 +151,36 @@ impl OptimizerRule for EliminateCrossJoin { let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( - &left, + left, &mut all_inputs, &possible_join_keys, &mut all_join_keys, )?; } - left = utils::optimize_children(self, &left, config)?.unwrap_or(left); + left = rewrite_children(self, left, config)?.data; - if plan.schema() != left.schema() { + if &plan_schema != left.schema() { left = LogicalPlan::Projection(Projection::new_from_schema( Arc::new(left), - plan.schema().clone(), + plan_schema.clone(), )); } let Some(predicate) = parent_predicate else { - return Ok(Some(left)); + return Ok(Transformed::yes(left)); }; // If there are no join keys then do nothing: if all_join_keys.is_empty() { - Filter::try_new(predicate.clone(), Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))) + Filter::try_new(predicate, Arc::new(left)) + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))) } else { // Remove join expressions from filter: - match remove_join_expressions(predicate.clone(), &all_join_keys) { + match remove_join_expressions(predicate, &all_join_keys) { Some(filter_expr) => Filter::try_new(filter_expr, Arc::new(left)) - .map(|f| Some(LogicalPlan::Filter(f))), - _ => Ok(Some(left)), + .map(|filter| Transformed::yes(LogicalPlan::Filter(filter))), + _ => Ok(Transformed::yes(left)), } } } @@ -144,49 +190,89 @@ impl OptimizerRule for EliminateCrossJoin { } } +fn rewrite_children( + optimizer: &impl OptimizerRule, + plan: LogicalPlan, + config: &dyn OptimizerConfig, +) -> Result> { + let transformed_plan = plan.map_children(|input| optimizer.rewrite(input, config))?; + + // recompute schema if the plan was transformed + if transformed_plan.transformed { + transformed_plan.map_data(|plan| plan.recompute_schema()) + } else { + Ok(transformed_plan) + } +} + /// Recursively accumulate possible_join_keys and inputs from inner joins /// (including cross joins). /// -/// Returns a boolean indicating whether the flattening was successful. -fn try_flatten_join_inputs( - plan: &LogicalPlan, +/// Assumes can_flatten_join_inputs has returned true and thus the plan can be +/// flattened. Adds all leaf inputs to `all_inputs` and join_keys to +/// possible_join_keys +fn flatten_join_inputs( + plan: LogicalPlan, possible_join_keys: &mut JoinKeySet, all_inputs: &mut Vec, -) -> Result { - let children = match plan { +) -> Result<()> { + match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // checked in can_flatten_join_inputs if join.filter.is_some() { - // The filter of inner join will lost, skip this rule. - // issue: https://github.com/apache/datafusion/issues/4844 - return Ok(false); + return internal_err!( + "should not have filter in inner join in flatten_join_inputs" + ); } - possible_join_keys.insert_all(join.on.iter()); - vec![&join.left, &join.right] + possible_join_keys.insert_all_owned(join.on); + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } LogicalPlan::CrossJoin(join) => { - vec![&join.left, &join.right] + flatten_join_inputs(unwrap_arc(join.left), possible_join_keys, all_inputs)?; + flatten_join_inputs(unwrap_arc(join.right), possible_join_keys, all_inputs)?; } _ => { - return plan_err!("flatten_join_inputs just can call join/cross_join"); + all_inputs.push(plan); } }; + Ok(()) +} - for child in children.iter() { - let child = child.as_ref(); +/// Returns true if the plan is a Join or Cross join could be flattened with +/// `flatten_join_inputs` +/// +/// Must stay in sync with `flatten_join_inputs` +fn can_flatten_join_inputs(plan: &LogicalPlan) -> bool { + // can only flatten inner / cross joins + match plan { + LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { + // The filter of inner join will lost, skip this rule. + // issue: https://github.com/apache/datafusion/issues/4844 + if join.filter.is_some() { + return false; + } + } + LogicalPlan::CrossJoin(_) => {} + _ => return false, + }; + + for child in plan.inputs() { match child { LogicalPlan::Join(Join { join_type: JoinType::Inner, .. }) | LogicalPlan::CrossJoin(_) => { - if !try_flatten_join_inputs(child, possible_join_keys, all_inputs)? { - return Ok(false); + if !can_flatten_join_inputs(child) { + return false; } } - _ => all_inputs.push(child.clone()), + // the child is not a join/cross join + _ => (), } } - Ok(true) + true } /// Finds the next to join with the left input plan, @@ -202,7 +288,7 @@ fn try_flatten_join_inputs( /// 1. Removes the first plan from `rights` /// 2. Returns `left_input CROSS JOIN right`. fn find_inner_join( - left_input: &LogicalPlan, + left_input: LogicalPlan, rights: &mut Vec, possible_join_keys: &JoinKeySet, all_join_keys: &mut JoinKeySet, @@ -237,7 +323,7 @@ fn find_inner_join( )?); return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right_input), join_type: JoinType::Inner, join_constraint: JoinConstraint::On, @@ -259,7 +345,7 @@ fn find_inner_join( )?); Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), + left: Arc::new(left_input), right: Arc::new(right), schema: join_schema, })) @@ -341,12 +427,12 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: Vec<&str>) { + fn assert_optimized_plan_eq(plan: LogicalPlan, expected: Vec<&str>) { + let starting_schema = plan.schema().clone(); let rule = EliminateCrossJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(transformed_plan.transformed, "failed to optimize plan"); + let optimized_plan = transformed_plan.data; let formatted = optimized_plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -355,13 +441,13 @@ mod tests { "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" ); - assert_eq!(plan.schema(), optimized_plan.schema()) + assert_eq!(&starting_schema, optimized_plan.schema()) } - fn assert_optimization_rule_fails(plan: &LogicalPlan) { + fn assert_optimization_rule_fails(plan: LogicalPlan) { let rule = EliminateCrossJoin::new(); - let optimized_plan = rule.try_optimize(plan, &OptimizerContext::new()).unwrap(); - assert!(optimized_plan.is_none()); + let transformed_plan = rule.rewrite(plan, &OptimizerContext::new()).unwrap(); + assert!(!transformed_plan.transformed) } #[test] @@ -386,7 +472,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -414,7 +500,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -441,7 +527,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -471,7 +557,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -501,7 +587,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -527,7 +613,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -551,7 +637,7 @@ mod tests { .filter(col("t1.a").gt(lit(15u32)))? .build()?; - assert_optimization_rule_fails(&plan); + assert_optimization_rule_fails(plan); Ok(()) } @@ -598,7 +684,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -675,7 +761,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -750,7 +836,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -825,7 +911,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -904,7 +990,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -987,7 +1073,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1074,7 +1160,7 @@ mod tests { " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1100,7 +1186,7 @@ mod tests { " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1128,7 +1214,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1156,7 +1242,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1184,7 +1270,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } @@ -1224,7 +1310,7 @@ mod tests { " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_eq(plan, expected); Ok(()) } diff --git a/datafusion/optimizer/src/join_key_set.rs b/datafusion/optimizer/src/join_key_set.rs index c47afa012c17..cd8ed382f069 100644 --- a/datafusion/optimizer/src/join_key_set.rs +++ b/datafusion/optimizer/src/join_key_set.rs @@ -66,20 +66,46 @@ impl JoinKeySet { } } + /// Same as [`Self::insert`] but avoids cloning expression if they + /// are owned + pub fn insert_owned(&mut self, left: Expr, right: Expr) -> bool { + if self.contains(&left, &right) { + false + } else { + self.inner.insert((left, right)); + true + } + } + /// Inserts potentially many join keys into the set, copying only when necessary /// /// returns true if any of the pairs were inserted pub fn insert_all<'a>( &mut self, - iter: impl Iterator, + iter: impl IntoIterator, ) -> bool { let mut inserted = false; - for (left, right) in iter { + for (left, right) in iter.into_iter() { inserted |= self.insert(left, right); } inserted } + /// Same as [`Self::insert_all`] but avoids cloning expressions if they are + /// already owned + /// + /// returns true if any of the pairs were inserted + pub fn insert_all_owned( + &mut self, + iter: impl IntoIterator, + ) -> bool { + let mut inserted = false; + for (left, right) in iter.into_iter() { + inserted |= self.insert_owned(left, right); + } + inserted + } + /// Inserts any join keys that are common to both `s1` and `s2` into self pub fn insert_intersection(&mut self, s1: JoinKeySet, s2: JoinKeySet) { // note can't use inner.intersection as we need to consider both (l, r) @@ -156,6 +182,15 @@ mod test { assert_eq!(set.len(), 2); } + #[test] + fn test_insert_owned() { + let mut set = JoinKeySet::new(); + assert!(set.insert_owned(col("a"), col("b"))); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("a"))); + assert!(!set.contains(&col("a"), &col("c"))); + } + #[test] fn test_contains() { let mut set = JoinKeySet::new(); @@ -217,18 +252,34 @@ mod test { } #[test] - fn test_insert_many() { + fn test_insert_all() { let mut set = JoinKeySet::new(); // insert (a=b), (b=c), (b=a) - set.insert_all( - vec![ - &(col("a"), col("b")), - &(col("b"), col("c")), - &(col("b"), col("a")), - ] - .into_iter(), - ); + set.insert_all(vec![ + &(col("a"), col("b")), + &(col("b"), col("c")), + &(col("b"), col("a")), + ]); + assert_eq!(set.len(), 2); + assert!(set.contains(&col("a"), &col("b"))); + assert!(set.contains(&col("b"), &col("c"))); + assert!(set.contains(&col("b"), &col("a"))); + + // should not contain (a=c) + assert!(!set.contains(&col("a"), &col("c"))); + } + + #[test] + fn test_insert_all_owned() { + let mut set = JoinKeySet::new(); + + // insert (a=b), (b=c), (b=a) + set.insert_all_owned(vec![ + (col("a"), col("b")), + (col("b"), col("c")), + (col("b"), col("a")), + ]); assert_eq!(set.len(), 2); assert!(set.contains(&col("a"), &col("b"))); assert!(set.contains(&col("b"), &col("c"))); From 5b74c2d1f8923b8f4f7cf7a660459a80bd947790 Mon Sep 17 00:00:00 2001 From: Mehmet Ozan Kabak Date: Mon, 13 May 2024 21:18:29 +0300 Subject: [PATCH 09/11] Improved ergonomy for `CREATE EXTERNAL TABLE OPTIONS`: Don't require quotations for simple namespaced keys like `foo.bar` (#10483) * Don't require quotations for simple namespaced keys like foo.bar * Add comments clarifying parse error cases for unquoted namespaced keys --- datafusion/common/src/config.rs | 65 ++++++++----------- datafusion/core/src/execution/context/mod.rs | 24 ++++--- .../tests/cases/roundtrip_logical_plan.rs | 18 ++--- datafusion/sql/src/parser.rs | 24 +++++-- .../test_files/create_external_table.slt | 21 ++++-- .../test_files/tpch/create_tables.slt.part | 2 +- 6 files changed, 84 insertions(+), 70 deletions(-) diff --git a/datafusion/common/src/config.rs b/datafusion/common/src/config.rs index 0f1d9b8f0264..a4f937b6e2a3 100644 --- a/datafusion/common/src/config.rs +++ b/datafusion/common/src/config.rs @@ -130,9 +130,9 @@ macro_rules! config_namespace { $( stringify!($field_name) => self.$field_name.set(rem, value), )* - _ => return Err(DataFusionError::Configuration(format!( + _ => return _config_err!( "Config value \"{}\" not found on {}", key, stringify!($struct_name) - ))) + ) } } @@ -676,22 +676,17 @@ impl ConfigOptions { /// Set a configuration option pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, key) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"", - )) - })?; + let Some((prefix, key)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "datafusion" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -1279,22 +1274,17 @@ impl TableOptions { /// /// A result indicating success or failure in setting the configuration option. pub fn set(&mut self, key: &str, value: &str) -> Result<()> { - let (prefix, _) = key.split_once('.').ok_or_else(|| { - DataFusionError::Configuration(format!( - "could not find config namespace for key \"{key}\"" - )) - })?; + let Some((prefix, _)) = key.split_once('.') else { + return _config_err!("could not find config namespace for key \"{key}\""); + }; if prefix == "format" { return ConfigField::set(self, key, value); } - let e = self.extensions.0.get_mut(prefix); - let e = e.ok_or_else(|| { - DataFusionError::Configuration(format!( - "Could not find config namespace \"{prefix}\"" - )) - })?; + let Some(e) = self.extensions.0.get_mut(prefix) else { + return _config_err!("Could not find config namespace \"{prefix}\""); + }; e.0.set(key, value) } @@ -1413,19 +1403,19 @@ impl ConfigField for TableParquetOptions { fn set(&mut self, key: &str, value: &str) -> Result<()> { // Determine if the key is a global, metadata, or column-specific setting if key.starts_with("metadata::") { - let k = - match key.split("::").collect::>()[..] { - [_meta] | [_meta, ""] => return Err(DataFusionError::Configuration( + let k = match key.split("::").collect::>()[..] { + [_meta] | [_meta, ""] => { + return _config_err!( "Invalid metadata key provided, missing key in metadata::" - .to_string(), - )), - [_meta, k] => k.into(), - _ => { - return Err(DataFusionError::Configuration(format!( + ) + } + [_meta, k] => k.into(), + _ => { + return _config_err!( "Invalid metadata key provided, found too many '::' in \"{key}\"" - ))) - } - }; + ) + } + }; self.key_value_metadata.insert(k, Some(value.into())); Ok(()) } else if key.contains("::") { @@ -1498,10 +1488,7 @@ macro_rules! config_namespace_with_hashmap { inner_value.set(inner_key, value) } - _ => Err(DataFusionError::Configuration(format!( - "Unrecognized key '{}'.", - key - ))), + _ => _config_err!("Unrecognized key '{key}'."), } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e69a249410b1..2fc1a19c3386 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -23,6 +23,8 @@ use std::ops::ControlFlow; use std::sync::{Arc, Weak}; use super::options::ReadOptions; +#[cfg(feature = "array_expressions")] +use crate::functions_array; use crate::{ catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA}, catalog::listing_schema::ListingSchemaProvider, @@ -53,53 +55,49 @@ use crate::{ }, optimizer::analyzer::{Analyzer, AnalyzerRule}, optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule}, + physical_expr::{create_physical_expr, PhysicalExpr}, physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule}, physical_plan::ExecutionPlan, physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}, variable::{VarProvider, VarType}, }; - -#[cfg(feature = "array_expressions")] -use crate::functions_array; use crate::{functions, functions_aggregate}; use arrow::datatypes::{DataType, SchemaRef}; use arrow::record_batch::RecordBatch; use arrow_schema::Schema; -use async_trait::async_trait; -use chrono::{DateTime, Utc}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ alias::AliasGenerator, config::{ConfigExtension, TableOptions}, exec_err, not_impl_err, plan_datafusion_err, plan_err, - tree_node::{TreeNodeRecursion, TreeNodeVisitor}, + tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor}, DFSchema, SchemaReference, TableReference, }; use datafusion_execution::registry::SerializerRegistry; use datafusion_expr::{ + expr_rewriter::FunctionRewrite, logical_plan::{DdlStatement, Statement}, + simplify::SimplifyInfo, var_provider::is_system_variables, Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF, }; +use datafusion_optimizer::simplify_expressions::ExprSimplifier; use datafusion_sql::{ parser::{CopyToSource, CopyToStatement, DFParser}, planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel}, ResolvedTableReference, }; -use parking_lot::RwLock; use sqlparser::dialect::dialect_from_str; + +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use parking_lot::RwLock; use url::Url; use uuid::Uuid; -use crate::physical_expr::PhysicalExpr; pub use datafusion_execution::config::SessionConfig; pub use datafusion_execution::TaskContext; pub use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::simplify::SimplifyInfo; -use datafusion_optimizer::simplify_expressions::ExprSimplifier; -use datafusion_physical_expr::create_physical_expr; mod avro; mod csv; diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index e5e57c0bc893..2927fd01d1b3 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::collections::HashMap; +use std::fmt::{self, Debug, Formatter}; +use std::sync::Arc; +use std::vec; + use arrow::array::{ArrayRef, FixedSizeListArray}; use arrow::datatypes::{ DataType, Field, Fields, Int32Type, IntervalDayTimeType, IntervalMonthDayNanoType, @@ -24,6 +30,7 @@ use datafusion::datasource::provider::TableProviderFactory; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::FunctionRegistry; use datafusion::functions_aggregate::covariance::{covar_pop, covar_samp}; use datafusion::functions_aggregate::expr_fn::first_value; use datafusion::prelude::*; @@ -51,16 +58,11 @@ use datafusion_proto::bytes::{ logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use datafusion_proto::logical_plan::to_proto::serialize_expr; -use datafusion_proto::logical_plan::LogicalExtensionCodec; -use datafusion_proto::logical_plan::{from_proto, DefaultLogicalExtensionCodec}; +use datafusion_proto::logical_plan::{ + from_proto, DefaultLogicalExtensionCodec, LogicalExtensionCodec, +}; use datafusion_proto::protobuf; -use std::any::Any; -use std::collections::HashMap; -use std::fmt::{self, Debug, Formatter}; -use std::sync::Arc; -use std::vec; -use datafusion::execution::FunctionRegistry; use prost::Message; #[cfg(feature = "json")] diff --git a/datafusion/sql/src/parser.rs b/datafusion/sql/src/parser.rs index f61c9cda6345..d09317271d23 100644 --- a/datafusion/sql/src/parser.rs +++ b/datafusion/sql/src/parser.rs @@ -462,7 +462,21 @@ impl<'a> DFParser<'a> { pub fn parse_option_key(&mut self) -> Result { let next_token = self.parser.next_token(); match next_token.token { - Token::Word(Word { value, .. }) => Ok(value), + Token::Word(Word { value, .. }) => { + let mut parts = vec![value]; + while self.parser.consume_token(&Token::Period) { + let next_token = self.parser.next_token(); + if let Token::Word(Word { value, .. }) = next_token.token { + parts.push(value); + } else { + // Unquoted namespaced keys have to conform to the syntax + // "[\.]*". If we have a key that breaks this + // pattern, error out: + return self.parser.expected("key name", next_token); + } + } + Ok(parts.join(".")) + } Token::SingleQuotedString(s) => Ok(s), Token::DoubleQuotedString(s) => Ok(s), Token::EscapedStringLiteral(s) => Ok(s), @@ -712,15 +726,15 @@ impl<'a> DFParser<'a> { } else { self.parser.expect_keyword(Keyword::HEADER)?; self.parser.expect_keyword(Keyword::ROW)?; - return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS ('format.has_header' 'true')"); + return parser_err!("WITH HEADER ROW clause is no longer in use. Please use the OPTIONS clause with 'format.has_header' set appropriately, e.g., OPTIONS (format.has_header true)"); } } Keyword::DELIMITER => { - return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS ('format.delimiter' ',')"); + return parser_err!("DELIMITER clause is no longer in use. Please use the OPTIONS clause with 'format.delimiter' set appropriately, e.g., OPTIONS (format.delimiter ',')"); } Keyword::COMPRESSION => { self.parser.expect_keyword(Keyword::TYPE)?; - return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS ('format.compression' 'gzip')"); + return parser_err!("COMPRESSION TYPE clause is no longer in use. Please use the OPTIONS clause with 'format.compression' set appropriately, e.g., OPTIONS (format.compression gzip)"); } Keyword::PARTITIONED => { self.parser.expect_keyword(Keyword::BY)?; @@ -933,7 +947,7 @@ mod tests { expect_parse_ok(sql, expected)?; // positive case with delimiter - let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS ('format.delimiter' '|')"; + let sql = "CREATE EXTERNAL TABLE t(c1 int) STORED AS CSV LOCATION 'foo.csv' OPTIONS (format.delimiter '|')"; let display = None; let expected = Statement::CreateExternalTable(CreateExternalTable { name: "t".into(), diff --git a/datafusion/sqllogictest/test_files/create_external_table.slt b/datafusion/sqllogictest/test_files/create_external_table.slt index fca177bb61f0..607c909fd63d 100644 --- a/datafusion/sqllogictest/test_files/create_external_table.slt +++ b/datafusion/sqllogictest/test_files/create_external_table.slt @@ -190,8 +190,8 @@ LOCATION 'test_files/scratch/create_external_table/manual_partitioning/'; statement error DataFusion error: Error during planning: Option format.delimiter is specified multiple times CREATE EXTERNAL TABLE t STORED AS CSV OPTIONS ( 'format.delimiter' '*', - 'format.has_header' 'true', - 'format.delimiter' '|') + 'format.has_header' 'true', + 'format.delimiter' '|') LOCATION 'foo.csv'; # If a config does not belong to any namespace, we assume it is a 'format' option and apply the 'format' prefix for backwards compatibility. @@ -201,7 +201,20 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ( 'format.delimiter' '|', - 'has_header' 'false'); \ No newline at end of file + 'has_header' 'false'); + +# Verify that we do not need quotations for simple namespaced keys. +statement ok +CREATE EXTERNAL TABLE IF NOT EXISTS region ( + r_regionkey BIGINT, + r_name VARCHAR, + r_comment VARCHAR, + r_rev VARCHAR, +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' +OPTIONS ( + format.delimiter '|', + has_header false, + compression gzip); diff --git a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part index 111d24773055..75bcbc198bef 100644 --- a/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part +++ b/datafusion/sqllogictest/test_files/tpch/create_tables.slt.part @@ -121,4 +121,4 @@ CREATE EXTERNAL TABLE IF NOT EXISTS region ( r_name VARCHAR, r_comment VARCHAR, r_rev VARCHAR, -) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ('format.delimiter' '|'); \ No newline at end of file +) STORED AS CSV LOCATION 'test_files/tpch/data/region.tbl' OPTIONS ('format.delimiter' '|'); From 18fc37629250d22faa6ead109725ebf94a4fa532 Mon Sep 17 00:00:00 2001 From: Jonah Gao Date: Tue, 14 May 2024 07:32:29 +0800 Subject: [PATCH 10/11] feat: allow `array_slice` to take an optional stride parameter (#10469) * feat: allow array_slice to take an optional stride parameter * Use ScalarUDF::call * Use create_function and add test * format * fix cargo doc --- datafusion/functions-array/src/array_has.rs | 6 +-- datafusion/functions-array/src/cardinality.rs | 2 +- datafusion/functions-array/src/concat.rs | 6 +-- datafusion/functions-array/src/dimension.rs | 4 +- datafusion/functions-array/src/empty.rs | 2 +- datafusion/functions-array/src/except.rs | 2 +- datafusion/functions-array/src/extract.rs | 23 +++++----- datafusion/functions-array/src/flatten.rs | 2 +- datafusion/functions-array/src/length.rs | 2 +- datafusion/functions-array/src/macros.rs | 44 +++++++++---------- datafusion/functions-array/src/make_array.rs | 2 +- datafusion/functions-array/src/position.rs | 4 +- datafusion/functions-array/src/range.rs | 4 +- datafusion/functions-array/src/remove.rs | 6 +-- datafusion/functions-array/src/repeat.rs | 2 +- datafusion/functions-array/src/replace.rs | 6 +-- datafusion/functions-array/src/resize.rs | 2 +- datafusion/functions-array/src/reverse.rs | 2 +- datafusion/functions-array/src/rewrite.rs | 2 +- datafusion/functions-array/src/set_ops.rs | 6 +-- datafusion/functions-array/src/sort.rs | 2 +- datafusion/functions-array/src/string.rs | 4 +- .../tests/cases/roundtrip_logical_plan.rs | 6 +++ 23 files changed, 74 insertions(+), 67 deletions(-) diff --git a/datafusion/functions-array/src/array_has.rs b/datafusion/functions-array/src/array_has.rs index e5e8add95fbe..43d6046f4f82 100644 --- a/datafusion/functions-array/src/array_has.rs +++ b/datafusion/functions-array/src/array_has.rs @@ -34,19 +34,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayHas, +make_udf_expr_and_func!(ArrayHas, array_has, first_array second_array, // arg name "returns true, if the element appears in the first array, otherwise false.", // doc array_has_udf // internal function name ); -make_udf_function!(ArrayHasAll, +make_udf_expr_and_func!(ArrayHasAll, array_has_all, first_array second_array, // arg name "returns true if each element of the second array appears in the first array; otherwise, it returns false.", // doc array_has_all_udf // internal function name ); -make_udf_function!(ArrayHasAny, +make_udf_expr_and_func!(ArrayHasAny, array_has_any, first_array second_array, // arg name "returns true if at least one element of the second array appears in the first array; otherwise, it returns false.", // doc diff --git a/datafusion/functions-array/src/cardinality.rs b/datafusion/functions-array/src/cardinality.rs index ed9f8d01f973..d6f2456313bc 100644 --- a/datafusion/functions-array/src/cardinality.rs +++ b/datafusion/functions-array/src/cardinality.rs @@ -29,7 +29,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Cardinality, cardinality, array, diff --git a/datafusion/functions-array/src/concat.rs b/datafusion/functions-array/src/concat.rs index f9d9bf4356ff..a6fed84fa765 100644 --- a/datafusion/functions-array/src/concat.rs +++ b/datafusion/functions-array/src/concat.rs @@ -36,7 +36,7 @@ use datafusion_expr::{ use crate::utils::{align_array_dimensions, check_datatypes, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayAppend, array_append, array element, // arg name @@ -96,7 +96,7 @@ impl ScalarUDFImpl for ArrayAppend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayPrepend, array_prepend, element array, @@ -156,7 +156,7 @@ impl ScalarUDFImpl for ArrayPrepend { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayConcat, array_concat, "Concatenates arrays.", diff --git a/datafusion/functions-array/src/dimension.rs b/datafusion/functions-array/src/dimension.rs index 569eff66f7f4..1dc6520f1bc7 100644 --- a/datafusion/functions-array/src/dimension.rs +++ b/datafusion/functions-array/src/dimension.rs @@ -33,7 +33,7 @@ use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayDims, array_dims, array, @@ -88,7 +88,7 @@ impl ScalarUDFImpl for ArrayDims { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayNdims, array_ndims, array, diff --git a/datafusion/functions-array/src/empty.rs b/datafusion/functions-array/src/empty.rs index d5fa174eee5f..9fe2c870496b 100644 --- a/datafusion/functions-array/src/empty.rs +++ b/datafusion/functions-array/src/empty.rs @@ -28,7 +28,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayEmpty, array_empty, array, diff --git a/datafusion/functions-array/src/except.rs b/datafusion/functions-array/src/except.rs index 444c7c758771..a56bab1e0611 100644 --- a/datafusion/functions-array/src/except.rs +++ b/datafusion/functions-array/src/except.rs @@ -31,7 +31,7 @@ use std::any::Any; use std::collections::HashSet; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayExcept, array_except, first_array second_array, diff --git a/datafusion/functions-array/src/extract.rs b/datafusion/functions-array/src/extract.rs index 0dbd106b6f18..842f4ec1b839 100644 --- a/datafusion/functions-array/src/extract.rs +++ b/datafusion/functions-array/src/extract.rs @@ -44,7 +44,7 @@ use std::sync::Arc; use crate::utils::make_scalar_function; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayElement, array_element, array element, @@ -52,15 +52,9 @@ make_udf_function!( array_element_udf ); -make_udf_function!( - ArraySlice, - array_slice, - array begin end stride, - "returns a slice of the array.", - array_slice_udf -); +create_func!(ArraySlice, array_slice_udf); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopFront, array_pop_front, array, @@ -68,7 +62,7 @@ make_udf_function!( array_pop_front_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayPopBack, array_pop_back, array, @@ -224,6 +218,15 @@ where Ok(arrow::array::make_array(data)) } +#[doc = "returns a slice of the array."] +pub fn array_slice(array: Expr, begin: Expr, end: Expr, stride: Option) -> Expr { + let args = match stride { + Some(stride) => vec![array, begin, end, stride], + None => vec![array, begin, end], + }; + array_slice_udf().call(args) +} + #[derive(Debug)] pub(super) struct ArraySlice { signature: Signature, diff --git a/datafusion/functions-array/src/flatten.rs b/datafusion/functions-array/src/flatten.rs index e2b50c6c02cc..294d41ada7c3 100644 --- a/datafusion/functions-array/src/flatten.rs +++ b/datafusion/functions-array/src/flatten.rs @@ -31,7 +31,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Flatten, flatten, array, diff --git a/datafusion/functions-array/src/length.rs b/datafusion/functions-array/src/length.rs index 9bbd11950d21..9cdcaddf8dff 100644 --- a/datafusion/functions-array/src/length.rs +++ b/datafusion/functions-array/src/length.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayLength, array_length, array, diff --git a/datafusion/functions-array/src/macros.rs b/datafusion/functions-array/src/macros.rs index c49f5830b8d5..4e00aa39bd84 100644 --- a/datafusion/functions-array/src/macros.rs +++ b/datafusion/functions-array/src/macros.rs @@ -19,8 +19,8 @@ /// /// 1. Single `ScalarUDF` instance /// -/// Creates a singleton `ScalarUDF` of the `$UDF` function named `$GNAME` and a -/// function named `$NAME` which returns that function named $NAME. +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. /// /// This is used to ensure creating the list of `ScalarUDF` only happens once. /// @@ -41,10 +41,9 @@ /// * `arg`: 0 or more named arguments for the function /// * `DOC`: documentation string for the function /// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` -/// * `GNAME`: name for the single static instance of the `ScalarUDF` /// /// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl -macro_rules! make_udf_function { +macro_rules! make_udf_expr_and_func { ($UDF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr , $SCALAR_UDF_FN:ident) => { paste::paste! { // "fluent expr_fn" style function @@ -55,25 +54,7 @@ macro_rules! make_udf_function { vec![$($arg),*], )) } - - /// Singleton instance of [`$UDF`], ensures the UDF is only created once - /// named STATIC_$(UDF). For example `STATIC_ArrayToString` - #[allow(non_upper_case_globals)] - static [< STATIC_ $UDF >]: std::sync::OnceLock> = - std::sync::OnceLock::new(); - - /// ScalarFunction that returns a [`ScalarUDF`] for [`$UDF`] - /// - /// [`ScalarUDF`]: datafusion_expr::ScalarUDF - pub fn $SCALAR_UDF_FN() -> std::sync::Arc { - [< STATIC_ $UDF >] - .get_or_init(|| { - std::sync::Arc::new(datafusion_expr::ScalarUDF::new_from_impl( - <$UDF>::new(), - )) - }) - .clone() - } + create_func!($UDF, $SCALAR_UDF_FN); } }; ($UDF:ty, $EXPR_FN:ident, $DOC:expr , $SCALAR_UDF_FN:ident) => { @@ -86,7 +67,24 @@ macro_rules! make_udf_function { arg, )) } + create_func!($UDF, $SCALAR_UDF_FN); + } + }; +} +/// Creates a singleton `ScalarUDF` of the `$UDF` function named `STATIC_$(UDF)` and a +/// function named `$SCALAR_UDF_FUNC` which returns that function named `STATIC_$(UDF)`. +/// +/// This is used to ensure creating the list of `ScalarUDF` only happens once. +/// +/// # Arguments +/// * `UDF`: name of the [`ScalarUDFImpl`] +/// * `SCALAR_UDF_FUNC`: name of the function to create (just) the `ScalarUDF` +/// +/// [`ScalarUDFImpl`]: datafusion_expr::ScalarUDFImpl +macro_rules! create_func { + ($UDF:ty, $SCALAR_UDF_FN:ident) => { + paste::paste! { /// Singleton instance of [`$UDF`], ensures the UDF is only created once /// named STATIC_$(UDF). For example `STATIC_ArrayToString` #[allow(non_upper_case_globals)] diff --git a/datafusion/functions-array/src/make_array.rs b/datafusion/functions-array/src/make_array.rs index 4f7dda933f42..4723464dfaf2 100644 --- a/datafusion/functions-array/src/make_array.rs +++ b/datafusion/functions-array/src/make_array.rs @@ -35,7 +35,7 @@ use datafusion_expr::{Expr, TypeSignature}; use crate::utils::make_scalar_function; -make_udf_function!( +make_udf_expr_and_func!( MakeArray, make_array, "Returns an Arrow array using the specified input expressions.", diff --git a/datafusion/functions-array/src/position.rs b/datafusion/functions-array/src/position.rs index a5a7a7405aa9..efdb7dff0ce6 100644 --- a/datafusion/functions-array/src/position.rs +++ b/datafusion/functions-array/src/position.rs @@ -37,7 +37,7 @@ use itertools::Itertools; use crate::utils::{compare_element_to_list, make_scalar_function}; -make_udf_function!( +make_udf_expr_and_func!( ArrayPosition, array_position, array element index, @@ -168,7 +168,7 @@ fn generic_position( Ok(Arc::new(UInt64Array::from(data))) } -make_udf_function!( +make_udf_expr_and_func!( ArrayPositions, array_positions, array element, // arg name diff --git a/datafusion/functions-array/src/range.rs b/datafusion/functions-array/src/range.rs index 150fe5960266..9a9829f96100 100644 --- a/datafusion/functions-array/src/range.rs +++ b/datafusion/functions-array/src/range.rs @@ -35,7 +35,7 @@ use datafusion_expr::{ use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( Range, range, start stop step, @@ -106,7 +106,7 @@ impl ScalarUDFImpl for Range { } } -make_udf_function!( +make_udf_expr_and_func!( GenSeries, gen_series, start stop step, diff --git a/datafusion/functions-array/src/remove.rs b/datafusion/functions-array/src/remove.rs index 21e373081054..7645c1a57573 100644 --- a/datafusion/functions-array/src/remove.rs +++ b/datafusion/functions-array/src/remove.rs @@ -32,7 +32,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRemove, array_remove, array element, @@ -81,7 +81,7 @@ impl ScalarUDFImpl for ArrayRemove { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveN, array_remove_n, array element max, @@ -130,7 +130,7 @@ impl ScalarUDFImpl for ArrayRemoveN { } } -make_udf_function!( +make_udf_expr_and_func!( ArrayRemoveAll, array_remove_all, array element, diff --git a/datafusion/functions-array/src/repeat.rs b/datafusion/functions-array/src/repeat.rs index 89b766bdcdfc..df623c114818 100644 --- a/datafusion/functions-array/src/repeat.rs +++ b/datafusion/functions-array/src/repeat.rs @@ -34,7 +34,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayRepeat, array_repeat, element count, // arg name diff --git a/datafusion/functions-array/src/replace.rs b/datafusion/functions-array/src/replace.rs index c32305bb454b..7cea4945836e 100644 --- a/datafusion/functions-array/src/replace.rs +++ b/datafusion/functions-array/src/replace.rs @@ -38,19 +38,19 @@ use std::any::Any; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!(ArrayReplace, +make_udf_expr_and_func!(ArrayReplace, array_replace, array from to, "replaces the first occurrence of the specified element with another specified element.", array_replace_udf ); -make_udf_function!(ArrayReplaceN, +make_udf_expr_and_func!(ArrayReplaceN, array_replace_n, array from to max, "replaces the first `max` occurrences of the specified element with another specified element.", array_replace_n_udf ); -make_udf_function!(ArrayReplaceAll, +make_udf_expr_and_func!(ArrayReplaceAll, array_replace_all, array from to, "replaces all occurrences of the specified element with another specified element.", diff --git a/datafusion/functions-array/src/resize.rs b/datafusion/functions-array/src/resize.rs index 561e98e8b76f..63f28c9afa77 100644 --- a/datafusion/functions-array/src/resize.rs +++ b/datafusion/functions-array/src/resize.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayResize, array_resize, array size value, diff --git a/datafusion/functions-array/src/reverse.rs b/datafusion/functions-array/src/reverse.rs index 9be640565703..3076013899ef 100644 --- a/datafusion/functions-array/src/reverse.rs +++ b/datafusion/functions-array/src/reverse.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArrayReverse, array_reverse, array, diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 416e79cbc079..5280355a8224 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -171,7 +171,7 @@ impl FunctionRewrite for ArrayFunctionRewriter { stop, stride, }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, *stride)), + }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), _ => Transformed::no(expr), }; diff --git a/datafusion/functions-array/src/set_ops.rs b/datafusion/functions-array/src/set_ops.rs index 5f3087fafd6f..40676b7cdcb8 100644 --- a/datafusion/functions-array/src/set_ops.rs +++ b/datafusion/functions-array/src/set_ops.rs @@ -37,7 +37,7 @@ use std::fmt::{Display, Formatter}; use std::sync::Arc; // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayUnion, array_union, array1 array2, @@ -45,7 +45,7 @@ make_udf_function!( array_union_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayIntersect, array_intersect, first_array second_array, @@ -53,7 +53,7 @@ make_udf_function!( array_intersect_udf ); -make_udf_function!( +make_udf_expr_and_func!( ArrayDistinct, array_distinct, array, diff --git a/datafusion/functions-array/src/sort.rs b/datafusion/functions-array/src/sort.rs index af78712065fc..16f271ef10ff 100644 --- a/datafusion/functions-array/src/sort.rs +++ b/datafusion/functions-array/src/sort.rs @@ -30,7 +30,7 @@ use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility} use std::any::Any; use std::sync::Arc; -make_udf_function!( +make_udf_expr_and_func!( ArraySort, array_sort, array desc null_first, diff --git a/datafusion/functions-array/src/string.rs b/datafusion/functions-array/src/string.rs index 38059035005b..4122ddbd45eb 100644 --- a/datafusion/functions-array/src/string.rs +++ b/datafusion/functions-array/src/string.rs @@ -102,7 +102,7 @@ macro_rules! call_array_function { } // Create static instances of ScalarUDFs for each function -make_udf_function!( +make_udf_expr_and_func!( ArrayToString, array_to_string, array delimiter, // arg name @@ -160,7 +160,7 @@ impl ScalarUDFImpl for ArrayToString { } } -make_udf_function!( +make_udf_expr_and_func!( StringToArray, string_to_array, string delimiter null_string, // arg name diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 2927fd01d1b3..ec215937dca8 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -582,7 +582,13 @@ async fn roundtrip_expr_api() -> Result<()> { make_array(vec![lit(1), lit(2), lit(3)]), lit(1), lit(2), + Some(lit(1)), + ), + array_slice( + make_array(vec![lit(1), lit(2), lit(3)]), lit(1), + lit(2), + None, ), array_pop_front(make_array(vec![lit(1), lit(2), lit(3)])), array_pop_back(make_array(vec![lit(1), lit(2), lit(3)])), From b8fab5cdf418e1fba5e6012b815a5bc40c7771cc Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Tue, 14 May 2024 10:22:28 +0800 Subject: [PATCH 11/11] Replace `GetFieldAccess` with indexing function in `SqlToRel ` (#10375) * use func in parser Signed-off-by: jayzhan211 * add tests Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * rm test1 Signed-off-by: jayzhan211 * parser done Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * fix exprapi test Signed-off-by: jayzhan211 * fix test Signed-off-by: jayzhan211 * fix conflicts Signed-off-by: jayzhan211 --------- Signed-off-by: jayzhan211 --- datafusion/core/tests/expr_api/mod.rs | 14 +-- datafusion/functions-array/src/rewrite.rs | 29 +---- datafusion/sql/src/expr/identifier.rs | 17 ++- datafusion/sql/src/expr/mod.rs | 48 ++++++++- datafusion/sqllogictest/test_files/expr.slt | 114 +++++++++++++++++++- 5 files changed, 172 insertions(+), 50 deletions(-) diff --git a/datafusion/core/tests/expr_api/mod.rs b/datafusion/core/tests/expr_api/mod.rs index 0dde7604cce2..d7e839824b3b 100644 --- a/datafusion/core/tests/expr_api/mod.rs +++ b/datafusion/core/tests/expr_api/mod.rs @@ -60,9 +60,8 @@ fn test_eq_with_coercion() { #[test] fn test_get_field() { - // field access Expr::field() requires a rewrite to work evaluate_expr_test( - col("props").field("a"), + get_field(col("props"), lit("a")), vec![ "+------------+", "| expr |", @@ -77,11 +76,8 @@ fn test_get_field() { #[test] fn test_nested_get_field() { - // field access Expr::field() requires a rewrite to work, test when it is - // not the root expression evaluate_expr_test( - col("props") - .field("a") + get_field(col("props"), lit("a")) .eq(lit("2021-02-02")) .or(col("id").eq(lit(1))), vec![ @@ -98,9 +94,8 @@ fn test_nested_get_field() { #[test] fn test_list() { - // list access also requires a rewrite to work evaluate_expr_test( - col("list").index(lit(1i64)), + array_element(col("list"), lit(1i64)), vec![ "+------+", "| expr |", "+------+", "| one |", "| two |", "| five |", "+------+", @@ -110,9 +105,8 @@ fn test_list() { #[test] fn test_list_range() { - // range access also requires a rewrite to work evaluate_expr_test( - col("list").range(lit(1i64), lit(2i64)), + array_slice(col("list"), lit(1i64), lit(2i64), None), vec![ "+--------------+", "| expr |", diff --git a/datafusion/functions-array/src/rewrite.rs b/datafusion/functions-array/src/rewrite.rs index 5280355a8224..a7aba78c1dbe 100644 --- a/datafusion/functions-array/src/rewrite.rs +++ b/datafusion/functions-array/src/rewrite.rs @@ -19,7 +19,6 @@ use crate::array_has::array_has_all; use crate::concat::{array_append, array_concat, array_prepend}; -use crate::extract::{array_element, array_slice}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::Transformed; use datafusion_common::utils::list_ndims; @@ -27,8 +26,7 @@ use datafusion_common::Result; use datafusion_common::{Column, DFSchema}; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::expr_rewriter::FunctionRewrite; -use datafusion_expr::{BinaryExpr, Expr, GetFieldAccess, GetIndexedField, Operator}; -use datafusion_functions::expr_fn::get_field; +use datafusion_expr::{BinaryExpr, Expr, Operator}; /// Rewrites expressions into function calls to array functions pub(crate) struct ArrayFunctionRewriter {} @@ -148,31 +146,6 @@ impl FunctionRewrite for ArrayFunctionRewriter { Transformed::yes(array_prepend(*left, *right)) } - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::NamedStructField { name }, - }) => { - let name = Expr::Literal(name); - Transformed::yes(get_field(*expr, name)) - } - - // expr[idx] ==> array_element(expr, idx) - Expr::GetIndexedField(GetIndexedField { - expr, - field: GetFieldAccess::ListIndex { key }, - }) => Transformed::yes(array_element(*expr, *key)), - - // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) - Expr::GetIndexedField(GetIndexedField { - expr, - field: - GetFieldAccess::ListRange { - start, - stop, - stride, - }, - }) => Transformed::yes(array_slice(*expr, *start, *stop, Some(*stride))), - _ => Transformed::no(expr), }; Ok(transformed) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 713ad6f72c24..d297b2e4df5b 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -19,9 +19,9 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use arrow_schema::Field; use datafusion_common::{ internal_err, plan_datafusion_err, Column, DFSchema, DataFusionError, Result, - TableReference, + ScalarValue, TableReference, }; -use datafusion_expr::{Case, Expr}; +use datafusion_expr::{expr::ScalarFunction, lit, Case, Expr}; use sqlparser::ast::{Expr as SQLExpr, Ident}; impl<'a, S: ContextProvider> SqlToRel<'a, S> { @@ -133,7 +133,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ); } let nested_name = nested_names[0].to_string(); - Ok(Expr::Column(Column::from((qualifier, field))).field(nested_name)) + + let col = Expr::Column(Column::from((qualifier, field))); + if let Some(udf) = + self.context_provider.get_function_meta("get_field") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![col, lit(ScalarValue::from(nested_name))], + ))) + } else { + internal_err!("get_field not found") + } } // found matching field with no spare identifier(s) Some((field, qualifier, _nested_names)) => { diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ed5421edfbb0..6445c3f7a885 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -29,7 +29,7 @@ use datafusion_expr::expr::InList; use datafusion_expr::expr::ScalarFunction; use datafusion_expr::{ col, expr, lit, AggregateFunction, Between, BinaryExpr, Cast, Expr, ExprSchemable, - GetFieldAccess, GetIndexedField, Like, Literal, Operator, TryCast, + GetFieldAccess, Like, Literal, Operator, TryCast, }; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; @@ -1019,10 +1019,48 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr }; - Ok(Expr::GetIndexedField(GetIndexedField::new( - Box::new(expr), - self.plan_indices(indices, schema, planner_context)?, - ))) + let field = self.plan_indices(indices, schema, planner_context)?; + match field { + GetFieldAccess::NamedStructField { name } => { + if let Some(udf) = self.context_provider.get_function_meta("get_field") { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, lit(name)], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[idx] ==> array_element(expr, idx) + GetFieldAccess::ListIndex { key } => { + if let Some(udf) = + self.context_provider.get_function_meta("array_element") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *key], + ))) + } else { + internal_err!("get_field not found") + } + } + // expr[start, stop, stride] ==> array_slice(expr, start, stop, stride) + GetFieldAccess::ListRange { + start, + stop, + stride, + } => { + if let Some(udf) = self.context_provider.get_function_meta("array_slice") + { + Ok(Expr::ScalarFunction(ScalarFunction::new_udf( + udf, + vec![expr, *start, *stop, *stride], + ))) + } else { + internal_err!("array_slice not found") + } + } + } } } diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 4b5f4d770a03..2dc00cbc5001 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -2324,28 +2324,134 @@ host3 3.3 # can have an aggregate function with an inner CASE WHEN query TR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 101 host2 202 host3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with different short-circuited expr query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(coalesce(server_load1)), + sum(( + case when t2.server_host is not null + then t2.server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 -# can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. CASE WHEN) +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(coalesce(server['c1'])), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server, + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; + query TRR -select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c1']), sum((case when t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] is not null then t2."struct(t1.time,t1.load1,t1.load2,t1.host)" end)['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host; +select + t2.server_host as host, + sum(( + case when t2.server_host is not null + then server_load1 + end + )), + sum(( + case when server_host is not null + then server_load2 + end + )) + from ( + select + struct(time,load1,load2,host)['c1'] as server_load1, + struct(time,load1,load2,host)['c2'] as server_load2, + struct(time,load1,load2,host)['c3'] as server_host + from t1 + ) t2 + where server_host IS NOT NULL + group by server_host order by host; ---- host1 1.1 101 host2 2.2 202 host3 3.3 303 +# TODO: Issue tracked in https://github.com/apache/datafusion/issues/10364 +query error +select + t2.server['c3'] as host, + sum(( + case when t2.server['c3'] is not null + then t2.server['c1'] + end + )), + sum(( + case when t2.server['c3'] is not null + then t2.server['c2'] + end + )) + from ( + select + struct(time,load1,load2,host) as server + from t1 + ) t2 + where t2.server['c3'] IS NOT NULL + group by t2.server['c3'] order by host; + # can have 2 projections with aggr(short_circuited), with the same short-circuited expr (e.g. coalesce) query TRR select t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] as host, sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c1']), sum(coalesce(t2."struct(t1.time,t1.load1,t1.load2,t1.host)")['c2']) from (select struct(time,load1,load2,host) from t1) t2 where t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] IS NOT NULL group by t2."struct(t1.time,t1.load1,t1.load2,t1.host)"['c3'] order by host;