diff --git a/proto/expr.proto b/proto/expr.proto index 957701887659..802397f0456f 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -97,6 +97,7 @@ message ExprNode { TRANSLATE = 216; COALESCE = 217; CONCAT_WS = 218; + CONCAT_WS_VARIADIC = 285; ABS = 219; SPLIT_PART = 220; CEIL = 221; @@ -106,6 +107,8 @@ message ExprNode { CHAR_LENGTH = 225; REPEAT = 226; CONCAT_OP = 227; + CONCAT = 286; + CONCAT_VARIADIC = 287; // BOOL_OUT is different from CAST-bool-to-varchar in PostgreSQL. BOOL_OUT = 228; OCTET_LENGTH = 229; @@ -182,6 +185,7 @@ message ExprNode { LEFT = 317; RIGHT = 318; FORMAT = 319; + FORMAT_VARIADIC = 326; PGWIRE_SEND = 320; PGWIRE_RECV = 321; CONVERT_FROM = 322; @@ -232,9 +236,11 @@ message ExprNode { // jsonb ->> int, jsonb ->> text that returns text JSONB_ACCESS_STR = 601; // jsonb #> text[] -> jsonb - JSONB_EXTRACT_PATH = 613; + JSONB_EXTRACT_PATH = 627; + JSONB_EXTRACT_PATH_VARIADIC = 613; // jsonb #>> text[] -> text - JSONB_EXTRACT_PATH_TEXT = 614; + JSONB_EXTRACT_PATH_TEXT = 628; + JSONB_EXTRACT_PATH_TEXT_VARIADIC = 614; JSONB_TYPEOF = 602; JSONB_ARRAY_LENGTH = 603; IS_JSON = 604; @@ -261,7 +267,9 @@ message ExprNode { JSONB_STRIP_NULLS = 616; TO_JSONB = 617; JSONB_BUILD_ARRAY = 618; + JSONB_BUILD_ARRAY_VARIADIC = 625; JSONB_BUILD_OBJECT = 619; + JSONB_BUILD_OBJECT_VARIADIC = 626; JSONB_PATH_EXISTS = 620; JSONB_PATH_MATCH = 621; JSONB_PATH_QUERY_ARRAY = 622; diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 45ef845835be..8f0235f8145d 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -612,6 +612,24 @@ impl Debug for ListRef<'_> { } } +impl Row for ListRef<'_> { + fn datum_at(&self, index: usize) -> DatumRef<'_> { + self.array.value_at(self.start as usize + index) + } + + unsafe fn datum_at_unchecked(&self, index: usize) -> DatumRef<'_> { + self.array.value_at_unchecked(self.start as usize + index) + } + + fn len(&self) -> usize { + self.len() + } + + fn iter(&self) -> impl Iterator> { + (*self).iter() + } +} + impl ToText for ListRef<'_> { // This function will be invoked when pgwire prints a list value in string. // Refer to PostgreSQL `array_out` or `appendPGArray`. diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index d1ced2b3322c..6dbb3906f561 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -209,4 +209,6 @@ where pub struct Context { pub arg_types: Vec, pub return_type: DataType, + /// Whether the function is variadic. + pub variadic: bool, } diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index bf8afc7712f9..bff225ad687f 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -302,6 +302,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{}", &ctx).unwrap(), @@ -314,6 +315,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert_eq!(str_to_list("{1, 2, 3}", &ctx).unwrap(), list123); @@ -322,6 +324,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[][]").unwrap(), + variadic: false, }; assert_eq!(str_to_list("{{1, 2, 3}}", &ctx).unwrap(), nested_list123); @@ -334,6 +337,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[][][]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(), @@ -344,6 +348,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::from_str("int[][]").unwrap()], return_type: DataType::from_str("varchar[][]").unwrap(), + variadic: false, }; let double_nested_varchar_list123_445566 = ListValue::from_iter([ list_cast(nested_list123.as_scalar_ref(), &ctx).unwrap(), @@ -354,6 +359,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("varchar[][][]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(), @@ -367,6 +373,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert!(str_to_list("{{}", &ctx).is_err()); assert!(str_to_list("{}}", &ctx).is_err()); @@ -385,6 +392,7 @@ mod tests { ("a", DataType::Int32), ("b", DataType::Int32), ])), + variadic: false, }; assert_eq!( struct_cast( @@ -420,6 +428,7 @@ mod tests { let ctx_str_to_int16 = Context { arg_types: vec![DataType::Varchar], return_type: DataType::Int16, + variadic: false, }; test_str_to_int16::(|x| str_parse(x, &ctx_str_to_int16).unwrap()).await; } diff --git a/src/expr/impl/src/scalar/concat.rs b/src/expr/impl/src/scalar/concat.rs new file mode 100644 index 000000000000..9359c75a7af2 --- /dev/null +++ b/src/expr/impl/src/scalar/concat.rs @@ -0,0 +1,73 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Write; + +use risingwave_common::row::Row; +use risingwave_common::types::ToText; +use risingwave_expr::function; + +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +/// +/// # Example +/// +/// ```slt +/// query T +/// select concat('abcde', 2, NULL, 22); +/// ---- +/// abcde222 +/// +/// query T +/// select concat(variadic array['abcde', '2', NULL, '22']); +/// ---- +/// abcde222 +/// ``` +#[function("concat(variadic anyarray) -> varchar")] +fn concat(vals: impl Row, writer: &mut impl Write) { + for string in vals.iter().flatten() { + string.write(writer).unwrap(); + } +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_concat() { + let concat = build_from_pretty("(concat:varchar $0:varchar $1:varchar $2:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T T T T + a b c abc + . b c bc + . . . (empty)", + ) + .split_column_at(3); + + // test eval + let output = concat.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = concat.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/concat_ws.rs b/src/expr/impl/src/scalar/concat_ws.rs index ec979d4cbacd..4c2ce3d56ac5 100644 --- a/src/expr/impl/src/scalar/concat_ws.rs +++ b/src/expr/impl/src/scalar/concat_ws.rs @@ -20,8 +20,22 @@ use risingwave_expr::function; /// Concatenates all but the first argument, with separators. The first argument is used as the /// separator string, and should not be NULL. Other NULL arguments are ignored. -#[function("concat_ws(varchar, ...) -> varchar")] -fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { +/// +/// # Example +/// +/// ```slt +/// query T +/// select concat_ws(',', 'abcde', 2, NULL, 22); +/// ---- +/// abcde,2,22 +/// +/// query T +/// select concat_ws(',', variadic array['abcde', 2, NULL, 22] :: varchar[]); +/// ---- +/// abcde,2,22 +/// ``` +#[function("concat_ws(varchar, variadic anyarray) -> varchar")] +fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) { let mut string_iter = vals.iter().flatten(); if let Some(string) = string_iter.next() { string.write(writer).unwrap(); @@ -30,7 +44,6 @@ fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { write!(writer, "{}", sep).unwrap(); string.write(writer).unwrap(); } - Some(()) } #[cfg(test)] diff --git a/src/expr/impl/src/scalar/format.rs b/src/expr/impl/src/scalar/format.rs index 24256ee87f9f..50195638e4d0 100644 --- a/src/expr/impl/src/scalar/format.rs +++ b/src/expr/impl/src/scalar/format.rs @@ -23,11 +23,25 @@ use thiserror_ext::AsReport; use super::string::quote_ident; /// Formats arguments according to a format string. +/// +/// # Example +/// +/// ```slt +/// query T +/// select format('%s %s', 'Hello', 'World'); +/// ---- +/// Hello World +/// +/// query T +/// select format('%s %s', variadic array['Hello', 'World']); +/// ---- +/// Hello World +/// ``` #[function( - "format(varchar, ...) -> varchar", + "format(varchar, variadic anyarray) -> varchar", prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" )] -fn format(formatter: &Formatter, row: impl Row, writer: &mut impl Write) -> Result<()> { +fn format(row: impl Row, formatter: &Formatter, writer: &mut impl Write) -> Result<()> { let mut args = row.iter(); for node in &formatter.nodes { match node { diff --git a/src/expr/impl/src/scalar/jsonb_access.rs b/src/expr/impl/src/scalar/jsonb_access.rs index 36ced44bf357..05578e34b17d 100644 --- a/src/expr/impl/src/scalar/jsonb_access.rs +++ b/src/expr/impl/src/scalar/jsonb_access.rs @@ -14,7 +14,8 @@ use std::fmt::Write; -use risingwave_common::types::{JsonbRef, ListRef}; +use risingwave_common::row::Row; +use risingwave_common::types::JsonbRef; use risingwave_expr::function; /// Extracts JSON object field with the given key. @@ -91,9 +92,14 @@ pub fn jsonb_array_element(v: JsonbRef<'_>, p: i32) -> Option> { /// select jsonb_extract_path('{"a": {"b": ["foo","bar"]}}', 'a', 'b', '1'); /// ---- /// "bar" +/// +/// query T +/// select jsonb_extract_path('{"a": {"b": ["foo","bar"]}}', variadic array['a', 'b', '1']); +/// ---- +/// "bar" /// ``` -#[function("jsonb_extract_path(jsonb, varchar[]) -> jsonb")] -pub fn jsonb_extract_path<'a>(v: JsonbRef<'a>, path: ListRef<'_>) -> Option> { +#[function("jsonb_extract_path(jsonb, variadic varchar[]) -> jsonb")] +pub fn jsonb_extract_path(v: JsonbRef<'_>, path: impl Row) -> Option> { let mut jsonb = v; for key in path.iter() { // return null if any element is null @@ -192,11 +198,16 @@ pub fn jsonb_array_element_str(v: JsonbRef<'_>, p: i32, writer: &mut impl Write) /// select jsonb_extract_path_text('{"a": {"b": ["foo","bar"]}}', 'a', 'b', '1'); /// ---- /// bar +/// +/// query T +/// select jsonb_extract_path_text('{"a": {"b": ["foo","bar"]}}', variadic array['a', 'b', '1']); +/// ---- +/// bar /// ``` -#[function("jsonb_extract_path_text(jsonb, varchar[]) -> varchar")] +#[function("jsonb_extract_path_text(jsonb, variadic varchar[]) -> varchar")] pub fn jsonb_extract_path_text( v: JsonbRef<'_>, - path: ListRef<'_>, + path: impl Row, writer: &mut impl Write, ) -> Option<()> { let jsonb = jsonb_extract_path(v, path)?; diff --git a/src/expr/impl/src/scalar/jsonb_build.rs b/src/expr/impl/src/scalar/jsonb_build.rs index 85b24d7126f1..5949faf9bc5c 100644 --- a/src/expr/impl/src/scalar/jsonb_build.rs +++ b/src/expr/impl/src/scalar/jsonb_build.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use itertools::Either; use jsonbb::Builder; use risingwave_common::row::Row; use risingwave_common::types::{JsonbVal, ScalarRefImpl}; @@ -31,13 +32,25 @@ use super::{ToJsonb, ToTextDisplay}; /// select jsonb_build_array(1, 2, 'foo', 4, 5); /// ---- /// [1, 2, "foo", 4, 5] +/// +/// query T +/// select jsonb_build_array(variadic array[1, 2, 4, 5]); +/// ---- +/// [1, 2, 4, 5] /// ``` -#[function("jsonb_build_array(...) -> jsonb")] +#[function("jsonb_build_array(variadic anyarray) -> jsonb")] fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { let mut builder = Builder::>::new(); builder.begin_array(); - for (value, ty) in args.iter().zip_eq_debug(&ctx.arg_types) { - value.add_to(ty, &mut builder)?; + if ctx.variadic { + for (value, ty) in args.iter().zip_eq_debug(&ctx.arg_types) { + value.add_to(ty, &mut builder)?; + } + } else { + let ty = ctx.arg_types[0].as_list(); + for value in args.iter() { + value.add_to(ty, &mut builder)?; + } } builder.end_array(); Ok(builder.finish().into()) @@ -54,8 +67,13 @@ fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { /// select jsonb_build_object('foo', 1, 2, 'bar'); /// ---- /// {"2": "bar", "foo": 1} +/// +/// query T +/// select jsonb_build_object(variadic array['foo', '1', '2', 'bar']); +/// ---- +/// {"2": "bar", "foo": "1"} /// ``` -#[function("jsonb_build_object(...) -> jsonb")] +#[function("jsonb_build_object(variadic anyarray) -> jsonb")] fn jsonb_build_object(args: impl Row, ctx: &Context) -> Result { if args.len() % 2 == 1 { return Err(ExprError::InvalidParam { @@ -65,9 +83,13 @@ fn jsonb_build_object(args: impl Row, ctx: &Context) -> Result { } let mut builder = Builder::>::new(); builder.begin_object(); + let arg_types = match ctx.variadic { + true => Either::Left(ctx.arg_types.iter()), + false => Either::Right(itertools::repeat_n(ctx.arg_types[0].as_list(), args.len())), + }; for (i, [(key, _), (value, value_type)]) in args .iter() - .zip_eq_debug(&ctx.arg_types) + .zip_eq_debug(arg_types) .array_chunks() .enumerate() { diff --git a/src/expr/impl/src/scalar/mod.rs b/src/expr/impl/src/scalar/mod.rs index c5dc3defc57b..27135a739763 100644 --- a/src/expr/impl/src/scalar/mod.rs +++ b/src/expr/impl/src/scalar/mod.rs @@ -35,6 +35,7 @@ mod case; mod cast; mod cmp; mod coalesce; +mod concat; mod concat_op; mod concat_ws; mod conjunction; diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 89033ccdbd71..1da6470cf38e 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -23,6 +23,42 @@ use super::*; impl FunctionAttr { /// Expands the wildcard in function arguments or return type. pub fn expand(&self) -> Vec { + // handle variadic argument + if self + .args + .last() + .is_some_and(|arg| arg.starts_with("variadic")) + { + // expand: foo(a, b, variadic anyarray) + // to: foo(a, b, ...) + // + foo_variadic(a, b, anyarray) + let mut attrs = Vec::new(); + attrs.extend( + FunctionAttr { + args: { + let mut args = self.args.clone(); + *args.last_mut().unwrap() = "...".to_string(); + args + }, + ..self.clone() + } + .expand(), + ); + attrs.extend( + FunctionAttr { + name: format!("{}_variadic", self.name), + args: { + let mut args = self.args.clone(); + let last = args.last_mut().unwrap(); + *last = last.strip_prefix("variadic ").unwrap().into(); + args + }, + ..self.clone() + } + .expand(), + ); + return attrs; + } let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty)); let ret = types::expand_type_wildcard(&self.ret); // multi_cartesian_product should emit an empty set if the input is empty. @@ -314,8 +350,8 @@ impl FunctionAttr { // inputs: [ Option ] let mut output = quote! { #fn_name #generic( #(#non_prebuilt_inputs,)* - #prebuilt_arg #variadic_args + #prebuilt_arg #context #writer ) #await_ }; @@ -501,6 +537,7 @@ impl FunctionAttr { let context = Context { return_type, arg_types: children.iter().map(|c| c.return_type()).collect(), + variadic: #variadic, }; #[derive(Debug)] @@ -851,6 +888,7 @@ impl FunctionAttr { let context = Context { return_type: agg.return_type.clone(), arg_types: agg.args.arg_types().to_owned(), + variadic: false, }; struct Agg { diff --git a/src/frontend/planner_test/tests/testdata/input/expr.yaml b/src/frontend/planner_test/tests/testdata/input/expr.yaml index cc8f104e1ac8..e19ee4a29126 100644 --- a/src/frontend/planner_test/tests/testdata/input/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/input/expr.yaml @@ -454,3 +454,15 @@ expected_outputs: - batch_plan - stream_error +- name: without variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', a, b) from t; + expected_outputs: + - batch_plan +- name: with variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', variadic array[a, b]) from t; + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index ef6727193f78..4c23cadf7cb4 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -635,3 +635,19 @@ Not supported: streaming nested-loop join HINT: The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. Consider rewriting the query to use dynamic filter as a substitute if possible. See also: https://github.com/risingwavelabs/rfcs/blob/main/rfcs/0033-dynamic-filter.md +- name: without variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', a, b) from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [ConcatWs(',':Varchar, t.a, t.b) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } +- name: with variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', variadic array[a, b]) from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [ConcatWsVariadic(',':Varchar, Array(t.a, t.b)) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } diff --git a/src/frontend/src/binder/expr/binary_op.rs b/src/frontend/src/binder/expr/binary_op.rs index d8a2ddae1b1a..77eff662802c 100644 --- a/src/frontend/src/binder/expr/binary_op.rs +++ b/src/frontend/src/binder/expr/binary_op.rs @@ -113,8 +113,8 @@ impl Binder { BinaryOperator::Arrow => ExprType::JsonbAccess, BinaryOperator::LongArrow => ExprType::JsonbAccessStr, BinaryOperator::HashMinus => ExprType::JsonbDeletePath, - BinaryOperator::HashArrow => ExprType::JsonbExtractPath, - BinaryOperator::HashLongArrow => ExprType::JsonbExtractPathText, + BinaryOperator::HashArrow => ExprType::JsonbExtractPathVariadic, + BinaryOperator::HashLongArrow => ExprType::JsonbExtractPathTextVariadic, BinaryOperator::Prefix => ExprType::StartsWith, BinaryOperator::Contains => { let left_type = (!bound_left.is_untyped()).then(|| bound_left.return_type()); diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 9918f3e73fa7..6a72b2f2fd20 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -273,7 +273,7 @@ impl Binder { } } - self.bind_builtin_scalar_function(function_name.as_str(), inputs) + self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.variadic) } fn bind_array_transform(&mut self, f: Function) -> Result { @@ -791,6 +791,7 @@ impl Binder { &mut self, function_name: &str, inputs: Vec, + variadic: bool, ) -> Result { type Inputs = Vec; @@ -1075,7 +1076,7 @@ impl Binder { guard_by_len(2, raw(|binder, inputs| { let (arg0, arg1) = inputs.into_iter().next_tuple().unwrap(); // rewrite into `CASE WHEN 0 < arg1 AND arg1 <= array_ndims(arg0) THEN 1 END` - let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0])?; + let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0], false)?; let arg1 = arg1.cast_implicit(DataType::Int32)?; FunctionCall::new( @@ -1102,36 +1103,8 @@ impl Binder { ("jsonb_array_element", raw_call(ExprType::JsonbAccess)), ("jsonb_object_field_text", raw_call(ExprType::JsonbAccessStr)), ("jsonb_array_element_text", raw_call(ExprType::JsonbAccessStr)), - ("jsonb_extract_path", raw(|_binder, mut inputs| { - // rewrite: jsonb_extract_path(jsonb, s1, s2...) - // to: jsonb_extract_path(jsonb, array[s1, s2...]) - if inputs.len() < 2 { - return Err(ErrorCode::ExprError("unexpected arguments number".into()).into()); - } - inputs[0].cast_implicit_mut(DataType::Jsonb)?; - let mut variadic_inputs = inputs.split_off(1); - for input in &mut variadic_inputs { - input.cast_implicit_mut(DataType::Varchar)?; - } - let array = FunctionCall::new_unchecked(ExprType::Array, variadic_inputs, DataType::List(Box::new(DataType::Varchar))); - inputs.push(array.into()); - Ok(FunctionCall::new_unchecked(ExprType::JsonbExtractPath, inputs, DataType::Jsonb).into()) - })), - ("jsonb_extract_path_text", raw(|_binder, mut inputs| { - // rewrite: jsonb_extract_path_text(jsonb, s1, s2...) - // to: jsonb_extract_path_text(jsonb, array[s1, s2...]) - if inputs.len() < 2 { - return Err(ErrorCode::ExprError("unexpected arguments number".into()).into()); - } - inputs[0].cast_implicit_mut(DataType::Jsonb)?; - let mut variadic_inputs = inputs.split_off(1); - for input in &mut variadic_inputs { - input.cast_implicit_mut(DataType::Varchar)?; - } - let array = FunctionCall::new_unchecked(ExprType::Array, variadic_inputs, DataType::List(Box::new(DataType::Varchar))); - inputs.push(array.into()); - Ok(FunctionCall::new_unchecked(ExprType::JsonbExtractPathText, inputs, DataType::Varchar).into()) - })), + ("jsonb_extract_path", raw_call(ExprType::JsonbExtractPath)), + ("jsonb_extract_path_text", raw_call(ExprType::JsonbExtractPathText)), ("jsonb_typeof", raw_call(ExprType::JsonbTypeof)), ("jsonb_array_length", raw_call(ExprType::JsonbArrayLength)), ("jsonb_concat", raw_call(ExprType::JsonbConcat)), @@ -1388,6 +1361,26 @@ impl Binder { tree }); + if variadic { + let func = match function_name { + "format" => ExprType::FormatVariadic, + "concat" => ExprType::ConcatVariadic, + "concat_ws" => ExprType::ConcatWsVariadic, + "jsonb_build_array" => ExprType::JsonbBuildArrayVariadic, + "jsonb_build_object" => ExprType::JsonbBuildObjectVariadic, + "jsonb_extract_path" => ExprType::JsonbExtractPathVariadic, + "jsonb_extract_path_text" => ExprType::JsonbExtractPathTextVariadic, + _ => { + return Err(ErrorCode::BindError(format!( + "VARIADIC argument is not allowed in function \"{}\"", + function_name + )) + .into()) + } + }; + return Ok(FunctionCall::new(func, inputs)?.into()); + } + match HANDLES.get(function_name) { Some(handle) => handle(self, inputs), None => { diff --git a/src/frontend/src/binder/relation/table_function.rs b/src/frontend/src/binder/relation/table_function.rs index f7d5c803ea63..a0c70f58f1cb 100644 --- a/src/frontend/src/binder/relation/table_function.rs +++ b/src/frontend/src/binder/relation/table_function.rs @@ -108,6 +108,7 @@ impl Binder { let func = self.bind_function(Function { name, args, + variadic: false, over: None, distinct: false, order_by: vec![], diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index bad7706c83f2..bc56733a4423 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -92,6 +92,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Translate | expr_node::Type::Coalesce | expr_node::Type::ConcatWs + | expr_node::Type::ConcatWsVariadic | expr_node::Type::Abs | expr_node::Type::SplitPart | expr_node::Type::Ceil @@ -102,6 +103,8 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::CharLength | expr_node::Type::Repeat | expr_node::Type::ConcatOp + | expr_node::Type::Concat + | expr_node::Type::ConcatVariadic | expr_node::Type::BoolOut | expr_node::Type::OctetLength | expr_node::Type::BitLength @@ -181,7 +184,9 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::JsonbAccess | expr_node::Type::JsonbAccessStr | expr_node::Type::JsonbExtractPath + | expr_node::Type::JsonbExtractPathVariadic | expr_node::Type::JsonbExtractPathText + | expr_node::Type::JsonbExtractPathTextVariadic | expr_node::Type::JsonbTypeof | expr_node::Type::JsonbArrayLength | expr_node::Type::JsonbObject @@ -194,7 +199,9 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::JsonbExistsAll | expr_node::Type::JsonbStripNulls | expr_node::Type::JsonbBuildArray + | expr_node::Type::JsonbBuildArrayVariadic | expr_node::Type::JsonbBuildObject + | expr_node::Type::JsonbBuildObjectVariadic | expr_node::Type::JsonbPathExists | expr_node::Type::JsonbPathMatch | expr_node::Type::JsonbPathQueryArray @@ -225,6 +232,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::ArrayPositions | expr_node::Type::StringToArray | expr_node::Type::Format + | expr_node::Type::FormatVariadic | expr_node::Type::PgwireSend | expr_node::Type::PgwireRecv | expr_node::Type::ArrayTransform diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index f80a131e7173..9f28dfeb74c8 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -350,6 +350,10 @@ fn infer_type_for_special( ensure_arity!("coalesce", 1 <= | inputs |); align_types(inputs.iter_mut()).map(Some).map_err(Into::into) } + ExprType::Concat => { + ensure_arity!("concat", 1 <= | inputs |); + Ok(Some(DataType::Varchar)) + } ExprType::ConcatWs => { ensure_arity!("concat_ws", 2 <= | inputs |); // 0-th arg must be string @@ -610,6 +614,22 @@ fn infer_type_for_special( } Ok(Some(DataType::Jsonb)) } + ExprType::JsonbExtractPath => { + ensure_arity!("jsonb_extract_path", 2 <= | inputs |); + inputs[0].cast_implicit_mut(DataType::Jsonb)?; + for input in inputs.iter_mut().skip(1) { + input.cast_implicit_mut(DataType::Varchar)?; + } + Ok(Some(DataType::Jsonb)) + } + ExprType::JsonbExtractPathText => { + ensure_arity!("jsonb_extract_path_text", 2 <= | inputs |); + inputs[0].cast_implicit_mut(DataType::Jsonb)?; + for input in inputs.iter_mut().skip(1) { + input.cast_implicit_mut(DataType::Varchar)?; + } + Ok(Some(DataType::Varchar)) + } _ => Ok(None), } } diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index dca2063dac92..a80a696c2c16 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -147,7 +147,10 @@ impl Strong { | ExprType::Round | ExprType::Ascii | ExprType::Translate + | ExprType::Concat + | ExprType::ConcatVariadic | ExprType::ConcatWs + | ExprType::ConcatWsVariadic | ExprType::Abs | ExprType::SplitPart | ExprType::ToChar @@ -217,6 +220,7 @@ impl Strong { | ExprType::Left | ExprType::Right | ExprType::Format + | ExprType::FormatVariadic | ExprType::PgwireSend | ExprType::PgwireRecv | ExprType::ConvertFrom @@ -255,7 +259,9 @@ impl Strong { | ExprType::JsonbAccess | ExprType::JsonbAccessStr | ExprType::JsonbExtractPath + | ExprType::JsonbExtractPathVariadic | ExprType::JsonbExtractPathText + | ExprType::JsonbExtractPathTextVariadic | ExprType::JsonbTypeof | ExprType::JsonbArrayLength | ExprType::IsJson @@ -271,7 +277,9 @@ impl Strong { | ExprType::JsonbStripNulls | ExprType::ToJsonb | ExprType::JsonbBuildArray + | ExprType::JsonbBuildArrayVariadic | ExprType::JsonbBuildObject + | ExprType::JsonbBuildObjectVariadic | ExprType::JsonbPathExists | ExprType::JsonbPathMatch | ExprType::JsonbPathQueryArray diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index 7051c10862d4..c6317d2930c2 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2272,6 +2272,8 @@ impl fmt::Display for FunctionArg { pub struct Function { pub name: ObjectName, pub args: Vec, + /// whether the last argument is variadic, e.g. `foo(a, b, variadic c)` + pub variadic: bool, pub over: Option, // aggregate functions may specify eg `COUNT(DISTINCT x)` pub distinct: bool, @@ -2286,6 +2288,7 @@ impl Function { Self { name, args: vec![], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -2299,17 +2302,22 @@ impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "{}({}{}{}{})", + "{}({}", self.name, if self.distinct { "DISTINCT " } else { "" }, - display_comma_separated(&self.args), - if !self.order_by.is_empty() { - " ORDER BY " - } else { - "" - }, - display_comma_separated(&self.order_by), )?; + if self.variadic { + for arg in &self.args[0..self.args.len() - 1] { + write!(f, "{}, ", arg)?; + } + write!(f, "VARIADIC {}", self.args.last().unwrap())?; + } else { + write!(f, "{}", display_comma_separated(&self.args))?; + } + if !self.order_by.is_empty() { + write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?; + } + write!(f, ")")?; if let Some(o) = &self.over { write!(f, " OVER ({})", o)?; } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 22f035002414..8c7a87fe51b9 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -815,7 +815,7 @@ impl Parser { pub fn parse_function(&mut self, name: ObjectName) -> Result { self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; - let (args, order_by) = self.parse_optional_args()?; + let (args, order_by, variadic) = self.parse_optional_args()?; let over = if self.parse_keyword(Keyword::OVER) { // TBD: support window names (`OVER mywin`) in place of inline specification self.expect_token(&Token::LParen)?; @@ -873,6 +873,7 @@ impl Parser { Ok(Expr::Function(Function { name, args, + variadic, over, distinct, order_by, @@ -4626,7 +4627,8 @@ impl Parser { let name = self.parse_object_name()?; // Postgres,table-valued functions: if self.consume_token(&Token::LParen) { - let (args, order_by) = self.parse_optional_args()?; + // ignore VARIADIC here + let (args, order_by, _variadic) = self.parse_optional_args()?; // Table-valued functions do not support ORDER BY, should return error if it appears if !order_by.is_empty() { return parser_err!("Table-valued functions do not support ORDER BY clauses"); @@ -4909,33 +4911,46 @@ impl Parser { Ok(Assignment { id, value }) } - fn parse_function_args(&mut self) -> Result { - if self.peek_nth_token(1) == Token::RArrow { + /// Parse a `[VARIADIC] name => expr`. + fn parse_function_args(&mut self) -> Result<(bool, FunctionArg), ParserError> { + let variadic = self.parse_keyword(Keyword::VARIADIC); + let arg = if self.peek_nth_token(1) == Token::RArrow { let name = self.parse_identifier()?; self.expect_token(&Token::RArrow)?; let arg = self.parse_wildcard_or_expr()?.into(); - Ok(FunctionArg::Named { name, arg }) + FunctionArg::Named { name, arg } } else { - Ok(FunctionArg::Unnamed(self.parse_wildcard_or_expr()?.into())) - } + FunctionArg::Unnamed(self.parse_wildcard_or_expr()?.into()) + }; + Ok((variadic, arg)) } pub fn parse_optional_args( &mut self, - ) -> Result<(Vec, Vec), ParserError> { + ) -> Result<(Vec, Vec, bool), ParserError> { if self.consume_token(&Token::RParen) { - Ok((vec![], vec![])) + Ok((vec![], vec![], false)) } else { let args = self.parse_comma_separated(Parser::parse_function_args)?; + if args + .iter() + .take(args.len() - 1) + .any(|(variadic, _)| *variadic) + { + return parser_err!("VARIADIC argument must be last"); + } + let variadic = args.last().map(|(variadic, _)| *variadic).unwrap_or(false); + let args = args.into_iter().map(|(_, arg)| arg).collect(); + let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_order_by_expr)? } else { vec![] }; self.expect_token(&Token::RParen)?; - Ok((args, order_by)) + Ok((args, order_by, variadic)) } } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index 0fc2f3c2530f..b447dce37d31 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -346,6 +346,7 @@ fn parse_select_count_wildcard() { &Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -367,6 +368,7 @@ fn parse_select_count_distinct() { op: UnaryOperator::Plus, expr: Box::new(Expr::Identifier(Ident::new_unchecked("x"))), }))], + variadic: false, over: None, distinct: true, order_by: vec![], @@ -1086,6 +1088,7 @@ fn parse_select_having() { left: Box::new(Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -1842,6 +1845,7 @@ fn parse_named_argument_function() { ))), }, ], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -1870,6 +1874,7 @@ fn parse_window_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("row_number")]), args: vec![], + variadic: false, over: Some(WindowSpec { partition_by: vec![], order_by: vec![OrderByExpr { @@ -1910,6 +1915,7 @@ fn parse_aggregate_with_order_by() { Ident::new_unchecked("b") ))), ], + variadic: false, over: None, distinct: false, order_by: vec![ @@ -1941,6 +1947,7 @@ fn parse_aggregate_with_filter() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new_unchecked("a")) )),], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -2199,6 +2206,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote_unchecked('"', "myfun")]), args: vec![], + variadic: false, over: None, distinct: false, order_by: vec![], diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 6a5dec5d809c..9ef7a7f085c4 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -1276,3 +1276,17 @@ fn parse_double_quoted_string_as_alias() { let sql = "SELECT x \"x1\" FROM t"; assert!(parse_sql_statements(sql).is_ok()); } + +#[test] +fn parse_variadic_argument() { + let sql = "SELECT foo(a, b, VARIADIC c)"; + _ = verified_stmt(sql); + + let sql = "SELECT foo(VARIADIC a, b, VARIADIC c)"; + assert_eq!( + parse_sql_statements(sql), + Err(ParserError::ParserError( + "VARIADIC argument must be last".to_string() + )) + ); +} diff --git a/src/sqlparser/tests/testdata/lambda.yaml b/src/sqlparser/tests/testdata/lambda.yaml index 6db19af63fcc..4ff15c60bd77 100644 --- a/src/sqlparser/tests/testdata/lambda.yaml +++ b/src/sqlparser/tests/testdata/lambda.yaml @@ -1,10 +1,10 @@ # This file is automatically generated. See `src/sqlparser/test_runner/src/bin/apply.rs` for more information. - input: select array_transform(array[1,2,3], |x| x * 2) formatted_sql: SELECT array_transform(ARRAY[1, 2, 3], |x| x * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |s| case when s ilike 'apple%' then 'apple' when s ilike 'google%' then 'google' else 'unknown' end) formatted_sql: SELECT array_transform(ARRAY[], |s| CASE WHEN s ILIKE 'apple%' THEN 'apple' WHEN s ILIKE 'google%' THEN 'google' ELSE 'unknown' END) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("apple%")) }, BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("google%")) }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("apple%")) }, BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("google%")) }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |x, y| x + y * 2) formatted_sql: SELECT array_transform(ARRAY[], |x, y| x + y * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/qualified_operator.yaml b/src/sqlparser/tests/testdata/qualified_operator.yaml index 23658fd17ce2..814113edbe5b 100644 --- a/src/sqlparser/tests/testdata/qualified_operator.yaml +++ b/src/sqlparser/tests/testdata/qualified_operator.yaml @@ -19,10 +19,10 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "operator", quote_style: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select "operator"(foo.bar); formatted_sql: SELECT "operator"(foo.bar) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select operator operator(+) operator(+) "operator"(9) operator from operator; formatted_sql: SELECT operator OPERATOR(+) OPERATOR(+) "operator"(9) AS operator FROM operator - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(Value(Number("9"))))], over: None, distinct: false, order_by: [], filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(Value(Number("9"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select 3 operator(-) 2 - 1; formatted_sql: SELECT 3 OPERATOR(-) 2 - 1 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Value(Number("3")), op: PGQualified(QualifiedOperator { schema: None, name: "-" }), right: BinaryOp { left: Value(Number("2")), op: Minus, right: Value(Number("1")) } })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 1fb897166a1a..5321abd469ee 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -1,7 +1,7 @@ # This file is automatically generated. See `src/sqlparser/test_runner/src/bin/apply.rs` for more information. - input: SELECT sqrt(id) FROM foo formatted_sql: SELECT sqrt(id) FROM foo - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT INT '1' formatted_sql: SELECT INT '1' - input: SELECT (foo).v1.v2 FROM foo @@ -132,7 +132,7 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "id1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "id2", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a2", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "stream", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "S", quote_style: None }, columns: [] }), for_system_time_as_of_proctime: false }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "version", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "V", quote_style: None }, columns: [] }), for_system_time_as_of_proctime: true }, join_operator: Inner(On(BinaryOp { left: Identifier(Ident { value: "id1", quote_style: None }), op: Eq, right: Identifier(Ident { value: "id2", quote_style: None }) })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x desc) from unnest(array[1,2,4,5,10]) as x formatted_sql: SELECT percentile_cont(0.3) FROM unnest(ARRAY[1, 2, 4, 5, 10]) AS x - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x, y desc) from t error_msg: 'sql parser error: only one arg in order by is expected here' - input: select 'apple' ~~ 'app%' diff --git a/src/tests/sqlsmith/src/sql_gen/agg.rs b/src/tests/sqlsmith/src/sql_gen/agg.rs index 26e65dccde2a..441e71e86c9f 100644 --- a/src/tests/sqlsmith/src/sql_gen/agg.rs +++ b/src/tests/sqlsmith/src/sql_gen/agg.rs @@ -142,6 +142,7 @@ fn make_agg_func( Function { name: ObjectName(vec![Ident::new_unchecked(func_name)]), args, + variadic: false, over: None, distinct, order_by, diff --git a/src/tests/sqlsmith/src/sql_gen/functions.rs b/src/tests/sqlsmith/src/sql_gen/functions.rs index 415d73d1cf3f..9265d45842e0 100644 --- a/src/tests/sqlsmith/src/sql_gen/functions.rs +++ b/src/tests/sqlsmith/src/sql_gen/functions.rs @@ -258,6 +258,7 @@ pub fn make_simple_func(func_name: &str, exprs: &[Expr]) -> Function { Function { name: ObjectName(vec![Ident::new_unchecked(func_name)]), args, + variadic: false, over: None, distinct: false, order_by: vec![],