diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs index 99511e969386..3f0c5033afdc 100644 --- a/datafusion-cli/src/functions.rs +++ b/datafusion-cli/src/functions.rs @@ -27,6 +27,7 @@ use datafusion::common::{plan_err, Column}; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; +use datafusion::execution::SessionState; use datafusion::logical_expr::Expr; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; @@ -317,7 +318,11 @@ fn fixed_len_byte_array_to_string(val: Option<&FixedLenByteArray>) -> Option Result> { + fn call( + &self, + _state: &SessionState, + exprs: &[Expr], + ) -> Result> { let filename = match exprs.first() { Some(Expr::Literal(ScalarValue::Utf8(Some(s)))) => s, // single quote: parquet_metadata('x.parquet') Some(Expr::Column(Column { name, .. })) => name, // double quote: parquet_metadata("x.parquet") diff --git a/datafusion-examples/examples/simple_udtf.rs b/datafusion-examples/examples/simple_udtf.rs index fe7f37cc00e3..260cce6d09e6 100644 --- a/datafusion-examples/examples/simple_udtf.rs +++ b/datafusion-examples/examples/simple_udtf.rs @@ -25,6 +25,7 @@ use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; use datafusion::execution::context::ExecutionProps; +use datafusion::execution::SessionState; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; @@ -130,7 +131,11 @@ impl TableProvider for LocalCsvTable { struct LocalCsvTableFunc {} impl TableFunctionImpl for LocalCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call( + &self, + _state: &SessionState, + exprs: &[Expr], + ) -> Result> { let Some(Expr::Literal(ScalarValue::Utf8(Some(ref path)))) = exprs.first() else { return plan_err!("read_csv requires at least one string argument"); }; diff --git a/datafusion/core/src/datasource/function.rs b/datafusion/core/src/datasource/function.rs index 14bbc431f973..0044b8ecf0fa 100644 --- a/datafusion/core/src/datasource/function.rs +++ b/datafusion/core/src/datasource/function.rs @@ -17,6 +17,8 @@ //! A table that uses a function to generate data +use crate::execution::SessionState; + use super::TableProvider; use datafusion_common::Result; @@ -27,7 +29,8 @@ use std::sync::Arc; /// A trait for table function implementations pub trait TableFunctionImpl: Sync + Send { /// Create a table provider - fn call(&self, args: &[Expr]) -> Result>; + fn call(&self, state: &SessionState, args: &[Expr]) + -> Result>; } /// A table that uses a function to generate data @@ -55,7 +58,11 @@ impl TableFunction { } /// Get the function implementation and generate a table - pub fn create_table_provider(&self, args: &[Expr]) -> Result> { - self.fun.call(args) + pub fn create_table_provider( + &self, + state: &SessionState, + args: &[Expr], + ) -> Result> { + self.fun.call(state, args) } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 2ea6c8878fd1..1247f84e285c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -397,7 +397,7 @@ impl SessionContext { Arc::clone(&factory) as Arc, )); let new_state = SessionStateBuilder::new_from_existing(self.state()) - .with_catalog_list(catalog_list) + .with_catalog_list(Some(catalog_list)) .build(); let ctx = SessionContext::new_with_state(new_state); factory.session_store().with_state(ctx.state_weak_ref()); diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index f24fec665f49..f5b63ea02f7d 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -280,7 +280,7 @@ impl SessionState { SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) - .with_catalog_list(catalog_list) + .with_catalog_list(Some(catalog_list)) .with_default_features() .build() } @@ -296,7 +296,7 @@ impl SessionState { SessionStateBuilder::new() .with_config(config) .with_runtime_env(runtime) - .with_catalog_list(catalog_list) + .with_catalog_list(Some(catalog_list)) .with_default_features() .build() } @@ -932,6 +932,7 @@ impl SessionState { /// be used for all values unless explicitly provided. /// /// See example on [`SessionState`] +#[derive(Clone)] pub struct SessionStateBuilder { session_id: Option, analyzer: Option, @@ -1140,9 +1141,9 @@ impl SessionStateBuilder { /// Set the [`CatalogProviderList`] pub fn with_catalog_list( mut self, - catalog_list: Arc, + catalog_list: Option>, ) -> Self { - self.catalog_list = Some(catalog_list); + self.catalog_list = catalog_list; self } @@ -1543,7 +1544,7 @@ impl ContextProvider for SessionContextProvider<'_> { .get(name) .cloned() .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; - let provider = tbl_func.create_table_provider(&args)?; + let provider = tbl_func.create_table_provider(self.state, &args)?; Ok(provider_as_source(provider)) } @@ -1876,7 +1877,7 @@ mod tests { let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; let session_state = SessionStateBuilder::new() - .with_catalog_list(Arc::new(MemoryCatalogProviderList::new())) + .with_catalog_list(Some(Arc::new(MemoryCatalogProviderList::new()))) .build(); let table_ref = session_state.resolve_table_ref("employee").to_string(); session_state diff --git a/datafusion/core/tests/user_defined/user_defined_table_functions.rs b/datafusion/core/tests/user_defined/user_defined_table_functions.rs index 5fd3b7a03384..b9718b5c7bdd 100644 --- a/datafusion/core/tests/user_defined/user_defined_table_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_table_functions.rs @@ -24,7 +24,7 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::function::TableFunctionImpl; use datafusion::datasource::TableProvider; use datafusion::error::Result; -use datafusion::execution::TaskContext; +use datafusion::execution::{SessionState, TaskContext}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::SessionContext; @@ -194,7 +194,11 @@ impl SimpleCsvTable { struct SimpleCsvTableFunc {} impl TableFunctionImpl for SimpleCsvTableFunc { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call( + &self, + _state: &SessionState, + exprs: &[Expr], + ) -> Result> { let mut new_exprs = vec![]; let mut filepath = String::new(); for expr in exprs { diff --git a/docs/source/library-user-guide/adding-udfs.md b/docs/source/library-user-guide/adding-udfs.md index fe3990b90c3c..cb2af297da58 100644 --- a/docs/source/library-user-guide/adding-udfs.md +++ b/docs/source/library-user-guide/adding-udfs.md @@ -562,6 +562,7 @@ In the `call` method, you parse the input `Expr`s and return a `TableProvider`. ```rust use datafusion::common::plan_err; use datafusion::datasource::function::TableFunctionImpl; +use datafusion::execution::SessionState; // Other imports here /// A table function that returns a table provider with the value as a single column @@ -569,7 +570,10 @@ use datafusion::datasource::function::TableFunctionImpl; pub struct EchoFunction {} impl TableFunctionImpl for EchoFunction { - fn call(&self, exprs: &[Expr]) -> Result> { + fn call(&self, + _state: &SessionState, + exprs: &[Expr], + ) -> Result> { let Some(Expr::Literal(ScalarValue::Int64(Some(value)))) = exprs.get(0) else { return plan_err!("First argument must be an integer"); };