Skip to content

Commit

Permalink
feat(rust, sql): sql udfs (#10957)
Browse files Browse the repository at this point in the history
  • Loading branch information
universalmind303 committed Sep 20, 2023
1 parent 372ff5c commit 5771dc0
Show file tree
Hide file tree
Showing 9 changed files with 299 additions and 14 deletions.
4 changes: 4 additions & 0 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Expand Up @@ -11,6 +11,10 @@ use super::*;

/// A wrapper trait for any closure `Fn(Vec<Series>) -> PolarsResult<Series>`
pub trait SeriesUdf: Send + Sync {
fn as_any(&self) -> &dyn std::any::Any {
unimplemented!("as_any not implemented for this 'opaque' function")
}

fn call_udf(&self, s: &mut [Series]) -> PolarsResult<Option<Series>>;

fn try_serialize(&self, _buf: &mut Vec<u8>) -> PolarsResult<()> {
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/dsl/mod.rs
Expand Up @@ -32,7 +32,7 @@ mod selector;
pub mod string;
#[cfg(feature = "dtype-struct")]
mod struct_;

pub mod udf;
use std::fmt::Debug;
use std::sync::Arc;

Expand All @@ -58,6 +58,7 @@ use polars_time::series::SeriesOpsTime;
pub(crate) use selector::Selector;
#[cfg(feature = "dtype-struct")]
pub use struct_::*;
pub use udf::UserDefinedFunction;

use crate::constants::MAP_LIST_NAME;
pub use crate::logical_plan::lit;
Expand Down
92 changes: 92 additions & 0 deletions crates/polars-plan/src/dsl/udf.rs
@@ -0,0 +1,92 @@
use std::sync::Arc;

use polars_arrow::error::{polars_bail, PolarsResult};
use polars_core::prelude::Field;
use polars_core::schema::Schema;

use super::{Expr, GetOutput, SeriesUdf, SpecialEq};
use crate::prelude::{Context, FunctionOptions};

/// Represents a user-defined function
#[derive(Clone)]
pub struct UserDefinedFunction {
/// name
pub name: String,
/// The function signature.
pub input_fields: Vec<Field>,
/// The function output type.
pub return_type: GetOutput,
/// The function implementation.
pub fun: SpecialEq<Arc<dyn SeriesUdf>>,
/// Options for the function.
pub options: FunctionOptions,
}

impl std::fmt::Debug for UserDefinedFunction {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
f.debug_struct("UserDefinedFunction")
.field("name", &self.name)
.field("signature", &self.input_fields)
.field("fun", &"<FUNC>")
.field("options", &self.options)
.finish()
}
}

impl UserDefinedFunction {
/// Create a new UserDefinedFunction
pub fn new(
name: &str,
input_fields: Vec<Field>,
return_type: GetOutput,
fun: impl SeriesUdf + 'static,
) -> Self {
Self {
name: name.to_owned(),
input_fields,
return_type,
fun: SpecialEq::new(Arc::new(fun)),
options: FunctionOptions::default(),
}
}

/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the registry.
/// The schema is validated and the query will fail if the schema is invalid.
pub fn call(self, args: Vec<Expr>) -> PolarsResult<Expr> {
if args.len() != self.input_fields.len() {
polars_bail!(InvalidOperation: "expected {} arguments, got {}", self.input_fields.len(), args.len())
}
let schema = Schema::from_iter(self.input_fields);

if args
.iter()
.map(|e| e.to_field(&schema, Context::Default))
.collect::<PolarsResult<Vec<_>>>()
.is_err()
{
polars_bail!(InvalidOperation: "unexpected field in UDF \nexpected: {:?}\n received {:?}", schema, args)
};

Ok(Expr::AnonymousFunction {
input: args,
function: self.fun,
output_type: self.return_type,
options: self.options,
})
}

/// creates a logical expression with a call of the UDF
/// This does not do any schema validation and is therefore faster.
///
/// Only use this if you are certain that the schema is correct.
/// If the schema is invalid, the query will fail at runtime.
pub fn call_unchecked(self, args: Vec<Expr>) -> Expr {
Expr::AnonymousFunction {
input: args,
function: self.fun,
output_type: self.return_type.clone(),
options: self.options,
}
}
}
39 changes: 32 additions & 7 deletions crates/polars-sql/src/context.rs
Expand Up @@ -15,16 +15,28 @@ use sqlparser::ast::{
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
#[derive(Default, Clone)]
#[derive(Clone)]
pub struct SQLContext {
pub(crate) table_map: PlHashMap<String, LazyFrame>,
pub(crate) function_registry: Arc<dyn FunctionRegistry>,
cte_map: RefCell<PlHashMap<String, LazyFrame>>,
}

impl Default for SQLContext {
fn default() -> Self {
Self {
function_registry: Arc::new(DefaultFunctionRegistry {}),
table_map: Default::default(),
cte_map: Default::default(),
}
}
}

impl SQLContext {
/// Create a new SQLContext.
/// ```rust
Expand All @@ -34,12 +46,8 @@ impl SQLContext {
/// # }
/// ```
pub fn new() -> Self {
Self {
table_map: PlHashMap::new(),
cte_map: RefCell::new(PlHashMap::new()),
}
Self::default()
}

/// Get the names of all registered tables, in sorted order.
pub fn get_tables(&self) -> Vec<String> {
let mut tables = Vec::from_iter(self.table_map.keys().cloned());
Expand Down Expand Up @@ -107,6 +115,23 @@ impl SQLContext {
self.cte_map.borrow_mut().clear();
res
}

/// add a function registry to the SQLContext
/// the registry provides the ability to add custom functions to the SQLContext
pub fn with_function_registry(mut self, function_registry: Arc<dyn FunctionRegistry>) -> Self {
self.function_registry = function_registry;
self
}

/// Get the function registry of the SQLContext
pub fn registry(&self) -> &Arc<dyn FunctionRegistry> {
&self.function_registry
}

/// Get a mutable reference to the function registry of the SQLContext
pub fn registry_mut(&mut self) -> &mut dyn FunctionRegistry {
Arc::get_mut(&mut self.function_registry).unwrap()
}
}

impl SQLContext {
Expand Down Expand Up @@ -702,7 +727,7 @@ impl SQLContext {
pub fn new_from_table_map(table_map: PlHashMap<String, LazyFrame>) -> Self {
Self {
table_map,
cte_map: RefCell::new(PlHashMap::new()),
..Default::default()
}
}
}
30 changes: 30 additions & 0 deletions crates/polars-sql/src/function_registry.rs
@@ -0,0 +1,30 @@
//! This module defines the function registry and user defined functions.

use polars_arrow::error::{polars_bail, PolarsResult};
use polars_plan::prelude::udf::UserDefinedFunction;
pub use polars_plan::prelude::{Context, FunctionOptions};
/// A registry that holds user defined functions.
pub trait FunctionRegistry: Send + Sync {
/// Register a function.
fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()>;
/// Call a user defined function.
fn get_udf(&self, name: &str) -> PolarsResult<Option<UserDefinedFunction>>;
/// Check if a function is registered.
fn contains(&self, name: &str) -> bool;
}

/// A default registry that does not support registering or calling functions.
pub struct DefaultFunctionRegistry {}

impl FunctionRegistry for DefaultFunctionRegistry {
fn register(&mut self, _name: &str, _fun: UserDefinedFunction) -> PolarsResult<()> {
polars_bail!(ComputeError: "'register' not implemented on DefaultFunctionRegistry'")
}

fn get_udf(&self, _name: &str) -> PolarsResult<Option<UserDefinedFunction>> {
polars_bail!(ComputeError: "'get_udf' not implemented on DefaultFunctionRegistry'")
}
fn contains(&self, _name: &str) -> bool {
false
}
}
40 changes: 34 additions & 6 deletions crates/polars-sql/src/functions.rs
@@ -1,4 +1,4 @@
use polars_core::prelude::{polars_bail, polars_err, PolarsError, PolarsResult};
use polars_core::prelude::{polars_bail, polars_err, PolarsResult};
use polars_lazy::dsl::Expr;
use polars_plan::dsl::count;
use polars_plan::logical_plan::LiteralValue;
Expand Down Expand Up @@ -362,6 +362,7 @@ pub(crate) enum PolarsSqlFunctions {
/// SELECT ARRAY_CONTAINS(column_1, 'foo') from df;
/// ```
ArrayContains,
Udf(String),
}

impl PolarsSqlFunctions {
Expand Down Expand Up @@ -435,9 +436,8 @@ impl PolarsSqlFunctions {
}
}

impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
type Error = PolarsError;
fn try_from(function: &'_ SQLFunction) -> Result<Self, Self::Error> {
impl PolarsSqlFunctions {
fn try_from_sql(function: &'_ SQLFunction, ctx: &'_ SQLContext) -> PolarsResult<Self> {
let function_name = function.name.0[0].value.to_lowercase();
Ok(match function_name.as_str() {
// ----
Expand Down Expand Up @@ -519,16 +519,22 @@ impl TryFrom<&'_ SQLFunction> for PolarsSqlFunctions {
"array_upper" => Self::ArrayMax,
"unnest" => Self::Explode,

other => polars_bail!(InvalidOperation: "unsupported SQL function: {}", other),
other => {
if ctx.function_registry.contains(other) {
Self::Udf(other.to_string())
} else {
polars_bail!(InvalidOperation: "unsupported SQL function: {}", other);
}
},
})
}
}

impl SqlFunctionVisitor<'_> {
pub(crate) fn visit_function(&self) -> PolarsResult<Expr> {
let function = self.func;
let function_name = PolarsSqlFunctions::try_from_sql(function, self.ctx)?;

let function_name: PolarsSqlFunctions = function.try_into()?;
use PolarsSqlFunctions::*;

match function_name {
Expand Down Expand Up @@ -686,6 +692,28 @@ impl SqlFunctionVisitor<'_> {
}),
ArrayUnique => self.visit_unary(|e| e.list().unique()),
Explode => self.visit_unary(|e| e.explode()),
Udf(func_name) => self.visit_udf(&func_name)
}
}

fn visit_udf(&self, func_name: &str) -> PolarsResult<Expr> {
let function = self.func;

let args = extract_args(function);
let args = args
.into_iter()
.map(|arg| {
if let FunctionArgExpr::Expr(e) = arg {
parse_sql_expr(e, self.ctx)
} else {
polars_bail!(ComputeError: "Only expressions are supported in UDFs")
}
})
.collect::<PolarsResult<Vec<_>>>()?;
if let Some(expr) = self.ctx.function_registry.get_udf(func_name)? {
expr.call(args)
} else {
polars_bail!(ComputeError: "UDF {} not found", func_name)
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/polars-sql/src/lib.rs
Expand Up @@ -2,6 +2,7 @@
//! This crate provides a SQL interface for Polars DataFrames
#![deny(missing_docs)]
mod context;
pub mod function_registry;
mod functions;
pub mod keywords;
mod sql_expr;
Expand Down

0 comments on commit 5771dc0

Please sign in to comment.