From 240f6a4f5064638bb2029c7c61925e36bf691159 Mon Sep 17 00:00:00 2001 From: Fomegne Date: Tue, 11 Apr 2023 21:57:10 +0100 Subject: [PATCH] apply suggestions --- core/src/executor/evaluate/mod.rs | 104 +++++++----------- .../src/function/{custom.rs => custom/mod.rs} | 10 +- 2 files changed, 41 insertions(+), 73 deletions(-) rename test-suite/src/function/{custom.rs => custom/mod.rs} (96%) diff --git a/core/src/executor/evaluate/mod.rs b/core/src/executor/evaluate/mod.rs index a0856e259..8f4731bf3 100644 --- a/core/src/executor/evaluate/mod.rs +++ b/core/src/executor/evaluate/mod.rs @@ -8,7 +8,7 @@ use { super::{context::RowContext, select::select}, crate::{ ast::{Aggregate, Expr, Function}, - data::{Interval, Literal, Row, Value}, + data::{CustomFunction, Interval, Literal, Row, Value}, result::{Error, Result}, store::GStore, }, @@ -19,7 +19,7 @@ use { stream::{self, StreamExt, TryStreamExt}, }, im_rc::HashMap, - std::{borrow::Cow, collections::HashMap as StdHashMap, rc::Rc}, + std::{borrow::Cow, rc::Rc}, }; pub use {error::EvaluateError, evaluated::Evaluated, stateless::evaluate_stateless}; @@ -286,13 +286,6 @@ async fn evaluate_function<'a, 'b: 'a, 'c: 'a, T: GStore>( evaluate(storage, context, aggregated, expr) }; - let eval_with_context = |expr: &'a Expr, context: Rc>| { - let context = Some(Rc::clone(&context)); - let aggregated = aggregated.as_ref().map(Rc::clone); - - evaluate(storage, context, aggregated, expr) - }; - let name = func.to_string(); match func { @@ -302,73 +295,50 @@ async fn evaluate_function<'a, 'b: 'a, 'c: 'a, T: GStore>( f::concat(exprs) } Function::Custom { name, exprs } => { - let custom_func = storage + let CustomFunction { + func_name, + args, + body, + } = storage .fetch_function(name) .await? .ok_or_else(|| EvaluateError::UnsupportedFunction(name.to_string()))?; - let args = stream::iter(exprs) - .then(eval) - .try_collect::>() - .await?; - let args = args - .into_iter() - .map(Value::try_from) - .collect::>>()?; - let fargs = &custom_func.args; + let min = args.iter().filter(|arg| arg.default.is_none()).count(); + let max = args.len(); - let dargs = fargs - .iter() - .filter_map(|y| y.default.as_ref()) - .collect::>(); - let dargs = stream::iter(dargs) - .then(eval) - .try_collect::>() - .await?; - let dargs = dargs - .into_iter() - .map(Value::try_from) - .collect::>>()?; - let mut dargs = dargs.iter(); - - let min = fargs.len() - dargs.len(); - let max = fargs.len(); - - let value = if (min..=max).contains(&args.len()) { - let mut hm = StdHashMap::new(); - - fargs - .iter() - .enumerate() - .try_for_each(|(i, farg)| -> Result<()> { - let arg = args.get(i).unwrap_or(&Value::Null); - arg.validate_type(&farg.data_type)?; - arg.validate_null(farg.default.is_some())?; - let value = if arg.is_null() { - dargs.next().unwrap() - } else { - arg - }; - hm.insert(farg.name.to_owned(), value.to_owned()); - Ok(()) - })?; - - let row = Row::Map(hm); - let rowcontext = RowContext::new(name, Cow::Owned(row), None); - let context = Rc::new(rowcontext); - - eval_with_context(&custom_func.body, context).await - } else { - Err((EvaluateError::FunctionArgsLengthNotWithinRange { - name: custom_func.func_name.to_owned(), + if !(min..=max).contains(&exprs.len()) { + return Err((EvaluateError::FunctionArgsLengthNotWithinRange { + name: func_name.to_owned(), expected_minimum: min, expected_maximum: max, - found: args.len(), + found: exprs.len(), }) - .into()) - }; + .into()); + } + + let exprs = exprs.iter().chain( + args.iter() + .skip(exprs.len()) + .filter_map(|arg| arg.default.as_ref()), + ); + + let context = stream::iter(args.iter().zip(exprs)) + .then(|(arg, expr)| async { + eval(expr) + .await? + .try_into_value(&arg.data_type, true) + .map(|value| (arg.name.to_owned(), value)) + }) + .try_collect() + .await + .map(|values| { + let row = Cow::Owned(Row::Map(values)); + let context = RowContext::new(name, row, None); + Some(Rc::new(context)) + })?; - Ok(value?) + evaluate(storage, context, None, body).await } Function::ConcatWs { separator, exprs } => { let separator = eval(separator).await?; diff --git a/test-suite/src/function/custom.rs b/test-suite/src/function/custom/mod.rs similarity index 96% rename from test-suite/src/function/custom.rs rename to test-suite/src/function/custom/mod.rs index 96461c18c..61a8b4399 100644 --- a/test-suite/src/function/custom.rs +++ b/test-suite/src/function/custom/mod.rs @@ -76,18 +76,16 @@ test_case!(custom, async move { ), ( "SELECT add_two(1, null, 2) as r", - Ok(select!( - r - I64; - 4 + Ok(select_with_null!( + r; Null )), ), ( - "SELECT add_two(1, 2) as r", + "SELECT add_two(1) as r", Ok(select!( r I64; - 4 + 3 )), ), ("DROP FUNCTION add_none", Ok(Payload::DropFunction)),