Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
372ff5c
commit 5771dc0
Showing
9 changed files
with
299 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
use std::sync::Arc; | ||
|
||
use polars_arrow::error::{polars_bail, PolarsResult}; | ||
use polars_core::prelude::Field; | ||
use polars_core::schema::Schema; | ||
|
||
use super::{Expr, GetOutput, SeriesUdf, SpecialEq}; | ||
use crate::prelude::{Context, FunctionOptions}; | ||
|
||
/// Represents a user-defined function | ||
#[derive(Clone)] | ||
pub struct UserDefinedFunction { | ||
/// name | ||
pub name: String, | ||
/// The function signature. | ||
pub input_fields: Vec<Field>, | ||
/// The function output type. | ||
pub return_type: GetOutput, | ||
/// The function implementation. | ||
pub fun: SpecialEq<Arc<dyn SeriesUdf>>, | ||
/// Options for the function. | ||
pub options: FunctionOptions, | ||
} | ||
|
||
impl std::fmt::Debug for UserDefinedFunction { | ||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { | ||
f.debug_struct("UserDefinedFunction") | ||
.field("name", &self.name) | ||
.field("signature", &self.input_fields) | ||
.field("fun", &"<FUNC>") | ||
.field("options", &self.options) | ||
.finish() | ||
} | ||
} | ||
|
||
impl UserDefinedFunction { | ||
/// Create a new UserDefinedFunction | ||
pub fn new( | ||
name: &str, | ||
input_fields: Vec<Field>, | ||
return_type: GetOutput, | ||
fun: impl SeriesUdf + 'static, | ||
) -> Self { | ||
Self { | ||
name: name.to_owned(), | ||
input_fields, | ||
return_type, | ||
fun: SpecialEq::new(Arc::new(fun)), | ||
options: FunctionOptions::default(), | ||
} | ||
} | ||
|
||
/// creates a logical expression with a call of the UDF | ||
/// This utility allows using the UDF without requiring access to the registry. | ||
/// The schema is validated and the query will fail if the schema is invalid. | ||
pub fn call(self, args: Vec<Expr>) -> PolarsResult<Expr> { | ||
if args.len() != self.input_fields.len() { | ||
polars_bail!(InvalidOperation: "expected {} arguments, got {}", self.input_fields.len(), args.len()) | ||
} | ||
let schema = Schema::from_iter(self.input_fields); | ||
|
||
if args | ||
.iter() | ||
.map(|e| e.to_field(&schema, Context::Default)) | ||
.collect::<PolarsResult<Vec<_>>>() | ||
.is_err() | ||
{ | ||
polars_bail!(InvalidOperation: "unexpected field in UDF \nexpected: {:?}\n received {:?}", schema, args) | ||
}; | ||
|
||
Ok(Expr::AnonymousFunction { | ||
input: args, | ||
function: self.fun, | ||
output_type: self.return_type, | ||
options: self.options, | ||
}) | ||
} | ||
|
||
/// creates a logical expression with a call of the UDF | ||
/// This does not do any schema validation and is therefore faster. | ||
/// | ||
/// Only use this if you are certain that the schema is correct. | ||
/// If the schema is invalid, the query will fail at runtime. | ||
pub fn call_unchecked(self, args: Vec<Expr>) -> Expr { | ||
Expr::AnonymousFunction { | ||
input: args, | ||
function: self.fun, | ||
output_type: self.return_type.clone(), | ||
options: self.options, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
//! This module defines the function registry and user defined functions. | ||
|
||
use polars_arrow::error::{polars_bail, PolarsResult}; | ||
use polars_plan::prelude::udf::UserDefinedFunction; | ||
pub use polars_plan::prelude::{Context, FunctionOptions}; | ||
/// A registry that holds user defined functions. | ||
pub trait FunctionRegistry: Send + Sync { | ||
/// Register a function. | ||
fn register(&mut self, name: &str, fun: UserDefinedFunction) -> PolarsResult<()>; | ||
/// Call a user defined function. | ||
fn get_udf(&self, name: &str) -> PolarsResult<Option<UserDefinedFunction>>; | ||
/// Check if a function is registered. | ||
fn contains(&self, name: &str) -> bool; | ||
} | ||
|
||
/// A default registry that does not support registering or calling functions. | ||
pub struct DefaultFunctionRegistry {} | ||
|
||
impl FunctionRegistry for DefaultFunctionRegistry { | ||
fn register(&mut self, _name: &str, _fun: UserDefinedFunction) -> PolarsResult<()> { | ||
polars_bail!(ComputeError: "'register' not implemented on DefaultFunctionRegistry'") | ||
} | ||
|
||
fn get_udf(&self, _name: &str) -> PolarsResult<Option<UserDefinedFunction>> { | ||
polars_bail!(ComputeError: "'get_udf' not implemented on DefaultFunctionRegistry'") | ||
} | ||
fn contains(&self, _name: &str) -> bool { | ||
false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.