Skip to content

Commit

Permalink
ARROW-9751: [Rust] [DataFusion] Allow UDFs to accept multiple data ty…
Browse files Browse the repository at this point in the history
…pes per argument

This PR aligns UDFs registration and declaration to be consistent with our built-in functions, so that we can leverage coercion rules on their arguments.

For ease of use, this PR introduces a function `create_udf` that simplifies the creation of UDFs with a fixed signature and fixed return type, so that users have a simple interface to declare them.

However, underneath, the UDFs have the same capabilities as built-in functions, in that they can be as generic as built-in functions (arbitrary types, etc.).

Specific achievements of this PR:

* Added example (120 LOC) of how to declare and register a UDF
* Deprecated the type coercer optimizer, since it was causing logical schemas to become misaligned and cause our end-to-end tests to faail when implicit casting was required, and replaced it by what we already do for built-ins
* Made UDFs use the same interfaces as built-in functions

Note that this PR is built on top of apache#8032.

Closes apache#7967 from jorgecarleitao/clean

Authored-by: Jorge C. Leitao <jorgecarleitao@gmail.com>
Signed-off-by: Andy Grove <andygrove73@gmail.com>
  • Loading branch information
jorgecarleitao authored and GeorgeAp committed Jun 7, 2021
1 parent 1ce4e1a commit 94bcb86
Show file tree
Hide file tree
Showing 13 changed files with 427 additions and 378 deletions.
138 changes: 138 additions & 0 deletions rust/datafusion/examples/simple_udf.rs
@@ -0,0 +1,138 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::{
array::{Array, ArrayRef, Float32Array, Float64Array, Float64Builder},
datatypes::DataType,
record_batch::RecordBatch,
util::pretty,
};

use datafusion::error::Result;
use datafusion::{physical_plan::functions::ScalarFunctionImplementation, prelude::*};
use std::sync::Arc;

// create local execution context with an in-memory table
fn create_context() -> Result<ExecutionContext> {
use arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float64, false),
]));

// define data.
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1])),
Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0])),
],
)?;

// declare a new context. In spark API, this corresponds to a new spark SQLsession
let mut ctx = ExecutionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::new(schema, vec![vec![batch]])?;
ctx.register_table("t", Box::new(provider));
Ok(ctx)
}

/// In this example we will declare a single-type, single return type UDF that exponentiates f64, a^b
fn main() -> Result<()> {
let mut ctx = create_context()?;

// First, declare the actual implementation of the calculation
let pow: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
// in DataFusion, all `args` and output are dynamically-typed arrays, which means that we need to:
// 1. cast the values to the type we want
// 2. perform the computation for every element in the array (using a loop or SIMD)
// 3. construct the resulting array

// this is guaranteed by DataFusion based on the function's signature.
assert_eq!(args.len(), 2);

// 1. cast both arguments to f64. These casts MUST be aligned with the signature or this function panics!
let base = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let exponent = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");

// this is guaranteed by DataFusion. We place it just to make it obvious.
assert_eq!(exponent.len(), base.len());

// 2. Arrow's builder is used to construct an Arrow array.
let mut builder = Float64Builder::new(base.len());
for index in 0..base.len() {
// in arrow, any value can be null.
// Here we decide to make our UDF to return null when either base or exponent is null.
if base.is_null(index) || exponent.is_null(index) {
builder.append_null()?;
} else {
// 3. computation. Since we do not have any SIMD `pow` operation at our hands,
// we loop over each entry. Array's values are obtained via `.value(index)`.
let value = base.value(index).powf(exponent.value(index));
builder.append_value(value)?;
}
}
Ok(Arc::new(builder.finish()))
});

// Next:
// * give it a name so that it shows nicely when the plan is printed
// * declare what input it expects
// * declare its return type
let pow = create_udf(
"pow",
// expects two f64
vec![DataType::Float64, DataType::Float64],
// returns f64
Arc::new(DataType::Float64),
pow,
);

// finally, register the UDF
ctx.register_udf(pow);

// at this point, we can use it. Note that the code below can be in a
// scope on which we do not have access to `pow`.

// get a DataFrame from the context
let df = ctx.table("t")?;

// get the udf registry.
let f = df.registry();

// equivalent to `'SELECT pow(a, b) FROM t'`
let df = df.select(vec![f.udf("pow", vec![col("a"), col("b")])?])?;

// note that "b" is f32, not f64. DataFusion coerces the types to match the UDF's signature.

// execute the query
let results = df.collect()?;

// print the results
pretty::print_batches(&results)?;

