From dbd946d15df25f8782be43608c6c3faf58764f1a Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 23 Jan 2024 16:38:06 +0800 Subject: [PATCH 01/14] support VARIADIC argument in parser Signed-off-by: Runji Wang --- src/sqlparser/src/ast/mod.rs | 24 ++++++++++------ src/sqlparser/src/parser.rs | 35 ++++++++++++++++------- src/sqlparser/tests/sqlparser_common.rs | 8 ++++++ src/sqlparser/tests/sqlparser_postgres.rs | 14 +++++++++ 4 files changed, 63 insertions(+), 18 deletions(-) diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index b11d4dc784bb..1ccfaba7ba5e 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -2260,6 +2260,8 @@ impl fmt::Display for FunctionArg { pub struct Function { pub name: ObjectName, pub args: Vec, + /// whether the last argument is variadic, e.g. `foo(a, b, variadic c)` + pub variadic: bool, pub over: Option, // aggregate functions may specify eg `COUNT(DISTINCT x)` pub distinct: bool, @@ -2274,6 +2276,7 @@ impl Function { Self { name, args: vec![], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -2287,17 +2290,22 @@ impl fmt::Display for Function { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "{}({}{}{}{})", + "{}({}", self.name, if self.distinct { "DISTINCT " } else { "" }, - display_comma_separated(&self.args), - if !self.order_by.is_empty() { - " ORDER BY " - } else { - "" - }, - display_comma_separated(&self.order_by), )?; + if self.variadic { + for arg in &self.args[0..self.args.len() - 1] { + write!(f, "{}, ", arg)?; + } + write!(f, "VARIADIC {}", self.args.last().unwrap())?; + } else { + write!(f, "{}", display_comma_separated(&self.args))?; + } + if !self.order_by.is_empty() { + write!(f, " ORDER BY {}", display_comma_separated(&self.order_by))?; + } + write!(f, ")")?; if let Some(o) = &self.over { write!(f, " OVER ({})", o)?; } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index e88a1df3157d..9724ae685a0d 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -804,7 +804,7 @@ impl Parser { pub fn parse_function(&mut self, name: ObjectName) -> Result { self.expect_token(&Token::LParen)?; let distinct = self.parse_all_or_distinct()?; - let (args, order_by) = self.parse_optional_args()?; + let (args, order_by, variadic) = self.parse_optional_args()?; let over = if self.parse_keyword(Keyword::OVER) { // TBD: support window names (`OVER mywin`) in place of inline specification self.expect_token(&Token::LParen)?; @@ -862,6 +862,7 @@ impl Parser { Ok(Expr::Function(Function { name, args, + variadic, over, distinct, order_by, @@ -4564,7 +4565,8 @@ impl Parser { let name = self.parse_object_name()?; // Postgres,table-valued functions: if self.consume_token(&Token::LParen) { - let (args, order_by) = self.parse_optional_args()?; + // ignore VARIADIC here + let (args, order_by, _variadic) = self.parse_optional_args()?; // Table-valued functions do not support ORDER BY, should return error if it appears if !order_by.is_empty() { return parser_err!("Table-valued functions do not support ORDER BY clauses"); @@ -4847,33 +4849,46 @@ impl Parser { Ok(Assignment { id, value }) } - fn parse_function_args(&mut self) -> Result { - if self.peek_nth_token(1) == Token::RArrow { + /// Parse a `[VARIADIC] name => expr`. + fn parse_function_args(&mut self) -> Result<(bool, FunctionArg), ParserError> { + let variadic = self.parse_keyword(Keyword::VARIADIC); + let arg = if self.peek_nth_token(1) == Token::RArrow { let name = self.parse_identifier()?; self.expect_token(&Token::RArrow)?; let arg = self.parse_wildcard_or_expr()?.into(); - Ok(FunctionArg::Named { name, arg }) + FunctionArg::Named { name, arg } } else { - Ok(FunctionArg::Unnamed(self.parse_wildcard_or_expr()?.into())) - } + FunctionArg::Unnamed(self.parse_wildcard_or_expr()?.into()) + }; + Ok((variadic, arg)) } pub fn parse_optional_args( &mut self, - ) -> Result<(Vec, Vec), ParserError> { + ) -> Result<(Vec, Vec, bool), ParserError> { if self.consume_token(&Token::RParen) { - Ok((vec![], vec![])) + Ok((vec![], vec![], false)) } else { let args = self.parse_comma_separated(Parser::parse_function_args)?; + if args + .iter() + .take(args.len() - 1) + .any(|(variadic, _)| *variadic) + { + return parser_err!("VARIADIC argument must be last"); + } + let variadic = args.last().map(|(variadic, _)| *variadic).unwrap_or(false); + let args = args.into_iter().map(|(_, arg)| arg).collect(); + let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) { self.parse_comma_separated(Parser::parse_order_by_expr)? } else { vec![] }; self.expect_token(&Token::RParen)?; - Ok((args, order_by)) + Ok((args, order_by, variadic)) } } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index 0fc2f3c2530f..b447dce37d31 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -346,6 +346,7 @@ fn parse_select_count_wildcard() { &Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -367,6 +368,7 @@ fn parse_select_count_distinct() { op: UnaryOperator::Plus, expr: Box::new(Expr::Identifier(Ident::new_unchecked("x"))), }))], + variadic: false, over: None, distinct: true, order_by: vec![], @@ -1086,6 +1088,7 @@ fn parse_select_having() { left: Box::new(Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("COUNT")]), args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard(None))], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -1842,6 +1845,7 @@ fn parse_named_argument_function() { ))), }, ], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -1870,6 +1874,7 @@ fn parse_window_functions() { &Expr::Function(Function { name: ObjectName(vec![Ident::new_unchecked("row_number")]), args: vec![], + variadic: false, over: Some(WindowSpec { partition_by: vec![], order_by: vec![OrderByExpr { @@ -1910,6 +1915,7 @@ fn parse_aggregate_with_order_by() { Ident::new_unchecked("b") ))), ], + variadic: false, over: None, distinct: false, order_by: vec![ @@ -1941,6 +1947,7 @@ fn parse_aggregate_with_filter() { args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( Expr::Identifier(Ident::new_unchecked("a")) )),], + variadic: false, over: None, distinct: false, order_by: vec![], @@ -2199,6 +2206,7 @@ fn parse_delimited_identifiers() { &Expr::Function(Function { name: ObjectName(vec![Ident::with_quote_unchecked('"', "myfun")]), args: vec![], + variadic: false, over: None, distinct: false, order_by: vec![], diff --git a/src/sqlparser/tests/sqlparser_postgres.rs b/src/sqlparser/tests/sqlparser_postgres.rs index 99e2c185fdcf..90a670745fd1 100644 --- a/src/sqlparser/tests/sqlparser_postgres.rs +++ b/src/sqlparser/tests/sqlparser_postgres.rs @@ -1271,3 +1271,17 @@ fn parse_double_quoted_string_as_alias() { let sql = "SELECT x \"x1\" FROM t"; assert!(parse_sql_statements(sql).is_ok()); } + +#[test] +fn parse_variadic_argument() { + let sql = "SELECT foo(a, b, VARIADIC c)"; + _ = verified_stmt(sql); + + let sql = "SELECT foo(VARIADIC a, b, VARIADIC c)"; + assert_eq!( + parse_sql_statements(sql), + Err(ParserError::ParserError( + "VARIADIC argument must be last".to_string() + )) + ); +} From 1b1d6cbc418128d9ae3a64ce6366dac390975375 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 23 Jan 2024 16:39:28 +0800 Subject: [PATCH 02/14] fix error in rw Signed-off-by: Runji Wang --- src/frontend/src/binder/relation/table_function.rs | 1 + src/tests/sqlsmith/src/sql_gen/agg.rs | 1 + src/tests/sqlsmith/src/sql_gen/functions.rs | 1 + 3 files changed, 3 insertions(+) diff --git a/src/frontend/src/binder/relation/table_function.rs b/src/frontend/src/binder/relation/table_function.rs index dbd015351bb5..c4113bdc512a 100644 --- a/src/frontend/src/binder/relation/table_function.rs +++ b/src/frontend/src/binder/relation/table_function.rs @@ -108,6 +108,7 @@ impl Binder { let func = self.bind_function(Function { name, args, + variadic: false, over: None, distinct: false, order_by: vec![], diff --git a/src/tests/sqlsmith/src/sql_gen/agg.rs b/src/tests/sqlsmith/src/sql_gen/agg.rs index 26e65dccde2a..441e71e86c9f 100644 --- a/src/tests/sqlsmith/src/sql_gen/agg.rs +++ b/src/tests/sqlsmith/src/sql_gen/agg.rs @@ -142,6 +142,7 @@ fn make_agg_func( Function { name: ObjectName(vec![Ident::new_unchecked(func_name)]), args, + variadic: false, over: None, distinct, order_by, diff --git a/src/tests/sqlsmith/src/sql_gen/functions.rs b/src/tests/sqlsmith/src/sql_gen/functions.rs index 415d73d1cf3f..9265d45842e0 100644 --- a/src/tests/sqlsmith/src/sql_gen/functions.rs +++ b/src/tests/sqlsmith/src/sql_gen/functions.rs @@ -258,6 +258,7 @@ pub fn make_simple_func(func_name: &str, exprs: &[Expr]) -> Function { Function { name: ObjectName(vec![Ident::new_unchecked(func_name)]), args, + variadic: false, over: None, distinct: false, order_by: vec![], From 68e99462019646498acc41480d89b6c26c57a5d7 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 23 Jan 2024 17:09:45 +0800 Subject: [PATCH 03/14] add function `format_variadic` Signed-off-by: Runji Wang --- proto/expr.proto | 1 + src/common/src/array/list_array.rs | 18 ++++++++++++++++++ src/expr/impl/src/scalar/format.rs | 6 +++++- src/expr/macro/src/gen.rs | 2 +- src/frontend/src/expr/pure.rs | 1 + 5 files changed, 26 insertions(+), 2 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index f62ee2936d11..87a01ee40dfd 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -176,6 +176,7 @@ message ExprNode { LEFT = 317; RIGHT = 318; FORMAT = 319; + FORMAT_VARIADIC = 324; PGWIRE_SEND = 320; PGWIRE_RECV = 321; CONVERT_FROM = 322; diff --git a/src/common/src/array/list_array.rs b/src/common/src/array/list_array.rs index 45ef845835be..8f0235f8145d 100644 --- a/src/common/src/array/list_array.rs +++ b/src/common/src/array/list_array.rs @@ -612,6 +612,24 @@ impl Debug for ListRef<'_> { } } +impl Row for ListRef<'_> { + fn datum_at(&self, index: usize) -> DatumRef<'_> { + self.array.value_at(self.start as usize + index) + } + + unsafe fn datum_at_unchecked(&self, index: usize) -> DatumRef<'_> { + self.array.value_at_unchecked(self.start as usize + index) + } + + fn len(&self) -> usize { + self.len() + } + + fn iter(&self) -> impl Iterator> { + (*self).iter() + } +} + impl ToText for ListRef<'_> { // This function will be invoked when pgwire prints a list value in string. // Refer to PostgreSQL `array_out` or `appendPGArray`. diff --git a/src/expr/impl/src/scalar/format.rs b/src/expr/impl/src/scalar/format.rs index 24256ee87f9f..1997b06d2f4f 100644 --- a/src/expr/impl/src/scalar/format.rs +++ b/src/expr/impl/src/scalar/format.rs @@ -27,7 +27,11 @@ use super::string::quote_ident; "format(varchar, ...) -> varchar", prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" )] -fn format(formatter: &Formatter, row: impl Row, writer: &mut impl Write) -> Result<()> { +#[function( + "format_variadic(varchar, anyarray) -> varchar", + prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" +)] +fn format(row: impl Row, formatter: &Formatter, writer: &mut impl Write) -> Result<()> { let mut args = row.iter(); for node in &formatter.nodes { match node { diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 28dd14c315f1..be6cb2b44cf6 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -314,8 +314,8 @@ impl FunctionAttr { // inputs: [ Option ] let mut output = quote! { #fn_name #generic( #(#non_prebuilt_inputs,)* - #prebuilt_arg #variadic_args + #prebuilt_arg #context #writer ) #await_ }; diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 7e7378a65752..829ca32c6d75 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -219,6 +219,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::ArrayPositions | expr_node::Type::StringToArray | expr_node::Type::Format + | expr_node::Type::FormatVariadic | expr_node::Type::PgwireSend | expr_node::Type::PgwireRecv | expr_node::Type::ArrayTransform From 2408f42bad6d3d00e6ae317c7a1e7942456014d8 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Tue, 23 Jan 2024 17:34:58 +0800 Subject: [PATCH 04/14] support variadic for `concat_ws` Signed-off-by: Runji Wang --- proto/expr.proto | 1 + src/expr/impl/src/scalar/concat_ws.rs | 15 +++++++++++++++ src/expr/impl/src/scalar/format.rs | 14 ++++++++++++++ src/frontend/src/binder/expr/function.rs | 23 +++++++++++++++++++++-- src/frontend/src/expr/pure.rs | 1 + 5 files changed, 52 insertions(+), 2 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 87a01ee40dfd..10091dc2e09e 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -91,6 +91,7 @@ message ExprNode { TRANSLATE = 216; COALESCE = 217; CONCAT_WS = 218; + CONCAT_WS_VARIADIC = 285; ABS = 219; SPLIT_PART = 220; CEIL = 221; diff --git a/src/expr/impl/src/scalar/concat_ws.rs b/src/expr/impl/src/scalar/concat_ws.rs index ec979d4cbacd..ba4b1bdcc4ab 100644 --- a/src/expr/impl/src/scalar/concat_ws.rs +++ b/src/expr/impl/src/scalar/concat_ws.rs @@ -20,7 +20,22 @@ use risingwave_expr::function; /// Concatenates all but the first argument, with separators. The first argument is used as the /// separator string, and should not be NULL. Other NULL arguments are ignored. +/// +/// # Example +/// +/// ```slt +/// query T +/// select concat_ws(',', 'abcde', 2, NULL, 22); +/// ---- +/// abcde,2,22 +/// +/// query T +/// select concat_ws(',', variadic array['abcde', 2, NULL, 22] :: varchar[]); +/// ---- +/// abcde,2,22 +/// ``` #[function("concat_ws(varchar, ...) -> varchar")] +#[function("concat_ws_variadic(varchar, anyarray) -> varchar")] fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { let mut string_iter = vals.iter().flatten(); if let Some(string) = string_iter.next() { diff --git a/src/expr/impl/src/scalar/format.rs b/src/expr/impl/src/scalar/format.rs index 1997b06d2f4f..2339953fdffe 100644 --- a/src/expr/impl/src/scalar/format.rs +++ b/src/expr/impl/src/scalar/format.rs @@ -23,6 +23,20 @@ use thiserror_ext::AsReport; use super::string::quote_ident; /// Formats arguments according to a format string. +/// +/// # Example +/// +/// ```slt +/// query T +/// select format('%s %s', 'Hello', 'World'); +/// ---- +/// Hello World +/// +/// query T +/// select format('%s %s', variadic array['Hello', 'World']); +/// ---- +/// Hello World +/// ``` #[function( "format(varchar, ...) -> varchar", prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index c0ffc86cd5eb..eefd38998d49 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -333,7 +333,7 @@ impl Binder { } } - self.bind_builtin_scalar_function(function_name.as_str(), inputs) + self.bind_builtin_scalar_function(function_name.as_str(), inputs, f.variadic) } fn bind_array_transform(&mut self, f: Function) -> Result { @@ -711,6 +711,7 @@ impl Binder { &mut self, function_name: &str, inputs: Vec, + variadic: bool, ) -> Result { type Inputs = Vec; @@ -990,7 +991,7 @@ impl Binder { guard_by_len(2, raw(|binder, inputs| { let (arg0, arg1) = inputs.into_iter().next_tuple().unwrap(); // rewrite into `CASE WHEN 0 < arg1 AND arg1 <= array_ndims(arg0) THEN 1 END` - let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0])?; + let ndims_expr = binder.bind_builtin_scalar_function("array_ndims", vec![arg0], false)?; let arg1 = arg1.cast_implicit(DataType::Int32)?; FunctionCall::new( @@ -1386,6 +1387,24 @@ impl Binder { tree }); + if variadic { + match function_name { + "format" => { + return Ok(FunctionCall::new(ExprType::FormatVariadic, inputs)?.into()); + } + "concat_ws" => { + return Ok(FunctionCall::new(ExprType::ConcatWsVariadic, inputs)?.into()); + } + _ => { + return Err(ErrorCode::BindError(format!( + "VARIADIC argument is not allowed in function \"{}\"", + function_name + )) + .into()) + } + } + } + match HANDLES.get(function_name) { Some(handle) => handle(self, inputs), None => { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 829ca32c6d75..1313d85b4708 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -88,6 +88,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Translate | expr_node::Type::Coalesce | expr_node::Type::ConcatWs + | expr_node::Type::ConcatWsVariadic | expr_node::Type::Abs | expr_node::Type::SplitPart | expr_node::Type::Ceil From 8d2998aaeaacd5ab62b00a6dc4d782cd413d42ab Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 24 Jan 2024 14:14:10 +0800 Subject: [PATCH 05/14] update parser test Signed-off-by: Runji Wang --- src/sqlparser/tests/testdata/lambda.yaml | 6 +++--- src/sqlparser/tests/testdata/qualified_operator.yaml | 4 ++-- src/sqlparser/tests/testdata/select.yaml | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/sqlparser/tests/testdata/lambda.yaml b/src/sqlparser/tests/testdata/lambda.yaml index 6db19af63fcc..4ff15c60bd77 100644 --- a/src/sqlparser/tests/testdata/lambda.yaml +++ b/src/sqlparser/tests/testdata/lambda.yaml @@ -1,10 +1,10 @@ # This file is automatically generated. See `src/sqlparser/test_runner/src/bin/apply.rs` for more information. - input: select array_transform(array[1,2,3], |x| x * 2) formatted_sql: SELECT array_transform(ARRAY[1, 2, 3], |x| x * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("3"))], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Multiply, right: Value(Number("2")) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |s| case when s ilike 'apple%' then 'apple' when s ilike 'google%' then 'google' else 'unknown' end) formatted_sql: SELECT array_transform(ARRAY[], |s| CASE WHEN s ILIKE 'apple%' THEN 'apple' WHEN s ILIKE 'google%' THEN 'google' ELSE 'unknown' END) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("apple%")) }, BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("google%")) }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "s", quote_style: None }], body: Case { operand: None, conditions: [BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("apple%")) }, BinaryOp { left: Identifier(Ident { value: "s", quote_style: None }), op: ILike, right: Value(SingleQuotedString("google%")) }], results: [Value(SingleQuotedString("apple")), Value(SingleQuotedString("google"))], else_result: Some(Value(SingleQuotedString("unknown"))) } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select array_transform(array[], |x, y| x + y * 2) formatted_sql: SELECT array_transform(ARRAY[], |x, y| x + y * 2) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "array_transform", quote_style: None }]), args: [Unnamed(Expr(Array(Array { elem: [], named: true }))), Unnamed(Expr(LambdaFunction { args: [Ident { value: "x", quote_style: None }, Ident { value: "y", quote_style: None }], body: BinaryOp { left: Identifier(Ident { value: "x", quote_style: None }), op: Plus, right: BinaryOp { left: Identifier(Ident { value: "y", quote_style: None }), op: Multiply, right: Value(Number("2")) } } }))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/qualified_operator.yaml b/src/sqlparser/tests/testdata/qualified_operator.yaml index 23658fd17ce2..814113edbe5b 100644 --- a/src/sqlparser/tests/testdata/qualified_operator.yaml +++ b/src/sqlparser/tests/testdata/qualified_operator.yaml @@ -19,10 +19,10 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "operator", quote_style: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select "operator"(foo.bar); formatted_sql: SELECT "operator"(foo.bar) - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(CompoundIdentifier([Ident { value: "foo", quote_style: None }, Ident { value: "bar", quote_style: None }])))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select operator operator(+) operator(+) "operator"(9) operator from operator; formatted_sql: SELECT operator OPERATOR(+) OPERATOR(+) "operator"(9) AS operator FROM operator - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(Value(Number("9"))))], over: None, distinct: false, order_by: [], filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [ExprWithAlias { expr: BinaryOp { left: Identifier(Ident { value: "operator", quote_style: None }), op: PGQualified(QualifiedOperator { schema: None, name: "+" }), right: UnaryOp { op: PGQualified(QualifiedOperator { schema: None, name: "+" }), expr: Function(Function { name: ObjectName([Ident { value: "operator", quote_style: Some(''"'') }]), args: [Unnamed(Expr(Value(Number("9"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }) } }, alias: Ident { value: "operator", quote_style: None } }], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "operator", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select 3 operator(-) 2 - 1; formatted_sql: SELECT 3 OPERATOR(-) 2 - 1 formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(BinaryOp { left: Value(Number("3")), op: PGQualified(QualifiedOperator { schema: None, name: "-" }), right: BinaryOp { left: Value(Number("2")), op: Minus, right: Value(Number("1")) } })], from: [], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index 1fb897166a1a..5321abd469ee 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -1,7 +1,7 @@ # This file is automatically generated. See `src/sqlparser/test_runner/src/bin/apply.rs` for more information. - input: SELECT sqrt(id) FROM foo formatted_sql: SELECT sqrt(id) FROM foo - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, for_system_time_as_of_proctime: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: SELECT INT '1' formatted_sql: SELECT INT '1' - input: SELECT (foo).v1.v2 FROM foo @@ -132,7 +132,7 @@ formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Identifier(Ident { value: "id1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a1", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "id2", quote_style: None })), UnnamedExpr(Identifier(Ident { value: "a2", quote_style: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "stream", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "S", quote_style: None }, columns: [] }), for_system_time_as_of_proctime: false }, joins: [Join { relation: Table { name: ObjectName([Ident { value: "version", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "V", quote_style: None }, columns: [] }), for_system_time_as_of_proctime: true }, join_operator: Inner(On(BinaryOp { left: Identifier(Ident { value: "id1", quote_style: None }), op: Eq, right: Identifier(Ident { value: "id2", quote_style: None }) })) }] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x desc) from unnest(array[1,2,4,5,10]) as x formatted_sql: SELECT percentile_cont(0.3) FROM unnest(ARRAY[1, 2, 4, 5, 10]) AS x - formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' + formatted_ast: 'Query(Query { with: None, body: Select(Select { distinct: All, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "percentile_cont", quote_style: None }]), args: [Unnamed(Expr(Value(Number("0.3"))))], variadic: false, over: None, distinct: false, order_by: [], filter: None, within_group: Some(OrderByExpr { expr: Identifier(Ident { value: "x", quote_style: None }), asc: Some(false), nulls_first: None }) }))], from: [TableWithJoins { relation: TableFunction { name: ObjectName([Ident { value: "unnest", quote_style: None }]), alias: Some(TableAlias { name: Ident { value: "x", quote_style: None }, columns: [] }), args: [Unnamed(Expr(Array(Array { elem: [Value(Number("1")), Value(Number("2")), Value(Number("4")), Value(Number("5")), Value(Number("10"))], named: true })))], with_ordinality: false }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None })' - input: select percentile_cont(0.3) within group (order by x, y desc) from t error_msg: 'sql parser error: only one arg in order by is expected here' - input: select 'apple' ~~ 'app%' From 4a0ce6f761631910ccc903723f4bdabd94a635b9 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 24 Jan 2024 17:29:10 +0800 Subject: [PATCH 06/14] support "variadic" keyword in `#[function]` Signed-off-by: Runji Wang --- proto/expr.proto | 2 ++ src/expr/impl/src/scalar/concat_ws.rs | 3 +- src/expr/impl/src/scalar/format.rs | 6 +--- src/expr/impl/src/scalar/jsonb_build.rs | 14 +++++++-- src/expr/macro/src/gen.rs | 36 ++++++++++++++++++++++++ src/frontend/src/binder/expr/function.rs | 26 ++++++++--------- src/frontend/src/expr/pure.rs | 2 ++ 7 files changed, 67 insertions(+), 22 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 10091dc2e09e..62edaeca3d91 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -255,7 +255,9 @@ message ExprNode { JSONB_STRIP_NULLS = 616; TO_JSONB = 617; JSONB_BUILD_ARRAY = 618; + JSONB_BUILD_ARRAY_VARIADIC = 624; JSONB_BUILD_OBJECT = 619; + JSONB_BUILD_OBJECT_VARIADIC = 625; JSONB_PATH_EXISTS = 620; JSONB_PATH_MATCH = 621; JSONB_PATH_QUERY_ARRAY = 622; diff --git a/src/expr/impl/src/scalar/concat_ws.rs b/src/expr/impl/src/scalar/concat_ws.rs index ba4b1bdcc4ab..0b95724cc73a 100644 --- a/src/expr/impl/src/scalar/concat_ws.rs +++ b/src/expr/impl/src/scalar/concat_ws.rs @@ -34,8 +34,7 @@ use risingwave_expr::function; /// ---- /// abcde,2,22 /// ``` -#[function("concat_ws(varchar, ...) -> varchar")] -#[function("concat_ws_variadic(varchar, anyarray) -> varchar")] +#[function("concat_ws(varchar, variadic anyarray) -> varchar")] fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { let mut string_iter = vals.iter().flatten(); if let Some(string) = string_iter.next() { diff --git a/src/expr/impl/src/scalar/format.rs b/src/expr/impl/src/scalar/format.rs index 2339953fdffe..50195638e4d0 100644 --- a/src/expr/impl/src/scalar/format.rs +++ b/src/expr/impl/src/scalar/format.rs @@ -38,11 +38,7 @@ use super::string::quote_ident; /// Hello World /// ``` #[function( - "format(varchar, ...) -> varchar", - prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" -)] -#[function( - "format_variadic(varchar, anyarray) -> varchar", + "format(varchar, variadic anyarray) -> varchar", prebuild = "Formatter::from_str($0).map_err(|e| ExprError::Parse(e.to_report_string().into()))?" )] fn format(row: impl Row, formatter: &Formatter, writer: &mut impl Write) -> Result<()> { diff --git a/src/expr/impl/src/scalar/jsonb_build.rs b/src/expr/impl/src/scalar/jsonb_build.rs index 85b24d7126f1..ddf22def1526 100644 --- a/src/expr/impl/src/scalar/jsonb_build.rs +++ b/src/expr/impl/src/scalar/jsonb_build.rs @@ -31,8 +31,13 @@ use super::{ToJsonb, ToTextDisplay}; /// select jsonb_build_array(1, 2, 'foo', 4, 5); /// ---- /// [1, 2, "foo", 4, 5] +/// +/// query T +/// select jsonb_build_array(variadic array[1, 2, 4, 5]); +/// ---- +/// [1, 2, 4, 5] /// ``` -#[function("jsonb_build_array(...) -> jsonb")] +#[function("jsonb_build_array(variadic anyarray) -> jsonb")] fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { let mut builder = Builder::>::new(); builder.begin_array(); @@ -54,8 +59,13 @@ fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { /// select jsonb_build_object('foo', 1, 2, 'bar'); /// ---- /// {"2": "bar", "foo": 1} +/// +/// query T +/// select jsonb_build_object(variadic array['foo', '1', '2', 'bar']); +/// ---- +/// {"2": "bar", "foo": 1} /// ``` -#[function("jsonb_build_object(...) -> jsonb")] +#[function("jsonb_build_object(variadic anyarray) -> jsonb")] fn jsonb_build_object(args: impl Row, ctx: &Context) -> Result { if args.len() % 2 == 1 { return Err(ExprError::InvalidParam { diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index be6cb2b44cf6..80f074b14322 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -23,6 +23,42 @@ use super::*; impl FunctionAttr { /// Expands the wildcard in function arguments or return type. pub fn expand(&self) -> Vec { + // handle variadic argument + if self + .args + .last() + .is_some_and(|arg| arg.starts_with("variadic")) + { + // expand: foo(a, b, variadic anyarray) + // to: foo(a, b, ...) + // + foo_variadic(a, b, anyarray) + let mut attrs = Vec::new(); + attrs.extend( + FunctionAttr { + args: { + let mut args = self.args.clone(); + *args.last_mut().unwrap() = "...".to_string(); + args + }, + ..self.clone() + } + .expand(), + ); + attrs.extend( + FunctionAttr { + name: format!("{}_variadic", self.name), + args: { + let mut args = self.args.clone(); + let last = args.last_mut().unwrap(); + *last = last.strip_prefix("variadic ").unwrap().into(); + args + }, + ..self.clone() + } + .expand(), + ); + return attrs; + } let args = self.args.iter().map(|ty| types::expand_type_wildcard(ty)); let ret = types::expand_type_wildcard(&self.ret); // multi_cartesian_product should emit an empty set if the input is empty. diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index eefd38998d49..e17a834f600c 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -1388,21 +1388,21 @@ impl Binder { }); if variadic { - match function_name { - "format" => { - return Ok(FunctionCall::new(ExprType::FormatVariadic, inputs)?.into()); + return match function_name { + "format" => Ok(FunctionCall::new(ExprType::FormatVariadic, inputs)?.into()), + "concat_ws" => Ok(FunctionCall::new(ExprType::ConcatWsVariadic, inputs)?.into()), + "jsonb_build_array" => { + Ok(FunctionCall::new(ExprType::JsonbBuildArrayVariadic, inputs)?.into()) } - "concat_ws" => { - return Ok(FunctionCall::new(ExprType::ConcatWsVariadic, inputs)?.into()); + "jsonb_build_object" => { + Ok(FunctionCall::new(ExprType::JsonbBuildObjectVariadic, inputs)?.into()) } - _ => { - return Err(ErrorCode::BindError(format!( - "VARIADIC argument is not allowed in function \"{}\"", - function_name - )) - .into()) - } - } + _ => Err(ErrorCode::BindError(format!( + "VARIADIC argument is not allowed in function \"{}\"", + function_name + )) + .into()), + }; } match HANDLES.get(function_name) { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 1313d85b4708..5445b6b26b35 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -191,7 +191,9 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::JsonbExistsAll | expr_node::Type::JsonbStripNulls | expr_node::Type::JsonbBuildArray + | expr_node::Type::JsonbBuildArrayVariadic | expr_node::Type::JsonbBuildObject + | expr_node::Type::JsonbBuildObjectVariadic | expr_node::Type::JsonbPathExists | expr_node::Type::JsonbPathMatch | expr_node::Type::JsonbPathQueryArray From cc37ccf5e2e56eddb8ae1834856b00f6e271f645 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 24 Jan 2024 17:55:27 +0800 Subject: [PATCH 07/14] support variadic for `jsonb_extract_path[_text]` Signed-off-by: Runji Wang --- proto/expr.proto | 6 +- src/expr/impl/src/scalar/jsonb_access.rs | 21 +++++-- src/frontend/src/binder/expr/binary_op.rs | 4 +- src/frontend/src/binder/expr/function.rs | 59 ++++++-------------- src/frontend/src/expr/pure.rs | 2 + src/frontend/src/expr/type_inference/func.rs | 16 ++++++ 6 files changed, 56 insertions(+), 52 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 62edaeca3d91..64223da67ede 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -226,9 +226,11 @@ message ExprNode { // jsonb ->> int, jsonb ->> text that returns text JSONB_ACCESS_STR = 601; // jsonb #> text[] -> jsonb - JSONB_EXTRACT_PATH = 613; + JSONB_EXTRACT_PATH = 626; + JSONB_EXTRACT_PATH_VARIADIC = 613; // jsonb #>> text[] -> text - JSONB_EXTRACT_PATH_TEXT = 614; + JSONB_EXTRACT_PATH_TEXT = 627; + JSONB_EXTRACT_PATH_TEXT_VARIADIC = 614; JSONB_TYPEOF = 602; JSONB_ARRAY_LENGTH = 603; IS_JSON = 604; diff --git a/src/expr/impl/src/scalar/jsonb_access.rs b/src/expr/impl/src/scalar/jsonb_access.rs index 36ced44bf357..176e44810f23 100644 --- a/src/expr/impl/src/scalar/jsonb_access.rs +++ b/src/expr/impl/src/scalar/jsonb_access.rs @@ -14,7 +14,8 @@ use std::fmt::Write; -use risingwave_common::types::{JsonbRef, ListRef}; +use risingwave_common::row::Row; +use risingwave_common::types::JsonbRef; use risingwave_expr::function; /// Extracts JSON object field with the given key. @@ -91,9 +92,14 @@ pub fn jsonb_array_element(v: JsonbRef<'_>, p: i32) -> Option> { /// select jsonb_extract_path('{"a": {"b": ["foo","bar"]}}', 'a', 'b', '1'); /// ---- /// "bar" +/// +/// query T +/// select jsonb_extract_path('{"a": {"b": ["foo","bar"]}}', variadic array['a', 'b', '1']); +/// ---- +/// "bar" /// ``` -#[function("jsonb_extract_path(jsonb, varchar[]) -> jsonb")] -pub fn jsonb_extract_path<'a>(v: JsonbRef<'a>, path: ListRef<'_>) -> Option> { +#[function("jsonb_extract_path(jsonb, variadic varchar[]) -> jsonb")] +pub fn jsonb_extract_path<'a>(v: JsonbRef<'a>, path: impl Row) -> Option> { let mut jsonb = v; for key in path.iter() { // return null if any element is null @@ -192,11 +198,16 @@ pub fn jsonb_array_element_str(v: JsonbRef<'_>, p: i32, writer: &mut impl Write) /// select jsonb_extract_path_text('{"a": {"b": ["foo","bar"]}}', 'a', 'b', '1'); /// ---- /// bar +/// +/// query T +/// select jsonb_extract_path_text('{"a": {"b": ["foo","bar"]}}', variadic array['a', 'b', '1']); +/// ---- +/// bar /// ``` -#[function("jsonb_extract_path_text(jsonb, varchar[]) -> varchar")] +#[function("jsonb_extract_path_text(jsonb, variadic varchar[]) -> varchar")] pub fn jsonb_extract_path_text( v: JsonbRef<'_>, - path: ListRef<'_>, + path: impl Row, writer: &mut impl Write, ) -> Option<()> { let jsonb = jsonb_extract_path(v, path)?; diff --git a/src/frontend/src/binder/expr/binary_op.rs b/src/frontend/src/binder/expr/binary_op.rs index a2efee44fe34..fd14ffbd90a2 100644 --- a/src/frontend/src/binder/expr/binary_op.rs +++ b/src/frontend/src/binder/expr/binary_op.rs @@ -113,8 +113,8 @@ impl Binder { BinaryOperator::Arrow => ExprType::JsonbAccess, BinaryOperator::LongArrow => ExprType::JsonbAccessStr, BinaryOperator::HashMinus => ExprType::JsonbDeletePath, - BinaryOperator::HashArrow => ExprType::JsonbExtractPath, - BinaryOperator::HashLongArrow => ExprType::JsonbExtractPathText, + BinaryOperator::HashArrow => ExprType::JsonbExtractPathVariadic, + BinaryOperator::HashLongArrow => ExprType::JsonbExtractPathTextVariadic, BinaryOperator::Prefix => ExprType::StartsWith, BinaryOperator::Contains => { let left_type = (!bound_left.is_untyped()).then(|| bound_left.return_type()); diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index e17a834f600c..59ea9a805576 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -1018,36 +1018,8 @@ impl Binder { ("jsonb_array_element", raw_call(ExprType::JsonbAccess)), ("jsonb_object_field_text", raw_call(ExprType::JsonbAccessStr)), ("jsonb_array_element_text", raw_call(ExprType::JsonbAccessStr)), - ("jsonb_extract_path", raw(|_binder, mut inputs| { - // rewrite: jsonb_extract_path(jsonb, s1, s2...) - // to: jsonb_extract_path(jsonb, array[s1, s2...]) - if inputs.len() < 2 { - return Err(ErrorCode::ExprError("unexpected arguments number".into()).into()); - } - inputs[0].cast_implicit_mut(DataType::Jsonb)?; - let mut variadic_inputs = inputs.split_off(1); - for input in &mut variadic_inputs { - input.cast_implicit_mut(DataType::Varchar)?; - } - let array = FunctionCall::new_unchecked(ExprType::Array, variadic_inputs, DataType::List(Box::new(DataType::Varchar))); - inputs.push(array.into()); - Ok(FunctionCall::new_unchecked(ExprType::JsonbExtractPath, inputs, DataType::Jsonb).into()) - })), - ("jsonb_extract_path_text", raw(|_binder, mut inputs| { - // rewrite: jsonb_extract_path_text(jsonb, s1, s2...) - // to: jsonb_extract_path_text(jsonb, array[s1, s2...]) - if inputs.len() < 2 { - return Err(ErrorCode::ExprError("unexpected arguments number".into()).into()); - } - inputs[0].cast_implicit_mut(DataType::Jsonb)?; - let mut variadic_inputs = inputs.split_off(1); - for input in &mut variadic_inputs { - input.cast_implicit_mut(DataType::Varchar)?; - } - let array = FunctionCall::new_unchecked(ExprType::Array, variadic_inputs, DataType::List(Box::new(DataType::Varchar))); - inputs.push(array.into()); - Ok(FunctionCall::new_unchecked(ExprType::JsonbExtractPathText, inputs, DataType::Varchar).into()) - })), + ("jsonb_extract_path", raw_call(ExprType::JsonbExtractPath)), + ("jsonb_extract_path_text", raw_call(ExprType::JsonbExtractPathText)), ("jsonb_typeof", raw_call(ExprType::JsonbTypeof)), ("jsonb_array_length", raw_call(ExprType::JsonbArrayLength)), ("jsonb_concat", raw_call(ExprType::JsonbConcat)), @@ -1388,21 +1360,22 @@ impl Binder { }); if variadic { - return match function_name { - "format" => Ok(FunctionCall::new(ExprType::FormatVariadic, inputs)?.into()), - "concat_ws" => Ok(FunctionCall::new(ExprType::ConcatWsVariadic, inputs)?.into()), - "jsonb_build_array" => { - Ok(FunctionCall::new(ExprType::JsonbBuildArrayVariadic, inputs)?.into()) - } - "jsonb_build_object" => { - Ok(FunctionCall::new(ExprType::JsonbBuildObjectVariadic, inputs)?.into()) + let func = match function_name { + "format" => ExprType::FormatVariadic, + "concat_ws" => ExprType::ConcatWsVariadic, + "jsonb_build_array" => ExprType::JsonbBuildArrayVariadic, + "jsonb_build_object" => ExprType::JsonbBuildObjectVariadic, + "jsonb_extract_path" => ExprType::JsonbExtractPathVariadic, + "jsonb_extract_path_text" => ExprType::JsonbExtractPathTextVariadic, + _ => { + return Err(ErrorCode::BindError(format!( + "VARIADIC argument is not allowed in function \"{}\"", + function_name + )) + .into()) } - _ => Err(ErrorCode::BindError(format!( - "VARIADIC argument is not allowed in function \"{}\"", - function_name - )) - .into()), }; + return Ok(FunctionCall::new(func, inputs)?.into()); } match HANDLES.get(function_name) { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 5445b6b26b35..7a954a04d7ee 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -178,7 +178,9 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::JsonbAccess | expr_node::Type::JsonbAccessStr | expr_node::Type::JsonbExtractPath + | expr_node::Type::JsonbExtractPathVariadic | expr_node::Type::JsonbExtractPathText + | expr_node::Type::JsonbExtractPathTextVariadic | expr_node::Type::JsonbTypeof | expr_node::Type::JsonbArrayLength | expr_node::Type::JsonbObject diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 1ccbaa28e9da..10a52f8c3b91 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -595,6 +595,22 @@ fn infer_type_for_special( } Ok(Some(DataType::Jsonb)) } + ExprType::JsonbExtractPath => { + ensure_arity!("jsonb_extract_path", 2 <= | inputs |); + inputs[0].cast_implicit_mut(DataType::Jsonb)?; + for input in inputs.iter_mut().skip(1) { + input.cast_implicit_mut(DataType::Varchar)?; + } + Ok(Some(DataType::Jsonb)) + } + ExprType::JsonbExtractPathText => { + ensure_arity!("jsonb_extract_path_text", 2 <= | inputs |); + inputs[0].cast_implicit_mut(DataType::Jsonb)?; + for input in inputs.iter_mut().skip(1) { + input.cast_implicit_mut(DataType::Varchar)?; + } + Ok(Some(DataType::Varchar)) + } _ => Ok(None), } } From f123201c0a5b6695c7e81fa2ebd59cdf2835248f Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 24 Jan 2024 18:10:25 +0800 Subject: [PATCH 08/14] support variadic for `concat` Signed-off-by: Runji Wang --- proto/expr.proto | 3 +- src/expr/impl/src/scalar/concat.rs | 73 +++++++++++++++++++ src/expr/impl/src/scalar/concat_op.rs | 35 --------- src/expr/impl/src/scalar/concat_ws.rs | 3 +- src/expr/impl/src/scalar/mod.rs | 2 +- .../tests/testdata/output/array.yaml | 2 +- .../tests/testdata/output/pg_catalog.yaml | 2 +- .../tests/testdata/output/subquery_expr.yaml | 2 +- src/frontend/src/binder/expr/binary_op.rs | 6 +- src/frontend/src/binder/expr/function.rs | 1 + src/frontend/src/expr/pure.rs | 3 +- src/frontend/src/expr/type_inference/func.rs | 2 +- 12 files changed, 86 insertions(+), 48 deletions(-) create mode 100644 src/expr/impl/src/scalar/concat.rs delete mode 100644 src/expr/impl/src/scalar/concat_op.rs diff --git a/proto/expr.proto b/proto/expr.proto index 64223da67ede..69387671a6b1 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -100,7 +100,8 @@ message ExprNode { MD5 = 224; CHAR_LENGTH = 225; REPEAT = 226; - CONCAT_OP = 227; + CONCAT = 227; + CONCAT_VARIADIC = 286; // BOOL_OUT is different from CAST-bool-to-varchar in PostgreSQL. BOOL_OUT = 228; OCTET_LENGTH = 229; diff --git a/src/expr/impl/src/scalar/concat.rs b/src/expr/impl/src/scalar/concat.rs new file mode 100644 index 000000000000..9359c75a7af2 --- /dev/null +++ b/src/expr/impl/src/scalar/concat.rs @@ -0,0 +1,73 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Write; + +use risingwave_common::row::Row; +use risingwave_common::types::ToText; +use risingwave_expr::function; + +/// Concatenates the text representations of all the arguments. NULL arguments are ignored. +/// +/// # Example +/// +/// ```slt +/// query T +/// select concat('abcde', 2, NULL, 22); +/// ---- +/// abcde222 +/// +/// query T +/// select concat(variadic array['abcde', '2', NULL, '22']); +/// ---- +/// abcde222 +/// ``` +#[function("concat(variadic anyarray) -> varchar")] +fn concat(vals: impl Row, writer: &mut impl Write) { + for string in vals.iter().flatten() { + string.write(writer).unwrap(); + } +} + +#[cfg(test)] +mod tests { + use risingwave_common::array::DataChunk; + use risingwave_common::row::Row; + use risingwave_common::test_prelude::DataChunkTestExt; + use risingwave_common::types::ToOwnedDatum; + use risingwave_common::util::iter_util::ZipEqDebug; + use risingwave_expr::expr::build_from_pretty; + + #[tokio::test] + async fn test_concat() { + let concat = build_from_pretty("(concat:varchar $0:varchar $1:varchar $2:varchar)"); + let (input, expected) = DataChunk::from_pretty( + "T T T T + a b c abc + . b c bc + . . . (empty)", + ) + .split_column_at(3); + + // test eval + let output = concat.eval(&input).await.unwrap(); + assert_eq!(&output, expected.column_at(0)); + + // test eval_row + for (row, expected) in input.rows().zip_eq_debug(expected.rows()) { + let result = concat.eval_row(&row.to_owned_row()).await.unwrap(); + assert_eq!(result, expected.datum_at(0).to_owned_datum()); + } + } +} diff --git a/src/expr/impl/src/scalar/concat_op.rs b/src/expr/impl/src/scalar/concat_op.rs deleted file mode 100644 index 399700b51c1f..000000000000 --- a/src/expr/impl/src/scalar/concat_op.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2024 RisingWave Labs -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::fmt::Write; - -use risingwave_expr::function; - -#[function("concat_op(varchar, varchar) -> varchar")] -pub fn concat_op(left: &str, right: &str, writer: &mut impl Write) { - writer.write_str(left).unwrap(); - writer.write_str(right).unwrap(); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_concat_op() { - let mut s = String::new(); - concat_op("114", "514", &mut s); - assert_eq!(s, "114514") - } -} diff --git a/src/expr/impl/src/scalar/concat_ws.rs b/src/expr/impl/src/scalar/concat_ws.rs index 0b95724cc73a..4c2ce3d56ac5 100644 --- a/src/expr/impl/src/scalar/concat_ws.rs +++ b/src/expr/impl/src/scalar/concat_ws.rs @@ -35,7 +35,7 @@ use risingwave_expr::function; /// abcde,2,22 /// ``` #[function("concat_ws(varchar, variadic anyarray) -> varchar")] -fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { +fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) { let mut string_iter = vals.iter().flatten(); if let Some(string) = string_iter.next() { string.write(writer).unwrap(); @@ -44,7 +44,6 @@ fn concat_ws(sep: &str, vals: impl Row, writer: &mut impl Write) -> Option<()> { write!(writer, "{}", sep).unwrap(); string.write(writer).unwrap(); } - Some(()) } #[cfg(test)] diff --git a/src/expr/impl/src/scalar/mod.rs b/src/expr/impl/src/scalar/mod.rs index bc925e3f2831..76843681c647 100644 --- a/src/expr/impl/src/scalar/mod.rs +++ b/src/expr/impl/src/scalar/mod.rs @@ -35,7 +35,7 @@ mod case; mod cast; mod cmp; mod coalesce; -mod concat_op; +mod concat; mod concat_ws; mod conjunction; mod date_trunc; diff --git a/src/frontend/planner_test/tests/testdata/output/array.yaml b/src/frontend/planner_test/tests/testdata/output/array.yaml index e578ddac5307..8f90fbb539ea 100644 --- a/src/frontend/planner_test/tests/testdata/output/array.yaml +++ b/src/frontend/planner_test/tests/testdata/output/array.yaml @@ -233,7 +233,7 @@ sql: | select ('{c,' || 'd}')::varchar[]; logical_plan: |- - LogicalProject { exprs: [ConcatOp('{c,':Varchar, 'd}':Varchar)::List(Varchar) as $expr1] } + LogicalProject { exprs: [Concat('{c,':Varchar, 'd}':Varchar)::List(Varchar) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } - name: unknown to varchar[] in implicit context sql: | diff --git a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml index 7842e311e47a..6d4ceba83626 100644 --- a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml +++ b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml @@ -229,7 +229,7 @@ - sql: | select ('pg' || '_namespace')::regclass logical_plan: |- - LogicalProject { exprs: [CastRegclass(ConcatOp('pg':Varchar, '_namespace':Varchar)) as $expr1] } + LogicalProject { exprs: [CastRegclass(Concat('pg':Varchar, '_namespace':Varchar)) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } batch_plan: |- BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] } diff --git a/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml b/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml index e7c8940b5a8c..4eeb25605b20 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml @@ -241,7 +241,7 @@ logical_plan: |- LogicalProject { exprs: [sum(*VALUES*_0.column_0), max($expr1), string_agg($expr2, ',':Varchar)] } └─LogicalAgg { aggs: [sum(*VALUES*_0.column_0), max($expr1), string_agg($expr2, ',':Varchar)] } - └─LogicalProject { exprs: [*VALUES*_0.column_0, ((*VALUES*_0.column_1 + *VALUES*_0.column_2) + 10:Int32) as $expr1, ConcatOp(*VALUES*_0.column_2::Varchar, '~':Varchar) as $expr2, ',':Varchar] } + └─LogicalProject { exprs: [*VALUES*_0.column_0, ((*VALUES*_0.column_1 + *VALUES*_0.column_2) + 10:Int32) as $expr1, Concat(*VALUES*_0.column_2::Varchar, '~':Varchar) as $expr2, ',':Varchar] } └─LogicalValues { rows: [[1:Int32, 2:Int32, 3:Int32], [4:Int32, 5:Int32, 6:Int32]], schema: Schema { fields: [*VALUES*_0.column_0:Int32, *VALUES*_0.column_1:Int32, *VALUES*_0.column_2:Int32] } } - sql: | select 1 + (select 2 from t); diff --git a/src/frontend/src/binder/expr/binary_op.rs b/src/frontend/src/binder/expr/binary_op.rs index fd14ffbd90a2..d446920cc7ff 100644 --- a/src/frontend/src/binder/expr/binary_op.rs +++ b/src/frontend/src/binder/expr/binary_op.rs @@ -171,9 +171,7 @@ impl Binder { (Some(_), Some(DataType::List { .. })) => ExprType::ArrayPrepend, // string concatenation - (Some(DataType::Varchar), _) | (_, Some(DataType::Varchar)) => { - ExprType::ConcatOp - } + (Some(DataType::Varchar), _) | (_, Some(DataType::Varchar)) => ExprType::Concat, (Some(DataType::Jsonb), Some(DataType::Jsonb)) | (Some(DataType::Jsonb), None) @@ -190,7 +188,7 @@ impl Binder { } // string concatenation - (None, _) | (_, None) => ExprType::ConcatOp, + (None, _) | (_, None) => ExprType::Concat, // invalid (Some(left_type), Some(right_type)) => { diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 59ea9a805576..6c9b56d97ef0 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -1362,6 +1362,7 @@ impl Binder { if variadic { let func = match function_name { "format" => ExprType::FormatVariadic, + "concat" => ExprType::ConcatVariadic, "concat_ws" => ExprType::ConcatWsVariadic, "jsonb_build_array" => ExprType::JsonbBuildArrayVariadic, "jsonb_build_object" => ExprType::JsonbBuildObjectVariadic, diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 7a954a04d7ee..7f655161a8d2 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -98,7 +98,8 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Md5 | expr_node::Type::CharLength | expr_node::Type::Repeat - | expr_node::Type::ConcatOp + | expr_node::Type::Concat + | expr_node::Type::ConcatVariadic | expr_node::Type::BoolOut | expr_node::Type::OctetLength | expr_node::Type::BitLength diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 10a52f8c3b91..40355ad5c84b 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -346,7 +346,7 @@ fn infer_type_for_special( } Ok(Some(DataType::Varchar)) } - ExprType::ConcatOp => { + ExprType::Concat => { for input in inputs { input.cast_explicit_mut(DataType::Varchar)?; } From 4cb34e18c7ee19e7d27d58a76776043789f0ab24 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Wed, 24 Jan 2024 20:10:24 +0800 Subject: [PATCH 09/14] fix clippy and revert concat_op Signed-off-by: Runji Wang --- proto/expr.proto | 5 +-- src/expr/impl/src/scalar/concat_op.rs | 35 +++++++++++++++++++ src/expr/impl/src/scalar/jsonb_access.rs | 2 +- src/expr/impl/src/scalar/mod.rs | 1 + .../tests/testdata/output/array.yaml | 2 +- .../tests/testdata/output/pg_catalog.yaml | 2 +- .../tests/testdata/output/subquery_expr.yaml | 2 +- src/frontend/src/binder/expr/binary_op.rs | 6 ++-- src/frontend/src/expr/pure.rs | 1 + src/frontend/src/expr/type_inference/func.rs | 6 +++- 10 files changed, 53 insertions(+), 9 deletions(-) create mode 100644 src/expr/impl/src/scalar/concat_op.rs diff --git a/proto/expr.proto b/proto/expr.proto index 69387671a6b1..f1c7ea801cfe 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -100,8 +100,9 @@ message ExprNode { MD5 = 224; CHAR_LENGTH = 225; REPEAT = 226; - CONCAT = 227; - CONCAT_VARIADIC = 286; + CONCAT_OP = 227; + CONCAT = 286; + CONCAT_VARIADIC = 287; // BOOL_OUT is different from CAST-bool-to-varchar in PostgreSQL. BOOL_OUT = 228; OCTET_LENGTH = 229; diff --git a/src/expr/impl/src/scalar/concat_op.rs b/src/expr/impl/src/scalar/concat_op.rs new file mode 100644 index 000000000000..399700b51c1f --- /dev/null +++ b/src/expr/impl/src/scalar/concat_op.rs @@ -0,0 +1,35 @@ +// Copyright 2024 RisingWave Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::fmt::Write; + +use risingwave_expr::function; + +#[function("concat_op(varchar, varchar) -> varchar")] +pub fn concat_op(left: &str, right: &str, writer: &mut impl Write) { + writer.write_str(left).unwrap(); + writer.write_str(right).unwrap(); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_concat_op() { + let mut s = String::new(); + concat_op("114", "514", &mut s); + assert_eq!(s, "114514") + } +} diff --git a/src/expr/impl/src/scalar/jsonb_access.rs b/src/expr/impl/src/scalar/jsonb_access.rs index 176e44810f23..05578e34b17d 100644 --- a/src/expr/impl/src/scalar/jsonb_access.rs +++ b/src/expr/impl/src/scalar/jsonb_access.rs @@ -99,7 +99,7 @@ pub fn jsonb_array_element(v: JsonbRef<'_>, p: i32) -> Option> { /// "bar" /// ``` #[function("jsonb_extract_path(jsonb, variadic varchar[]) -> jsonb")] -pub fn jsonb_extract_path<'a>(v: JsonbRef<'a>, path: impl Row) -> Option> { +pub fn jsonb_extract_path(v: JsonbRef<'_>, path: impl Row) -> Option> { let mut jsonb = v; for key in path.iter() { // return null if any element is null diff --git a/src/expr/impl/src/scalar/mod.rs b/src/expr/impl/src/scalar/mod.rs index 76843681c647..9dfa1726b6b7 100644 --- a/src/expr/impl/src/scalar/mod.rs +++ b/src/expr/impl/src/scalar/mod.rs @@ -36,6 +36,7 @@ mod cast; mod cmp; mod coalesce; mod concat; +mod concat_op; mod concat_ws; mod conjunction; mod date_trunc; diff --git a/src/frontend/planner_test/tests/testdata/output/array.yaml b/src/frontend/planner_test/tests/testdata/output/array.yaml index 8f90fbb539ea..e578ddac5307 100644 --- a/src/frontend/planner_test/tests/testdata/output/array.yaml +++ b/src/frontend/planner_test/tests/testdata/output/array.yaml @@ -233,7 +233,7 @@ sql: | select ('{c,' || 'd}')::varchar[]; logical_plan: |- - LogicalProject { exprs: [Concat('{c,':Varchar, 'd}':Varchar)::List(Varchar) as $expr1] } + LogicalProject { exprs: [ConcatOp('{c,':Varchar, 'd}':Varchar)::List(Varchar) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } - name: unknown to varchar[] in implicit context sql: | diff --git a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml index 6d4ceba83626..7842e311e47a 100644 --- a/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml +++ b/src/frontend/planner_test/tests/testdata/output/pg_catalog.yaml @@ -229,7 +229,7 @@ - sql: | select ('pg' || '_namespace')::regclass logical_plan: |- - LogicalProject { exprs: [CastRegclass(Concat('pg':Varchar, '_namespace':Varchar)) as $expr1] } + LogicalProject { exprs: [CastRegclass(ConcatOp('pg':Varchar, '_namespace':Varchar)) as $expr1] } └─LogicalValues { rows: [[]], schema: Schema { fields: [] } } batch_plan: |- BatchProject { exprs: [CastRegclass('pg_namespace':Varchar) as $expr1] } diff --git a/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml b/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml index 4eeb25605b20..e7c8940b5a8c 100644 --- a/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/subquery_expr.yaml @@ -241,7 +241,7 @@ logical_plan: |- LogicalProject { exprs: [sum(*VALUES*_0.column_0), max($expr1), string_agg($expr2, ',':Varchar)] } └─LogicalAgg { aggs: [sum(*VALUES*_0.column_0), max($expr1), string_agg($expr2, ',':Varchar)] } - └─LogicalProject { exprs: [*VALUES*_0.column_0, ((*VALUES*_0.column_1 + *VALUES*_0.column_2) + 10:Int32) as $expr1, Concat(*VALUES*_0.column_2::Varchar, '~':Varchar) as $expr2, ',':Varchar] } + └─LogicalProject { exprs: [*VALUES*_0.column_0, ((*VALUES*_0.column_1 + *VALUES*_0.column_2) + 10:Int32) as $expr1, ConcatOp(*VALUES*_0.column_2::Varchar, '~':Varchar) as $expr2, ',':Varchar] } └─LogicalValues { rows: [[1:Int32, 2:Int32, 3:Int32], [4:Int32, 5:Int32, 6:Int32]], schema: Schema { fields: [*VALUES*_0.column_0:Int32, *VALUES*_0.column_1:Int32, *VALUES*_0.column_2:Int32] } } - sql: | select 1 + (select 2 from t); diff --git a/src/frontend/src/binder/expr/binary_op.rs b/src/frontend/src/binder/expr/binary_op.rs index d446920cc7ff..fd14ffbd90a2 100644 --- a/src/frontend/src/binder/expr/binary_op.rs +++ b/src/frontend/src/binder/expr/binary_op.rs @@ -171,7 +171,9 @@ impl Binder { (Some(_), Some(DataType::List { .. })) => ExprType::ArrayPrepend, // string concatenation - (Some(DataType::Varchar), _) | (_, Some(DataType::Varchar)) => ExprType::Concat, + (Some(DataType::Varchar), _) | (_, Some(DataType::Varchar)) => { + ExprType::ConcatOp + } (Some(DataType::Jsonb), Some(DataType::Jsonb)) | (Some(DataType::Jsonb), None) @@ -188,7 +190,7 @@ impl Binder { } // string concatenation - (None, _) | (_, None) => ExprType::Concat, + (None, _) | (_, None) => ExprType::ConcatOp, // invalid (Some(left_type), Some(right_type)) => { diff --git a/src/frontend/src/expr/pure.rs b/src/frontend/src/expr/pure.rs index 7f655161a8d2..0e8350707fe5 100644 --- a/src/frontend/src/expr/pure.rs +++ b/src/frontend/src/expr/pure.rs @@ -98,6 +98,7 @@ impl ExprVisitor for ImpureAnalyzer { | expr_node::Type::Md5 | expr_node::Type::CharLength | expr_node::Type::Repeat + | expr_node::Type::ConcatOp | expr_node::Type::Concat | expr_node::Type::ConcatVariadic | expr_node::Type::BoolOut diff --git a/src/frontend/src/expr/type_inference/func.rs b/src/frontend/src/expr/type_inference/func.rs index 40355ad5c84b..96c9999c36af 100644 --- a/src/frontend/src/expr/type_inference/func.rs +++ b/src/frontend/src/expr/type_inference/func.rs @@ -335,6 +335,10 @@ fn infer_type_for_special( ensure_arity!("coalesce", 1 <= | inputs |); align_types(inputs.iter_mut()).map(Some).map_err(Into::into) } + ExprType::Concat => { + ensure_arity!("concat", 1 <= | inputs |); + Ok(Some(DataType::Varchar)) + } ExprType::ConcatWs => { ensure_arity!("concat_ws", 2 <= | inputs |); // 0-th arg must be string @@ -346,7 +350,7 @@ fn infer_type_for_special( } Ok(Some(DataType::Varchar)) } - ExprType::Concat => { + ExprType::ConcatOp => { for input in inputs { input.cast_explicit_mut(DataType::Varchar)?; } From cea8d3edba5cc1f7e6620e1c7c539a43d2664aed Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 25 Jan 2024 12:51:37 +0800 Subject: [PATCH 10/14] fix variadic for jsonb_build_array and jsonb_build_object Signed-off-by: Runji Wang --- src/expr/core/src/expr/mod.rs | 2 ++ src/expr/impl/src/scalar/jsonb_build.rs | 18 +++++++++++++++--- src/expr/macro/src/gen.rs | 2 ++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/expr/core/src/expr/mod.rs b/src/expr/core/src/expr/mod.rs index d1ced2b3322c..6dbb3906f561 100644 --- a/src/expr/core/src/expr/mod.rs +++ b/src/expr/core/src/expr/mod.rs @@ -209,4 +209,6 @@ where pub struct Context { pub arg_types: Vec, pub return_type: DataType, + /// Whether the function is variadic. + pub variadic: bool, } diff --git a/src/expr/impl/src/scalar/jsonb_build.rs b/src/expr/impl/src/scalar/jsonb_build.rs index ddf22def1526..63bf7d99b288 100644 --- a/src/expr/impl/src/scalar/jsonb_build.rs +++ b/src/expr/impl/src/scalar/jsonb_build.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +use itertools::Either; use jsonbb::Builder; use risingwave_common::row::Row; use risingwave_common::types::{JsonbVal, ScalarRefImpl}; @@ -41,8 +42,15 @@ use super::{ToJsonb, ToTextDisplay}; fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { let mut builder = Builder::>::new(); builder.begin_array(); - for (value, ty) in args.iter().zip_eq_debug(&ctx.arg_types) { - value.add_to(ty, &mut builder)?; + if ctx.variadic { + for (value, ty) in args.iter().zip_eq_debug(&ctx.arg_types) { + value.add_to(ty, &mut builder)?; + } + } else { + let ty = ctx.arg_types[0].as_list(); + for value in args.iter() { + value.add_to(ty, &mut builder)?; + } } builder.end_array(); Ok(builder.finish().into()) @@ -75,9 +83,13 @@ fn jsonb_build_object(args: impl Row, ctx: &Context) -> Result { } let mut builder = Builder::>::new(); builder.begin_object(); + let arg_types = match ctx.variadic { + true => Either::Left(ctx.arg_types.iter()), + false => Either::Right(itertools::repeat_n(ctx.arg_types[0].as_list(), args.len())), + }; for (i, [(key, _), (value, value_type)]) in args .iter() - .zip_eq_debug(&ctx.arg_types) + .zip_eq_debug(arg_types) .array_chunks() .enumerate() { diff --git a/src/expr/macro/src/gen.rs b/src/expr/macro/src/gen.rs index 80f074b14322..f9223da17fca 100644 --- a/src/expr/macro/src/gen.rs +++ b/src/expr/macro/src/gen.rs @@ -519,6 +519,7 @@ impl FunctionAttr { let context = Context { return_type, arg_types: children.iter().map(|c| c.return_type()).collect(), + variadic: #variadic, }; #[derive(Debug)] @@ -869,6 +870,7 @@ impl FunctionAttr { let context = Context { return_type: agg.return_type.clone(), arg_types: agg.args.arg_types().to_owned(), + variadic: false, }; struct Agg { From 443181e177bcc23513e38875eb57718282e2cb0c Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 25 Jan 2024 18:52:22 +0800 Subject: [PATCH 11/14] fix build and test Signed-off-by: Runji Wang --- src/expr/impl/src/scalar/cast.rs | 9 +++++++++ src/expr/impl/src/scalar/jsonb_build.rs | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/src/expr/impl/src/scalar/cast.rs b/src/expr/impl/src/scalar/cast.rs index dc81e3ab77ba..36f7335d2a14 100644 --- a/src/expr/impl/src/scalar/cast.rs +++ b/src/expr/impl/src/scalar/cast.rs @@ -301,6 +301,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{}", &ctx).unwrap(), @@ -313,6 +314,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert_eq!(str_to_list("{1, 2, 3}", &ctx).unwrap(), list123); @@ -321,6 +323,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[][]").unwrap(), + variadic: false, }; assert_eq!(str_to_list("{{1, 2, 3}}", &ctx).unwrap(), nested_list123); @@ -333,6 +336,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[][][]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(), @@ -343,6 +347,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::from_str("int[][]").unwrap()], return_type: DataType::from_str("varchar[][]").unwrap(), + variadic: false, }; let double_nested_varchar_list123_445566 = ListValue::from_iter([ list_cast(nested_list123.as_scalar_ref(), &ctx).unwrap(), @@ -353,6 +358,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("varchar[][][]").unwrap(), + variadic: false, }; assert_eq!( str_to_list("{{{1, 2, 3}}, {{44, 55, 66}}}", &ctx).unwrap(), @@ -366,6 +372,7 @@ mod tests { let ctx = Context { arg_types: vec![DataType::Varchar], return_type: DataType::from_str("int[]").unwrap(), + variadic: false, }; assert!(str_to_list("{{}", &ctx).is_err()); assert!(str_to_list("{}}", &ctx).is_err()); @@ -384,6 +391,7 @@ mod tests { ("a", DataType::Int32), ("b", DataType::Int32), ])), + variadic: false, }; assert_eq!( struct_cast( @@ -419,6 +427,7 @@ mod tests { let ctx_str_to_int16 = Context { arg_types: vec![DataType::Varchar], return_type: DataType::Int16, + variadic: false, }; test_str_to_int16::(|x| str_parse(x, &ctx_str_to_int16).unwrap()).await; } diff --git a/src/expr/impl/src/scalar/jsonb_build.rs b/src/expr/impl/src/scalar/jsonb_build.rs index 63bf7d99b288..5949faf9bc5c 100644 --- a/src/expr/impl/src/scalar/jsonb_build.rs +++ b/src/expr/impl/src/scalar/jsonb_build.rs @@ -71,7 +71,7 @@ fn jsonb_build_array(args: impl Row, ctx: &Context) -> Result { /// query T /// select jsonb_build_object(variadic array['foo', '1', '2', 'bar']); /// ---- -/// {"2": "bar", "foo": 1} +/// {"2": "bar", "foo": "1"} /// ``` #[function("jsonb_build_object(variadic anyarray) -> jsonb")] fn jsonb_build_object(args: impl Row, ctx: &Context) -> Result { From d011ec502db077ae55475679e8f9574a76f64bb8 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 7 Mar 2024 15:04:46 +0800 Subject: [PATCH 12/14] update proto enum number Signed-off-by: Runji Wang --- proto/expr.proto | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index 0b6a3c502430..802397f0456f 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -185,7 +185,7 @@ message ExprNode { LEFT = 317; RIGHT = 318; FORMAT = 319; - FORMAT_VARIADIC = 324; + FORMAT_VARIADIC = 326; PGWIRE_SEND = 320; PGWIRE_RECV = 321; CONVERT_FROM = 322; @@ -236,10 +236,10 @@ message ExprNode { // jsonb ->> int, jsonb ->> text that returns text JSONB_ACCESS_STR = 601; // jsonb #> text[] -> jsonb - JSONB_EXTRACT_PATH = 626; + JSONB_EXTRACT_PATH = 627; JSONB_EXTRACT_PATH_VARIADIC = 613; // jsonb #>> text[] -> text - JSONB_EXTRACT_PATH_TEXT = 627; + JSONB_EXTRACT_PATH_TEXT = 628; JSONB_EXTRACT_PATH_TEXT_VARIADIC = 614; JSONB_TYPEOF = 602; JSONB_ARRAY_LENGTH = 603; @@ -267,9 +267,9 @@ message ExprNode { JSONB_STRIP_NULLS = 616; TO_JSONB = 617; JSONB_BUILD_ARRAY = 618; - JSONB_BUILD_ARRAY_VARIADIC = 624; + JSONB_BUILD_ARRAY_VARIADIC = 625; JSONB_BUILD_OBJECT = 619; - JSONB_BUILD_OBJECT_VARIADIC = 625; + JSONB_BUILD_OBJECT_VARIADIC = 626; JSONB_PATH_EXISTS = 620; JSONB_PATH_MATCH = 621; JSONB_PATH_QUERY_ARRAY = 622; From 49b870802885a7fd831e0e24889a045c1d338c4a Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 7 Mar 2024 15:19:00 +0800 Subject: [PATCH 13/14] add missing match arms Signed-off-by: Runji Wang --- src/frontend/src/optimizer/plan_expr_visitor/strong.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs index dca2063dac92..a80a696c2c16 100644 --- a/src/frontend/src/optimizer/plan_expr_visitor/strong.rs +++ b/src/frontend/src/optimizer/plan_expr_visitor/strong.rs @@ -147,7 +147,10 @@ impl Strong { | ExprType::Round | ExprType::Ascii | ExprType::Translate + | ExprType::Concat + | ExprType::ConcatVariadic | ExprType::ConcatWs + | ExprType::ConcatWsVariadic | ExprType::Abs | ExprType::SplitPart | ExprType::ToChar @@ -217,6 +220,7 @@ impl Strong { | ExprType::Left | ExprType::Right | ExprType::Format + | ExprType::FormatVariadic | ExprType::PgwireSend | ExprType::PgwireRecv | ExprType::ConvertFrom @@ -255,7 +259,9 @@ impl Strong { | ExprType::JsonbAccess | ExprType::JsonbAccessStr | ExprType::JsonbExtractPath + | ExprType::JsonbExtractPathVariadic | ExprType::JsonbExtractPathText + | ExprType::JsonbExtractPathTextVariadic | ExprType::JsonbTypeof | ExprType::JsonbArrayLength | ExprType::IsJson @@ -271,7 +277,9 @@ impl Strong { | ExprType::JsonbStripNulls | ExprType::ToJsonb | ExprType::JsonbBuildArray + | ExprType::JsonbBuildArrayVariadic | ExprType::JsonbBuildObject + | ExprType::JsonbBuildObjectVariadic | ExprType::JsonbPathExists | ExprType::JsonbPathMatch | ExprType::JsonbPathQueryArray From 5a900e2645daf17fbc6f91dee4e7177dd7b8b373 Mon Sep 17 00:00:00 2001 From: Runji Wang Date: Thu, 7 Mar 2024 15:26:35 +0800 Subject: [PATCH 14/14] add planner test Signed-off-by: Runji Wang --- .../planner_test/tests/testdata/input/expr.yaml | 12 ++++++++++++ .../planner_test/tests/testdata/output/expr.yaml | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/src/frontend/planner_test/tests/testdata/input/expr.yaml b/src/frontend/planner_test/tests/testdata/input/expr.yaml index cc8f104e1ac8..e19ee4a29126 100644 --- a/src/frontend/planner_test/tests/testdata/input/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/input/expr.yaml @@ -454,3 +454,15 @@ expected_outputs: - batch_plan - stream_error +- name: without variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', a, b) from t; + expected_outputs: + - batch_plan +- name: with variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', variadic array[a, b]) from t; + expected_outputs: + - batch_plan diff --git a/src/frontend/planner_test/tests/testdata/output/expr.yaml b/src/frontend/planner_test/tests/testdata/output/expr.yaml index ef6727193f78..4c23cadf7cb4 100644 --- a/src/frontend/planner_test/tests/testdata/output/expr.yaml +++ b/src/frontend/planner_test/tests/testdata/output/expr.yaml @@ -635,3 +635,19 @@ Not supported: streaming nested-loop join HINT: The non-equal join in the query requires a nested-loop join executor, which could be very expensive to run. Consider rewriting the query to use dynamic filter as a substitute if possible. See also: https://github.com/risingwavelabs/rfcs/blob/main/rfcs/0033-dynamic-filter.md +- name: without variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', a, b) from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [ConcatWs(',':Varchar, t.a, t.b) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard } +- name: with variadic keyword + sql: | + create table t(a varchar, b varchar); + select concat_ws(',', variadic array[a, b]) from t; + batch_plan: |- + BatchExchange { order: [], dist: Single } + └─BatchProject { exprs: [ConcatWsVariadic(',':Varchar, Array(t.a, t.b)) as $expr1] } + └─BatchScan { table: t, columns: [t.a, t.b], distribution: SomeShard }