Skip to content

Commit

Permalink
Merge branch 'apache_main' into feature/optimize-projections
Browse files Browse the repository at this point in the history
  • Loading branch information
berkaysynnada committed May 14, 2024
2 parents 8983ae2 + b8fab5c commit 250ad8c
Show file tree
Hide file tree
Showing 50 changed files with 1,110 additions and 351 deletions.
180 changes: 180 additions & 0 deletions datafusion-examples/examples/simplify_udaf_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_schema::{Field, Schema};
use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use datafusion_expr::function::AggregateFunctionSimplification;
use datafusion_expr::simplify::SimplifyInfo;

use std::{any::Any, sync::Arc};

use datafusion::arrow::{array::Float32Array, record_batch::RecordBatch};
use datafusion::error::Result;
use datafusion::{assert_batches_eq, prelude::*};
use datafusion_common::cast::as_float64_array;
use datafusion_expr::{
expr::{AggregateFunction, AggregateFunctionDefinition},
function::AccumulatorArgs,
Accumulator, AggregateUDF, AggregateUDFImpl, GroupsAccumulator, Signature,
};

/// This example shows how to use the AggregateUDFImpl::simplify API to simplify/replace user
/// defined aggregate function with a different expression which is defined in the `simplify` method.

#[derive(Debug, Clone)]
struct BetterAvgUdaf {
signature: Signature,
}

impl BetterAvgUdaf {
/// Create a new instance of the GeoMeanUdaf struct
fn new() -> Self {
Self {
signature: Signature::exact(vec![DataType::Float64], Volatility::Immutable),
}
}
}

impl AggregateUDFImpl for BetterAvgUdaf {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"better_avg"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Float64)
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
unimplemented!("should not be invoked")
}

fn state_fields(
&self,
_name: &str,
_value_type: DataType,
_ordering_fields: Vec<arrow_schema::Field>,
) -> Result<Vec<arrow_schema::Field>> {
unimplemented!("should not be invoked")
}

fn groups_accumulator_supported(&self) -> bool {
true
}

fn create_groups_accumulator(&self) -> Result<Box<dyn GroupsAccumulator>> {
unimplemented!("should not get here");
}
// we override method, to return new expression which would substitute
// user defined function call
fn simplify(&self) -> Option<AggregateFunctionSimplification> {
// as an example for this functionality we replace UDF function
// with build-in aggregate function to illustrate the use
let simplify = |aggregate_function: datafusion_expr::expr::AggregateFunction,
_: &dyn SimplifyInfo| {
Ok(Expr::AggregateFunction(AggregateFunction {
func_def: AggregateFunctionDefinition::BuiltIn(
// yes it is the same Avg, `BetterAvgUdaf` was just a
// marketing pitch :)
datafusion_expr::aggregate_function::AggregateFunction::Avg,
),
args: aggregate_function.args,
distinct: aggregate_function.distinct,
filter: aggregate_function.filter,
order_by: aggregate_function.order_by,
null_treatment: aggregate_function.null_treatment,
}))
};

Some(Box::new(simplify))
}
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, false),
Field::new("b", DataType::Float32, false),
]));

// define data in two partitions
let batch1 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0])),
Arc::new(Float32Array::from(vec![2.0, 2.0, 2.0])),
],
)?;
let batch2 = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Float32Array::from(vec![16.0])),
Arc::new(Float32Array::from(vec![2.0])),
],
)?;

let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
Ok(ctx)
}

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;

let better_avg = AggregateUDF::from(BetterAvgUdaf::new());
ctx.register_udaf(better_avg.clone());

let result = ctx
.sql("SELECT better_avg(a) FROM t group by b")
.await?
.collect()
.await?;

let expected = [
"+-----------------+",
"| better_avg(t.a) |",
"+-----------------+",
"| 7.5 |",
"+-----------------+",
];

assert_batches_eq!(expected, &result);

let df = ctx.table("t").await?;
let df = df.aggregate(vec![], vec![better_avg.call(vec![col("a")])])?;

let results = df.collect().await?;
let result = as_float64_array(results[0].column(0))?;

assert!((result.value(0) - 7.5).abs() < f64::EPSILON);
println!("The average of [2,4,8,16] is {}", result.value(0));

Ok(())
}
68 changes: 28 additions & 40 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ macro_rules! config_namespace {
$(
stringify!($field_name) => self.$field_name.set(rem, value),
)*
_ => return Err(DataFusionError::Configuration(format!(
_ => return _config_err!(
"Config value \"{}\" not found on {}", key, stringify!($struct_name)
)))
)
}
}

Expand Down Expand Up @@ -181,7 +181,8 @@ config_namespace! {
/// Type of `TableProvider` to use when loading `default` schema
pub format: Option<String>, default = None

/// If the file has a header
/// Default value for `format.has_header` for `CREATE EXTERNAL TABLE`
/// if not specified explicitly in the statement.
pub has_header: bool, default = false
}
}
Expand Down Expand Up @@ -675,22 +676,17 @@ impl ConfigOptions {

/// Set a configuration option
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
let (prefix, key) = key.split_once('.').ok_or_else(|| {
DataFusionError::Configuration(format!(
"could not find config namespace for key \"{key}\"",
))
})?;
let Some((prefix, key)) = key.split_once('.') else {
return _config_err!("could not find config namespace for key \"{key}\"");
};

if prefix == "datafusion" {
return ConfigField::set(self, key, value);
}
let e = self.extensions.0.get_mut(prefix);
let e = e.ok_or_else(|| {
DataFusionError::Configuration(format!(
"Could not find config namespace \"{prefix}\""
))
})?;
let Some(e) = self.extensions.0.get_mut(prefix) else {
return _config_err!("Could not find config namespace \"{prefix}\"");
};
e.0.set(key, value)
}