Ok(())
}
31 changes: 10 additions & 21 deletions rust/datafusion/src/execution/context.rs
Expand Up @@ -40,13 +40,11 @@ use crate::logical_plan::{Expr, FunctionRegistry, LogicalPlan, LogicalPlanBuilde
use crate::optimizer::filter_push_down::FilterPushDown;
use crate::optimizer::optimizer::OptimizerRule;
use crate::optimizer::projection_push_down::ProjectionPushDown;
use crate::optimizer::type_coercion::TypeCoercionRule;
use crate::physical_plan::common;
use crate::physical_plan::csv::CsvReadOptions;
use crate::physical_plan::merge::MergeExec;
use crate::physical_plan::planner::DefaultPhysicalPlanner;
use crate::physical_plan::udf::ScalarFunction;
use crate::physical_plan::udf::ScalarFunctionRegistry;
use crate::physical_plan::udf::ScalarUDF;
use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::PhysicalPlanner;
use crate::sql::{
Expand Down Expand Up @@ -180,7 +178,7 @@ impl ExecutionContext {
}

/// Register a scalar UDF
pub fn register_udf(&mut self, f: ScalarFunction) {
pub fn register_udf(&mut self, f: ScalarUDF) {
self.state
.scalar_functions
.insert(f.name.clone(), Arc::new(f));
Expand Down Expand Up @@ -294,7 +292,6 @@ impl ExecutionContext {
// Apply standard rewrites and optimizations
let mut plan = ProjectionPushDown::new().optimize(&plan)?;
plan = FilterPushDown::new().optimize(&plan)?;
plan = TypeCoercionRule::new().optimize(&plan)?;

self.state.config.query_planner.rewrite_logical_plan(plan)
}
Expand Down Expand Up @@ -377,12 +374,6 @@ impl ExecutionContext {
}
}

impl ScalarFunctionRegistry for ExecutionContext {
fn lookup(&self, name: &str) -> Option<Arc<ScalarFunction>> {
self.state.scalar_functions.lookup(name)
}
}

/// A planner used to add extensions to DataFusion logical and phusical plans.
pub trait QueryPlanner {
/// Given a `LogicalPlan`, create a new, modified `LogicalPlan`
Expand Down Expand Up @@ -468,7 +459,7 @@ pub struct ExecutionContextState {
/// Data sources that are registered with the context
pub datasources: HashMap<String, Arc<dyn TableProvider + Send + Sync>>,
/// Scalar functions that are registered with the context
pub scalar_functions: HashMap<String, Arc<ScalarFunction>>,
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
/// Context configuration
pub config: ExecutionConfig,
}
Expand All @@ -478,7 +469,7 @@ impl SchemaProvider for ExecutionContextState {
self.datasources.get(name).map(|ds| ds.schema().clone())
}

fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarFunction>> {
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
self.scalar_functions
.get(name)
.and_then(|func| Some(func.clone()))
Expand Down Expand Up @@ -510,8 +501,8 @@ mod tests {

use super::*;
use crate::datasource::MemTable;
use crate::logical_plan::{aggregate_expr, col};
use crate::physical_plan::udf::ScalarUdf;
use crate::logical_plan::{aggregate_expr, col, create_udf};
use crate::physical_plan::functions::ScalarFunctionImplementation;
use crate::test;
use arrow::array::{ArrayRef, Int32Array};
use arrow::compute::add;
Expand Down Expand Up @@ -1032,7 +1023,7 @@ mod tests {
let provider = MemTable::new(Arc::new(schema), vec![vec![batch]])?;
ctx.register_table("t", Box::new(provider));

let myfunc: ScalarUdf = Arc::new(|args: &[ArrayRef]| {
let myfunc: ScalarFunctionImplementation = Arc::new(|args: &[ArrayRef]| {
let l = &args[0]
.as_any()
.downcast_ref::<Int32Array>()
Expand All @@ -1044,14 +1035,12 @@ mod tests {
Ok(Arc::new(add(l, r)?))
});

let my_add = ScalarFunction::new(
ctx.register_udf(create_udf(
"my_add",
vec![DataType::Int32, DataType::Int32],
DataType::Int32,
Arc::new(DataType::Int32),
myfunc,
);

ctx.register_udf(my_add);
));

// from here on, we may be in a different scope. We would still like to be able
// to call UDFs.
Expand Down
44 changes: 13 additions & 31 deletions rust/datafusion/src/execution/dataframe_impl.rs
Expand Up @@ -136,15 +136,8 @@ mod tests {
use crate::datasource::csv::CsvReadOptions;
use crate::execution::context::ExecutionContext;
use crate::logical_plan::*;
use crate::{
physical_plan::udf::{ScalarFunction, ScalarUdf},
test,
};
use arrow::{
array::{ArrayRef, Float64Array},
compute::add,
datatypes::DataType,
};
use crate::{physical_plan::functions::ScalarFunctionImplementation, test};
use arrow::{array::ArrayRef, datatypes::DataType};

#[test]
fn select_columns() -> Result<()> {
Expand Down Expand Up @@ -250,39 +243,28 @@ mod tests {
register_aggregate_csv(&mut ctx)?;

// declare the udf
let my_add: ScalarUdf = Arc::new(|args: &[ArrayRef]| {
let l = &args[0]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
let r = &args[1]
.as_any()
.downcast_ref::<Float64Array>()
.expect("cast failed");
Ok(Arc::new(add(l, r)?))
});

let my_add = ScalarFunction::new(
"my_add",
vec![DataType::Float64],
DataType::Float64,
my_add,
);
let my_fn: ScalarFunctionImplementation =
Arc::new(|_: &[ArrayRef]| unimplemented!("my_fn is not implemented"));

// register the udf
ctx.register_udf(my_add);
// create and register the udf
ctx.register_udf(create_udf(
"my_fn",
vec![DataType::Float64],
Arc::new(DataType::Float64),
my_fn,
));

// build query with a UDF using DataFrame API
let df = ctx.table("aggregate_test_100")?;

let f = df.registry();

let df = df.select(vec![f.udf("my_add", vec![col("c12")])?])?;
let df = df.select(vec![f.udf("my_fn", vec![col("c12")])?])?;
let plan = df.to_logical_plan();

// build query using SQL
let sql_plan =
ctx.create_logical_plan("SELECT my_add(c12) FROM aggregate_test_100")?;
ctx.create_logical_plan("SELECT my_fn(c12) FROM aggregate_test_100")?;

// the two plans should be identical
assert_same_plan(&plan, &sql_plan);
Expand Down

0 comments on commit 94bcb86

Please sign in to comment.