diff --git a/rust/datafusion/examples/simple_udf.rs b/rust/datafusion/examples/simple_udf.rs new file mode 100644 index 0000000000000..a8d9fcef498ee --- /dev/null +++ b/rust/datafusion/examples/simple_udf.rs @@ -0,0 +1,138 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder}, + datatypes::DataType, + record_batch::RecordBatch, + util::pretty, +}; + +use datafusion::error::Result; +use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*}; +use std::sync::Arc; + +// create local execution context with an in-memory table +fn create_context() -> Result { + use arrow::datatypes::{Field, Schema}; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Float32, false), + Field::new("b", DataType::Float64, false), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])), + Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])), + ], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let mut ctx = ExecutionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::new(schema, vec![vec![batch]])?; + ctx.register_table("t", Box::new(provider)); + Ok(ctx) +} + +/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b +fn main() -> Result<()> { + let mut ctx = create_context()?; + + // First, declare the actual implementation of the calculation + let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { + // in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to: + // 1. cast the values to the type we want + // 2. perform the computation for every element in the array (using a loop or SIMD) + // 3. construct the resulting array + + // this is guaranteed by DataFusion based on the function's signature. + assert_eq!(args.len(), 2); + + // 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics! + let base = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let exponent = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + // this is guaranteed by DataFusion. We place it just to make it obvious. + assert_eq!(exponent.len(), base.len()); + + // 2. Arrow's builder is used to construct an Arrow array. + let mut builder = Float64Builder::new(base.len()); + for index in 0..base.len() { + // in arrow, any value can be null. + // Here we decide to make our UDF to return null when either base or exponent is null. + if base.is_null(index) || exponent.is_null(index) { + builder.append_null()?; + } else { + // 3. computation. Since we do not have any SIMD `pow` operation at our hands, + // we loop over each entry. Array's values are obtained via `.value(index)`. + let value = base.value(index).powf(exponent.value(index)); + builder.append_value(value)?; + } + } + Ok(Arc::new(builder.finish())) + }); + + // Next: + // * give it a name so that it shows nicely when the plan is printed + // * declare what input it expects + // * declare its return type + let pow = create_udf( + "pow", + // expects two f64 + vec![DataType::Float64, DataType::Float64], + // returns f64 + Arc::new(DataType::Float64), + pow, + ); + + // finally, register the UDF + ctx.register_udf(pow); + + // at this point, we can use it. Note that the code below can be in a + // scope on which we do not have access to `pow`. + + // get a DataFrame from the context + let df = ctx.table("t")?; + + // get the udf registry. + let f = df.registry(); + + // equivalent to `'SELECT pow(a, b) FROM t'` + let df = df.select(vec![f.udf("pow", vec![col("a"), col("b")])?])?; + + // note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature. + + // execute the query + let results = df.collect()?; + + // print the results + pretty::print_batches(&results)?; + + Ok(()) +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 8fd2531d257d8..0e886c52031a4 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -40,13 +40,11 @@ use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilde use crate::optimizer::filter_push_down::FilterPushDown; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; -use crate::optimizer::type_coercion::TypeCoercionRule; use crate::physical_plan::common; use crate::physical_plan::csv::CsvReadOptions; use crate::physical_plan::merge::MergeExec; use crate::physical_plan::planner::DefaultPhysicalPlanner; -use crate::physical_plan::udf::ScalarFunction; -use crate::physical_plan::udf::ScalarFunctionRegistry; +use crate::physical_plan::udf::ScalarUDF; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::PhysicalPlanner; use crate::sql::{ @@ -180,7 +178,7 @@ impl ExecutionContext { } /// Register a scalar UDF - pub fn register_udf(&mut self, f: ScalarFunction) { + pub fn register_udf(&mut self, f: ScalarUDF) { self.state .scalar_functions .insert(f.name.clone(), Arc::new(f)); @@ -294,7 +292,6 @@ impl ExecutionContext { // Apply standard rewrites and optimizations let mut plan = ProjectionPushDown::new().optimize(&plan)?; plan = FilterPushDown::new().optimize(&plan)?; - plan = TypeCoercionRule::new().optimize(&plan)?; self.state.config.query_planner.rewrite_logical_plan(plan) } @@ -377,12 +374,6 @@ impl ExecutionContext { } } -impl ScalarFunctionRegistry for ExecutionContext { - fn lookup(&self, name: &str) -> Option> { - self.state.scalar_functions.lookup(name) - } -} - /// A planner used to add extensions to DataFusion logical and phusical plans. pub trait QueryPlanner { /// Given a `LogicalPlan`, create a new, modified `LogicalPlan` @@ -468,7 +459,7 @@ pub struct ExecutionContextState { /// Data sources that are registered with the context pub datasources: HashMap>, /// Scalar functions that are registered with the context - pub scalar_functions: HashMap>, + pub scalar_functions: HashMap>, /// Context configuration pub config: ExecutionConfig, } @@ -478,7 +469,7 @@ impl SchemaProvider for ExecutionContextState { self.datasources.get(name).map(|ds| ds.schema().clone()) } - fn get_function_meta(&self, name: &str) -> Option> { + fn get_function_meta(&self, name: &str) -> Option> { self.scalar_functions .get(name) .and_then(|func| Some(func.clone())) @@ -510,8 +501,8 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::logical_plan::{aggregate_expr, col}; - use crate::physical_plan::udf::ScalarUdf; + use crate::logical_plan::{aggregate_expr, col, create_udf}; + use crate::physical_plan::functions::ScalarFunctionImplementation; use crate::test; use arrow::array::{ArrayRef, Int32Array}; use arrow::compute::add; @@ -1032,7 +1023,7 @@ mod tests { let provider = MemTable::new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Box::new(provider)); - let myfunc: ScalarUdf = Arc::new(|args: &[ArrayRef]| { + let myfunc: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| { let l = &args[0] .as_any() .downcast_ref::() @@ -1044,14 +1035,12 @@ mod tests { Ok(Arc::new(add(l, r)?)) }); - let my_add = ScalarFunction::new( + ctx.register_udf(create_udf( "my_add", vec![DataType::Int32, DataType::Int32], - DataType::Int32, + Arc::new(DataType::Int32), myfunc, - ); - - ctx.register_udf(my_add); + )); // from here on, we may be in a different scope. We would still like to be able // to call UDFs. diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 57c4df346bdc1..bf25d2a142e89 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -136,15 +136,8 @@ mod tests { use crate::datasource::csv::CsvReadOptions; use crate::execution::context::ExecutionContext; use crate::logical_plan::*; - use crate::{ - physical_plan::udf::{ScalarFunction, ScalarUdf}, - test, - }; - use arrow::{ - array::{ArrayRef, Float64Array}, - compute::add, - datatypes::DataType, - }; + use crate::{physical_plan::functions::ScalarFunctionImplementation, test}; + use arrow::{array::ArrayRef, datatypes::DataType}; #[test] fn select_columns() -> Result<()> { @@ -250,39 +243,28 @@ mod tests { register_aggregate_csv(&mut ctx)?; // declare the udf - let my_add: ScalarUdf = Arc::new(|args: &[ArrayRef]| { - let l = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - let r = &args[1] - .as_any() - .downcast_ref::() - .expect("cast failed"); - Ok(Arc::new(add(l, r)?)) - }); - - let my_add = ScalarFunction::new( - "my_add", - vec![DataType::Float64], - DataType::Float64, - my_add, - ); + let my_fn: ScalarFunctionImplementation = + Arc::new(|_: &[ArrayRef]| unimplemented!("my_fn is not implemented")); - // register the udf - ctx.register_udf(my_add); + // create and register the udf + ctx.register_udf(create_udf( + "my_fn", + vec![DataType::Float64], + Arc::new(DataType::Float64), + my_fn, + )); // build query with a UDF using DataFrame API let df = ctx.table("aggregate_test_100")?; let f = df.registry(); - let df = df.select(vec![f.udf("my_add", vec![col("c12")])?])?; + let df = df.select(vec![f.udf("my_fn", vec![col("c12")])?])?; let plan = df.to_logical_plan(); // build query using SQL let sql_plan = - ctx.create_logical_plan("SELECT my_add(c12) FROM aggregate_test_100")?; + ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?; // the two plans should be identical assert_same_plan(&plan, &sql_plan); diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index c894c833cd075..29ebaa194faca 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -21,6 +21,7 @@ //! Logical query plans can then be optimized and executed directly, or translated into //! physical query plans and executed. +use fmt::Debug; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; @@ -29,15 +30,15 @@ use crate::datasource::csv::{CsvFile, CsvReadOptions}; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; -use crate::physical_plan::udf; use crate::{ physical_plan::{ - expressions::binary_operator_data_type, functions, type_coercion::can_coerce_from, + expressions::binary_operator_data_type, functions, + type_coercion::can_coerce_from, udf::ScalarUDF, }, sql::parser::FileType, }; use arrow::record_batch::RecordBatch; -use fmt::Debug; +use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; /// Operators applied to expressions #[derive(Debug, Clone, PartialEq, Eq)] @@ -272,14 +273,14 @@ pub enum Expr { /// scalar function. ScalarFunction { /// The function - fun: functions::ScalarFunction, + fun: functions::BuiltinScalarFunction, /// List of expressions to feed to the functions as arguments args: Vec, }, /// scalar udf. ScalarUDF { /// The function - fun: Arc, + fun: Arc, /// List of expressions to feed to the functions as arguments args: Vec, }, @@ -302,7 +303,13 @@ impl Expr { Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => l.get_datatype(), Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::ScalarUDF { fun, .. } => Ok(fun.return_type.clone()), + Expr::ScalarUDF { fun, args } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } Expr::ScalarFunction { fun, args } => { let data_types = args .iter() @@ -636,7 +643,7 @@ macro_rules! unary_math_expr { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { Expr::ScalarFunction { - fun: functions::ScalarFunction::$ENUM, + fun: functions::BuiltinScalarFunction::$ENUM, args: vec![e], } } @@ -665,7 +672,7 @@ unary_math_expr!(Log10, log10); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { Expr::ScalarFunction { - fun: functions::ScalarFunction::Length, + fun: functions::BuiltinScalarFunction::Length, args: vec![e], } } @@ -673,7 +680,7 @@ pub fn length(e: Expr) -> Expr { /// returns the concatenation of string expressions pub fn concat(args: Vec) -> Expr { Expr::ScalarFunction { - fun: functions::ScalarFunction::Concat, + fun: functions::BuiltinScalarFunction::Concat, args, } } @@ -686,6 +693,21 @@ pub fn aggregate_expr(name: &str, expr: Expr) -> Expr { } } +/// Creates a new UDF with a specific signature and specific return type. +/// This is a helper function to create a new UDF. +/// The function `create_udf` returns a subset of all possible `ScalarFunction`: +/// * the UDF has a fixed return type +/// * the UDF has a fixed signature (e.g. [f64, f64]) +pub fn create_udf( + name: &str, + input_types: Vec, + return_type: Arc, + fun: ScalarFunctionImplementation, +) -> ScalarUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + ScalarUDF::new(name, &Signature::Exact(input_types), &return_type, &fun) +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/rust/datafusion/src/optimizer/mod.rs b/rust/datafusion/src/optimizer/mod.rs index 9871acd9024b7..dffae5328b7b2 100644 --- a/rust/datafusion/src/optimizer/mod.rs +++ b/rust/datafusion/src/optimizer/mod.rs @@ -21,5 +21,4 @@ pub mod filter_push_down; pub mod optimizer; pub mod projection_push_down; -pub mod type_coercion; pub mod utils; diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs deleted file mode 100644 index 9e74b6934282e..0000000000000 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ /dev/null @@ -1,120 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! The type_coercion optimizer rule ensures that all operators are operating on -//! compatible types by adding explicit cast operations to expressions. For example, -//! the operation `c_float + c_int` would be rewritten as `c_float + CAST(c_int AS -//! float)`. This keeps the runtime query execution code much simpler. - -use arrow::datatypes::Schema; - -use crate::error::Result; -use crate::logical_plan::Expr; -use crate::logical_plan::LogicalPlan; -use crate::optimizer::optimizer::OptimizerRule; -use crate::optimizer::utils; -use crate::physical_plan::expressions::numerical_coercion; -use utils::optimize_explain; - -/// Optimizer that applies coercion rules to expressions in the logical plan. -/// -/// This optimizer does not alter the structure of the plan, it only changes expressions on it. -pub struct TypeCoercionRule {} - -impl TypeCoercionRule { - /// Create a new type coercion optimizer rule using meta-data about registered - /// scalar functions - pub fn new() -> Self { - Self {} - } - - /// Rewrite an expression to include explicit CAST operations when required - fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result { - let expressions = utils::expr_sub_expressions(expr)?; - - // recurse of the re-write - let mut expressions = expressions - .iter() - .map(|e| self.rewrite_expr(e, schema)) - .collect::>>()?; - - // modify `expressions` by introducing casts when necessary - match expr { - Expr::ScalarUDF { fun, .. } => { - // cast the inputs of scalar functions to the appropriate type where possible - for i in 0..expressions.len() { - let actual_type = expressions[i].get_type(schema)?; - let required_type = &fun.arg_types[i]; - if &actual_type != required_type { - // attempt to coerce using numerical coercion - // todo: also try string coercion. - if let Some(cast_to_type) = - numerical_coercion(&actual_type, required_type) - { - expressions[i] = - expressions[i].cast_to(&cast_to_type, schema)? - }; - // not possible: do nothing and let the plan fail with a clear error message - }; - } - } - _ => {} - }; - utils::rewrite_expression(expr, &expressions) - } -} - -impl OptimizerRule for TypeCoercionRule { - fn optimize(&mut self, plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Explain { - verbose, - plan, - stringified_plans, - schema, - } => optimize_explain(self, *verbose, &*plan, stringified_plans, &*schema), - _ => { - let inputs = utils::inputs(plan); - let expressions = utils::expressions(plan); - - // apply the optimization to all inputs of the plan - let new_inputs = inputs - .iter() - .map(|plan| self.optimize(*plan)) - .collect::>>()?; - // re-write all expressions on this plan. - // This assumes a single input, [0]. It wont work for join, subqueries and union operations with more than one input. - // It is currently not an issue as we do not have any plan with more than one input. - assert!( - expressions.len() == 0 || inputs.len() > 0, - "Assume that all plan nodes with expressions have inputs" - ); - - let new_expressions = expressions - .iter() - .map(|expr| self.rewrite_expr(expr, inputs[0].schema())) - .collect::>>()?; - - utils::from_plan(plan, &new_expressions, &new_inputs) - } - } - } - - fn name(&self) -> &str { - return "type_coercion"; - } -} diff --git a/rust/datafusion/src/physical_plan/functions.rs b/rust/datafusion/src/physical_plan/functions.rs index abd6a325551e6..7f1e4fc4e7d44 100644 --- a/rust/datafusion/src/physical_plan/functions.rs +++ b/rust/datafusion/src/physical_plan/functions.rs @@ -36,16 +36,17 @@ use super::{ use crate::error::{ExecutionError, Result}; use crate::physical_plan::math_expressions; use crate::physical_plan::string_expressions; -use crate::physical_plan::udf; use arrow::{ + array::ArrayRef, compute::kernels::length::length, datatypes::{DataType, Schema}, + record_batch::RecordBatch, }; +use fmt::{Debug, Formatter}; use std::{fmt, str::FromStr, sync::Arc}; -use udf::ScalarUdf; /// A function's signature, which defines the function's supported argument types. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Signature { /// arbitrary number of arguments of an common type out of a list of valid types // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])` @@ -58,11 +59,21 @@ pub enum Signature { // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])` // A function of two arguments of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])` Uniform(usize, Vec), + /// exact number of arguments of an exact type + Exact(Vec), } +/// Scalar function +pub type ScalarFunctionImplementation = + Arc Result + Send + Sync>; + +/// A function's return type +pub type ReturnTypeFunction = + Arc Result> + Send + Sync>; + /// Enum of all built-in scalar functions #[derive(Debug, Clone, PartialEq, Eq)] -pub enum ScalarFunction { +pub enum BuiltinScalarFunction { /// sqrt Sqrt, /// sin @@ -103,36 +114,36 @@ pub enum ScalarFunction { Concat, } -impl fmt::Display for ScalarFunction { +impl fmt::Display for BuiltinScalarFunction { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // lowercase of the debug. write!(f, "{}", format!("{:?}", self).to_lowercase()) } } -impl FromStr for ScalarFunction { +impl FromStr for BuiltinScalarFunction { type Err = ExecutionError; - fn from_str(name: &str) -> Result { + fn from_str(name: &str) -> Result { Ok(match name { - "sqrt" => ScalarFunction::Sqrt, - "sin" => ScalarFunction::Sin, - "cos" => ScalarFunction::Cos, - "tan" => ScalarFunction::Tan, - "asin" => ScalarFunction::Asin, - "acos" => ScalarFunction::Acos, - "atan" => ScalarFunction::Atan, - "exp" => ScalarFunction::Exp, - "log" => ScalarFunction::Log, - "log2" => ScalarFunction::Log2, - "log10" => ScalarFunction::Log10, - "floor" => ScalarFunction::Floor, - "ceil" => ScalarFunction::Ceil, - "round" => ScalarFunction::Round, - "truc" => ScalarFunction::Trunc, - "abs" => ScalarFunction::Abs, - "signum" => ScalarFunction::Signum, - "length" => ScalarFunction::Length, - "concat" => ScalarFunction::Concat, + "sqrt" => BuiltinScalarFunction::Sqrt, + "sin" => BuiltinScalarFunction::Sin, + "cos" => BuiltinScalarFunction::Cos, + "tan" => BuiltinScalarFunction::Tan, + "asin" => BuiltinScalarFunction::Asin, + "acos" => BuiltinScalarFunction::Acos, + "atan" => BuiltinScalarFunction::Atan, + "exp" => BuiltinScalarFunction::Exp, + "log" => BuiltinScalarFunction::Log, + "log2" => BuiltinScalarFunction::Log2, + "log10" => BuiltinScalarFunction::Log10, + "floor" => BuiltinScalarFunction::Floor, + "ceil" => BuiltinScalarFunction::Ceil, + "round" => BuiltinScalarFunction::Round, + "truc" => BuiltinScalarFunction::Trunc, + "abs" => BuiltinScalarFunction::Abs, + "signum" => BuiltinScalarFunction::Signum, + "length" => BuiltinScalarFunction::Length, + "concat" => BuiltinScalarFunction::Concat, _ => { return Err(ExecutionError::General(format!( "There is no built-in function named {}", @@ -144,7 +155,10 @@ impl FromStr for ScalarFunction { } /// Returns the datatype of the scalar function -pub fn return_type(fun: &ScalarFunction, arg_types: &Vec) -> Result { +pub fn return_type( + fun: &BuiltinScalarFunction, + arg_types: &Vec, +) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. @@ -163,8 +177,8 @@ pub fn return_type(fun: &ScalarFunction, arg_types: &Vec) -> Result Ok(DataType::UInt32), - ScalarFunction::Concat => Ok(DataType::Utf8), + BuiltinScalarFunction::Length => Ok(DataType::UInt32), + BuiltinScalarFunction::Concat => Ok(DataType::Utf8), _ => Ok(DataType::Float64), } } @@ -172,30 +186,30 @@ pub fn return_type(fun: &ScalarFunction, arg_types: &Vec) -> Result>, input_schema: &Schema, ) -> Result> { - let fun_expr: ScalarUdf = Arc::new(match fun { - ScalarFunction::Sqrt => math_expressions::sqrt, - ScalarFunction::Sin => math_expressions::sin, - ScalarFunction::Cos => math_expressions::cos, - ScalarFunction::Tan => math_expressions::tan, - ScalarFunction::Asin => math_expressions::asin, - ScalarFunction::Acos => math_expressions::acos, - ScalarFunction::Atan => math_expressions::atan, - ScalarFunction::Exp => math_expressions::exp, - ScalarFunction::Log => math_expressions::ln, - ScalarFunction::Log2 => math_expressions::log2, - ScalarFunction::Log10 => math_expressions::log10, - ScalarFunction::Floor => math_expressions::floor, - ScalarFunction::Ceil => math_expressions::ceil, - ScalarFunction::Round => math_expressions::round, - ScalarFunction::Trunc => math_expressions::trunc, - ScalarFunction::Abs => math_expressions::abs, - ScalarFunction::Signum => math_expressions::signum, - ScalarFunction::Length => |args| Ok(Arc::new(length(args[0].as_ref())?)), - ScalarFunction::Concat => { + let fun_expr: ScalarFunctionImplementation = Arc::new(match fun { + BuiltinScalarFunction::Sqrt => math_expressions::sqrt, + BuiltinScalarFunction::Sin => math_expressions::sin, + BuiltinScalarFunction::Cos => math_expressions::cos, + BuiltinScalarFunction::Tan => math_expressions::tan, + BuiltinScalarFunction::Asin => math_expressions::asin, + BuiltinScalarFunction::Acos => math_expressions::acos, + BuiltinScalarFunction::Atan => math_expressions::atan, + BuiltinScalarFunction::Exp => math_expressions::exp, + BuiltinScalarFunction::Log => math_expressions::ln, + BuiltinScalarFunction::Log2 => math_expressions::log2, + BuiltinScalarFunction::Log10 => math_expressions::log10, + BuiltinScalarFunction::Floor => math_expressions::floor, + BuiltinScalarFunction::Ceil => math_expressions::ceil, + BuiltinScalarFunction::Round => math_expressions::round, + BuiltinScalarFunction::Trunc => math_expressions::trunc, + BuiltinScalarFunction::Abs => math_expressions::abs, + BuiltinScalarFunction::Signum => math_expressions::signum, + BuiltinScalarFunction::Length => |args| Ok(Arc::new(length(args[0].as_ref())?)), + BuiltinScalarFunction::Concat => { |args| Ok(Arc::new(string_expressions::concatenate(args)?)) } }); @@ -207,7 +221,7 @@ pub fn create_physical_expr( .map(|e| e.data_type(input_schema)) .collect::>>()?; - Ok(Arc::new(udf::ScalarFunctionExpr::new( + Ok(Arc::new(ScalarFunctionExpr::new( &format!("{}", fun), fun_expr, args, @@ -216,13 +230,13 @@ pub fn create_physical_expr( } /// the signatures supported by the function `fun`. -fn signature(fun: &ScalarFunction) -> Signature { +fn signature(fun: &BuiltinScalarFunction) -> Signature { // note: the physical expression must accept the type returned by this function or the execution panics. // for now, the list is small, as we do not have many built-in functions. match fun { - ScalarFunction::Length => Signature::Uniform(1, vec![DataType::Utf8]), - ScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), + BuiltinScalarFunction::Length => Signature::Uniform(1, vec![DataType::Utf8]), + BuiltinScalarFunction::Concat => Signature::Variadic(vec![DataType::Utf8]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). @@ -232,6 +246,80 @@ fn signature(fun: &ScalarFunction) -> Signature { } } +/// Physical expression of a scalar function +pub struct ScalarFunctionExpr { + fun: ScalarFunctionImplementation, + name: String, + args: Vec>, + return_type: DataType, +} + +impl Debug for ScalarFunctionExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("ScalarFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.args) + .field("return_type", &self.return_type) + .finish() + } +} + +impl ScalarFunctionExpr { + /// Create a new Scalar function + pub fn new( + name: &str, + fun: ScalarFunctionImplementation, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + fun, + name: name.to_owned(), + args, + return_type: return_type.clone(), + } + } +} + +impl fmt::Display for ScalarFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + self.args + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} + +impl PhysicalExpr for ScalarFunctionExpr { + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + // evaluate the function + let fun = self.fun.as_ref(); + (fun)(&inputs) + } +} + #[cfg(test)] mod tests { use super::*; @@ -251,7 +339,8 @@ mod tests { let arg = lit(value); - let expr = create_physical_expr(&ScalarFunction::Exp, &vec![arg], &schema)?; + let expr = + create_physical_expr(&BuiltinScalarFunction::Exp, &vec![arg], &schema)?; // type is correct assert_eq!(expr.data_type(&schema)?, DataType::Float64); @@ -289,7 +378,7 @@ mod tests { // concat(value, value) let expr = create_physical_expr( - &ScalarFunction::Concat, + &BuiltinScalarFunction::Concat, &vec![lit(value.clone()), lit(value)], &schema, )?; @@ -317,7 +406,7 @@ mod tests { #[test] fn test_concat_error() -> Result<()> { - let result = return_type(&ScalarFunction::Concat, &vec![]); + let result = return_type(&BuiltinScalarFunction::Concat, &vec![]); if let Ok(_) = result { Err(ExecutionError::General( "Function 'concat' cannot accept zero arguments".to_string(), diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index 5d83db1cf6bbb..17c0ad6483f59 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -38,7 +38,7 @@ use crate::physical_plan::merge::MergeExec; use crate::physical_plan::parquet::ParquetExec; use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::sort::SortExec; -use crate::physical_plan::udf::ScalarFunctionExpr; +use crate::physical_plan::udf; use crate::physical_plan::{expressions, Distribution}; use crate::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr, PhysicalPlanner}; use arrow::compute::SortOptions; @@ -429,12 +429,12 @@ impl DefaultPhysicalPlanner { ctx_state, )?); } - Ok(Arc::new(ScalarFunctionExpr::new( - &fun.name, - fun.fun.clone(), - physical_args, - &fun.return_type, - ))) + + udf::create_physical_expr( + fun.clone().as_ref(), + &physical_args, + input_schema, + ) } other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", diff --git a/rust/datafusion/src/physical_plan/type_coercion.rs b/rust/datafusion/src/physical_plan/type_coercion.rs index 8835ea0e8db2f..4ad6085bd1416 100644 --- a/rust/datafusion/src/physical_plan/type_coercion.rs +++ b/rust/datafusion/src/physical_plan/type_coercion.rs @@ -66,6 +66,7 @@ pub fn data_types( .map(|_| current_types[0].clone()) .collect()] } + Signature::Exact(valid_types) => vec![valid_types.clone()], }; if valid_types.contains(current_types) { diff --git a/rust/datafusion/src/physical_plan/udf.rs b/rust/datafusion/src/physical_plan/udf.rs index ffc206742329c..a02b38e0c42d1 100644 --- a/rust/datafusion/src/physical_plan/udf.rs +++ b/rust/datafusion/src/physical_plan/udf.rs @@ -17,144 +17,81 @@ //! UDF support +use fmt::{Debug, Formatter}; use std::fmt; -use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::Schema; use crate::error::Result; use crate::physical_plan::PhysicalExpr; -use arrow::record_batch::RecordBatch; -use fmt::{Debug, Formatter}; -use std::{collections::HashMap, sync::Arc}; - -/// Scalar UDF -pub type ScalarUdf = Arc Result + Send + Sync>; +use super::{ + functions::{ + ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature, + }, + type_coercion::coerce, +}; +use std::sync::Arc; -/// Scalar UDF Expression +/// Logical representation of a UDF. #[derive(Clone)] -pub struct ScalarFunction { - /// Function name +pub struct ScalarUDF { + /// name pub name: String, - /// Function argument meta-data - pub arg_types: Vec, + /// signature + pub signature: Signature, /// Return type - pub return_type: DataType, - /// UDF implementation - pub fun: ScalarUdf, -} - -/// Something which provides information for particular scalar functions -pub trait ScalarFunctionRegistry { - /// Return ScalarFunction for `name` - fn lookup(&self, name: &str) -> Option>; + pub return_type: ReturnTypeFunction, + /// actual implementation + pub fun: ScalarFunctionImplementation, } -impl ScalarFunctionRegistry for HashMap> { - fn lookup(&self, name: &str) -> Option> { - self.get(name).and_then(|func| Some(func.clone())) - } -} - -impl Debug for ScalarFunction { +impl Debug for ScalarUDF { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("ScalarFunction") + f.debug_struct("ScalarUDF") .field("name", &self.name) - .field("arg_types", &self.arg_types) - .field("return_type", &self.return_type) + .field("signature", &self.signature) .field("fun", &"") .finish() } } -impl ScalarFunction { - /// Create a new ScalarFunction +impl ScalarUDF { + /// Create a new ScalarUDF pub fn new( name: &str, - arg_types: Vec, - return_type: DataType, - fun: ScalarUdf, + signature: &Signature, + return_type: &ReturnTypeFunction, + fun: &ScalarFunctionImplementation, ) -> Self { Self { name: name.to_owned(), - arg_types, - return_type, - fun, - } - } -} - -/// Scalar UDF Physical Expression -pub struct ScalarFunctionExpr { - fun: ScalarUdf, - name: String, - args: Vec>, - return_type: DataType, -} - -impl Debug for ScalarFunctionExpr { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - f.debug_struct("ScalarFunctionExpr") - .field("fun", &"") - .field("name", &self.name) - .field("args", &self.args) - .field("return_type", &self.return_type) - .finish() - } -} - -impl ScalarFunctionExpr { - /// Create a new Scalar function - pub fn new( - name: &str, - fun: ScalarUdf, - args: Vec>, - return_type: &DataType, - ) -> Self { - Self { - fun, - name: name.to_owned(), - args, + signature: signature.clone(), return_type: return_type.clone(), + fun: fun.clone(), } } } -impl fmt::Display for ScalarFunctionExpr { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!( - f, - "{}({})", - self.name, - self.args - .iter() - .map(|e| format!("{}", e)) - .collect::>() - .join(", ") - ) - } -} - -impl PhysicalExpr for ScalarFunctionExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(true) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - // evaluate the arguments - let inputs = self - .args - .iter() - .map(|e| e.evaluate(batch)) - .collect::>>()?; - - // evaluate the function - let fun = self.fun.as_ref(); - (fun)(&inputs) - } +/// Create a physical expression of the UDF. +/// This function errors when `args`' can't be coerced to a valid argument type of the UDF. +pub fn create_physical_expr( + fun: &ScalarUDF, + args: &Vec>, + input_schema: &Schema, +) -> Result> { + // coerce + let args = coerce(args, input_schema, &fun.signature)?; + + let arg_types = args + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + + Ok(Arc::new(ScalarFunctionExpr::new( + &fun.name, + fun.fun.clone(), + args, + (fun.return_type)(&arg_types)?.as_ref(), + ))) } diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs index 25e5dd366f793..1b68347d6e26a 100644 --- a/rust/datafusion/src/prelude.rs +++ b/rust/datafusion/src/prelude.rs @@ -27,5 +27,7 @@ pub use crate::dataframe::DataFrame; pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; -pub use crate::logical_plan::{avg, col, concat, count, length, lit, max, min, sum}; +pub use crate::logical_plan::{ + avg, col, concat, count, create_udf, length, lit, max, min, sum, +}; pub use crate::physical_plan::csv::CsvReadOptions; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index e1a36f40d4d52..ccb084c667071 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -28,7 +28,7 @@ use crate::logical_plan::{ }; use crate::{ physical_plan::functions, - physical_plan::udf::ScalarFunction, + physical_plan::udf::ScalarUDF, sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}, }; @@ -48,7 +48,7 @@ pub trait SchemaProvider { /// Getter for a field description fn get_table_meta(&self, name: &str) -> Option; /// Getter for a UDF description - fn get_function_meta(&self, name: &str) -> Option>; + fn get_function_meta(&self, name: &str) -> Option>; } /// SQL query planner @@ -485,7 +485,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { let name: String = function.name.to_string(); // first, scalar built-in - if let Ok(fun) = functions::ScalarFunction::from_str(&name) { + if let Ok(fun) = functions::BuiltinScalarFunction::from_str(&name) { let args = function .args .iter() @@ -528,21 +528,15 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { // finally, user-defined functions _ => match self.schema_provider.get_function_meta(&name) { Some(fm) => { - let rex_args = function + let args = function .args .iter() .map(|a| self.sql_to_rex(a, schema)) .collect::>>()?; - let mut safe_args: Vec = vec![]; - for i in 0..rex_args.len() { - safe_args - .push(rex_args[i].cast_to(&fm.arg_types[i], schema)?); - } - Ok(Expr::ScalarUDF { fun: fm.clone(), - args: safe_args, + args, }) } _ => Err(ExecutionError::General(format!( @@ -592,7 +586,8 @@ pub fn convert_data_type(sql: &SQLDataType) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::sql::parser::DFParser; + use crate::{logical_plan::create_udf, sql::parser::DFParser}; + use functions::ScalarFunctionImplementation; #[test] fn select_no_relation() { @@ -917,13 +912,15 @@ mod tests { } } - fn get_function_meta(&self, name: &str) -> Option> { + fn get_function_meta(&self, name: &str) -> Option> { + let f: ScalarFunctionImplementation = + Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))); match name { - "my_sqrt" => Some(Arc::new(ScalarFunction::new( + "my_sqrt" => Some(Arc::new(create_udf( "my_sqrt", vec![DataType::Float64], - DataType::Float64, - Arc::new(|_| Err(ExecutionError::NotImplemented("".to_string()))), + Arc::new(DataType::Float64), + f, ))), _ => None, } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c0172adc546a8..050660c64a2b1 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -29,7 +29,7 @@ use datafusion::datasource::{csv::CsvReadOptions, MemTable}; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; use datafusion::logical_plan::LogicalPlan; -use datafusion::physical_plan::udf::ScalarFunction; +use datafusion::prelude::create_udf; #[test] fn nyc() -> Result<()> { @@ -201,6 +201,20 @@ fn csv_query_avg_sqrt() -> Result<()> { Ok(()) } +/// test that casting happens on udfs. +/// c11 is f32, but `custom_sqrt` requires f64. Casting happens but the logical plan and +/// physical plan have the same schema. +#[test] +fn csv_query_custom_udf_with_cast() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + let expected = "0.6584408483418833".to_string(); + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + /// sqrt(f32) is sligthly different than sqrt(CAST(f32 AS double))) #[test] fn sqrt_f32_vs_f64() -> Result<()> { @@ -247,10 +261,10 @@ fn create_ctx() -> Result { let mut ctx = ExecutionContext::new(); // register a custom UDF - ctx.register_udf(ScalarFunction::new( + ctx.register_udf(create_udf( "custom_sqrt", vec![DataType::Float64], - DataType::Float64, + Arc::new(DataType::Float64), Arc::new(custom_sqrt), )); @@ -498,7 +512,6 @@ fn csv_explain_verbose() { // pain). Instead just check for a few key pieces. assert!(actual.contains("logical_plan"), "Actual: '{}'", actual); assert!(actual.contains("physical_plan"), "Actual: '{}'", actual); - assert!(actual.contains("type_coercion"), "Actual: '{}'", actual); assert!(actual.contains("#c2 Gt Int64(10)"), "Actual: '{}'", actual); }