Skip to content

Commit

Permalink
apply suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
pythonbrad committed Apr 11, 2023
1 parent 749dbbf commit 240f6a4
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 73 deletions.
104 changes: 37 additions & 67 deletions core/src/executor/evaluate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use {
super::{context::RowContext, select::select},
crate::{
ast::{Aggregate, Expr, Function},
data::{Interval, Literal, Row, Value},
data::{CustomFunction, Interval, Literal, Row, Value},
result::{Error, Result},
store::GStore,
},
Expand All @@ -19,7 +19,7 @@ use {
stream::{self, StreamExt, TryStreamExt},
},
im_rc::HashMap,
std::{borrow::Cow, collections::HashMap as StdHashMap, rc::Rc},
std::{borrow::Cow, rc::Rc},
};

pub use {error::EvaluateError, evaluated::Evaluated, stateless::evaluate_stateless};
Expand Down Expand Up @@ -286,13 +286,6 @@ async fn evaluate_function<'a, 'b: 'a, 'c: 'a, T: GStore>(
evaluate(storage, context, aggregated, expr)
};

let eval_with_context = |expr: &'a Expr, context: Rc<RowContext<'b>>| {
let context = Some(Rc::clone(&context));
let aggregated = aggregated.as_ref().map(Rc::clone);

evaluate(storage, context, aggregated, expr)
};

let name = func.to_string();

match func {
Expand All @@ -302,73 +295,50 @@ async fn evaluate_function<'a, 'b: 'a, 'c: 'a, T: GStore>(
f::concat(exprs)
}
Function::Custom { name, exprs } => {
let custom_func = storage
let CustomFunction {
func_name,
args,
body,
} = storage
.fetch_function(name)
.await?
.ok_or_else(|| EvaluateError::UnsupportedFunction(name.to_string()))?;
let args = stream::iter(exprs)
.then(eval)
.try_collect::<Vec<_>>()
.await?;
let args = args
.into_iter()
.map(Value::try_from)
.collect::<Result<Vec<_>>>()?;

let fargs = &custom_func.args;
let min = args.iter().filter(|arg| arg.default.is_none()).count();
let max = args.len();

let dargs = fargs
.iter()
.filter_map(|y| y.default.as_ref())
.collect::<Vec<_>>();
let dargs = stream::iter(dargs)
.then(eval)
.try_collect::<Vec<_>>()
.await?;
let dargs = dargs
.into_iter()
.map(Value::try_from)
.collect::<Result<Vec<_>>>()?;
let mut dargs = dargs.iter();

let min = fargs.len() - dargs.len();
let max = fargs.len();

let value = if (min..=max).contains(&args.len()) {
let mut hm = StdHashMap::new();

fargs
.iter()
.enumerate()
.try_for_each(|(i, farg)| -> Result<()> {
let arg = args.get(i).unwrap_or(&Value::Null);
arg.validate_type(&farg.data_type)?;
arg.validate_null(farg.default.is_some())?;
let value = if arg.is_null() {
dargs.next().unwrap()
} else {
arg
};
hm.insert(farg.name.to_owned(), value.to_owned());
Ok(())
})?;

let row = Row::Map(hm);
let rowcontext = RowContext::new(name, Cow::Owned(row), None);
let context = Rc::new(rowcontext);

eval_with_context(&custom_func.body, context).await
} else {
Err((EvaluateError::FunctionArgsLengthNotWithinRange {
name: custom_func.func_name.to_owned(),
if !(min..=max).contains(&exprs.len()) {
return Err((EvaluateError::FunctionArgsLengthNotWithinRange {
name: func_name.to_owned(),
expected_minimum: min,
expected_maximum: max,
found: args.len(),
found: exprs.len(),
})
.into())
};
.into());
}

let exprs = exprs.iter().chain(
args.iter()
.skip(exprs.len())
.filter_map(|arg| arg.default.as_ref()),
);

let context = stream::iter(args.iter().zip(exprs))
.then(|(arg, expr)| async {
eval(expr)
.await?
.try_into_value(&arg.data_type, true)
.map(|value| (arg.name.to_owned(), value))
})
.try_collect()
.await
.map(|values| {
let row = Cow::Owned(Row::Map(values));
let context = RowContext::new(name, row, None);
Some(Rc::new(context))
})?;

Ok(value?)
evaluate(storage, context, None, body).await
}
Function::ConcatWs { separator, exprs } => {
let separator = eval(separator).await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,16 @@ test_case!(custom, async move {
),
(
"SELECT add_two(1, null, 2) as r",
Ok(select!(
r
I64;
4
Ok(select_with_null!(
r; Null
)),
),
(
"SELECT add_two(1, 2) as r",
"SELECT add_two(1) as r",
Ok(select!(
r
I64;
4
3
)),
),
("DROP FUNCTION add_none", Ok(Payload::DropFunction)),
Expand Down

0 comments on commit 240f6a4

Please sign in to comment.