Expand Down Expand Up @@ -1278,22 +1274,17 @@ impl TableOptions {
///
/// A result indicating success or failure in setting the configuration option.
pub fn set(&mut self, key: &str, value: &str) -> Result<()> {
let (prefix, _) = key.split_once('.').ok_or_else(|| {
DataFusionError::Configuration(format!(
"could not find config namespace for key \"{key}\""
))
})?;
let Some((prefix, _)) = key.split_once('.') else {
return _config_err!("could not find config namespace for key \"{key}\"");
};

if prefix == "format" {
return ConfigField::set(self, key, value);
}
let e = self.extensions.0.get_mut(prefix);
let e = e.ok_or_else(|| {
DataFusionError::Configuration(format!(
"Could not find config namespace \"{prefix}\""
))
})?;
let Some(e) = self.extensions.0.get_mut(prefix) else {
return _config_err!("Could not find config namespace \"{prefix}\"");
};
e.0.set(key, value)
}

Expand Down Expand Up @@ -1412,19 +1403,19 @@ impl ConfigField for TableParquetOptions {
fn set(&mut self, key: &str, value: &str) -> Result<()> {
// Determine if the key is a global, metadata, or column-specific setting
if key.starts_with("metadata::") {
let k =
match key.split("::").collect::<Vec<_>>()[..] {
[_meta] | [_meta, ""] => return Err(DataFusionError::Configuration(
let k = match key.split("::").collect::<Vec<_>>()[..] {
[_meta] | [_meta, ""] => {
return _config_err!(
"Invalid metadata key provided, missing key in metadata::<key>"
.to_string(),
)),
[_meta, k] => k.into(),
_ => {
return Err(DataFusionError::Configuration(format!(
)
}
[_meta, k] => k.into(),
_ => {
return _config_err!(
"Invalid metadata key provided, found too many '::' in \"{key}\""
)))
}
};
)
}
};
self.key_value_metadata.insert(k, Some(value.into()));
Ok(())
} else if key.contains("::") {
Expand Down Expand Up @@ -1497,10 +1488,7 @@ macro_rules! config_namespace_with_hashmap {

inner_value.set(inner_key, value)
}
_ => Err(DataFusionError::Configuration(format!(
"Unrecognized key '{}'.",
key
))),
_ => _config_err!("Unrecognized key '{key}'."),
}
}

Expand Down
24 changes: 11 additions & 13 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ use std::ops::ControlFlow;
use std::sync::{Arc, Weak};

use super::options::ReadOptions;
#[cfg(feature = "array_expressions")]
use crate::functions_array;
use crate::{
catalog::information_schema::{InformationSchemaProvider, INFORMATION_SCHEMA},
catalog::listing_schema::ListingSchemaProvider,
Expand Down Expand Up @@ -53,53 +55,49 @@ use crate::{
},
optimizer::analyzer::{Analyzer, AnalyzerRule},
optimizer::optimizer::{Optimizer, OptimizerConfig, OptimizerRule},
physical_expr::{create_physical_expr, PhysicalExpr},
physical_optimizer::optimizer::{PhysicalOptimizer, PhysicalOptimizerRule},
physical_plan::ExecutionPlan,
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
variable::{VarProvider, VarType},
};

#[cfg(feature = "array_expressions")]
use crate::functions_array;
use crate::{functions, functions_aggregate};

use arrow::datatypes::{DataType, SchemaRef};
use arrow::record_batch::RecordBatch;
use arrow_schema::Schema;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use datafusion_common::tree_node::TreeNode;
use datafusion_common::{
alias::AliasGenerator,
config::{ConfigExtension, TableOptions},
exec_err, not_impl_err, plan_datafusion_err, plan_err,
tree_node::{TreeNodeRecursion, TreeNodeVisitor},
tree_node::{TreeNode, TreeNodeRecursion, TreeNodeVisitor},
DFSchema, SchemaReference, TableReference,
};
use datafusion_execution::registry::SerializerRegistry;
use datafusion_expr::{
expr_rewriter::FunctionRewrite,
logical_plan::{DdlStatement, Statement},
simplify::SimplifyInfo,
var_provider::is_system_variables,
Expr, ExprSchemable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
};
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_sql::{
parser::{CopyToSource, CopyToStatement, DFParser},
planner::{object_name_to_table_reference, ContextProvider, ParserOptions, SqlToRel},
ResolvedTableReference,
};
use parking_lot::RwLock;
use sqlparser::dialect::dialect_from_str;

use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::RwLock;
use url::Url;
use uuid::Uuid;

use crate::physical_expr::PhysicalExpr;
pub use datafusion_execution::config::SessionConfig;
pub use datafusion_execution::TaskContext;
pub use datafusion_expr::execution_props::ExecutionProps;
use datafusion_expr::expr_rewriter::FunctionRewrite;
use datafusion_expr::simplify::SimplifyInfo;
use datafusion_optimizer::simplify_expressions::ExprSimplifier;
use datafusion_physical_expr::create_physical_expr;

mod avro;
mod csv;
Expand Down
Loading

0 comments on commit 250ad8c

Please sign in to comment.