Skip to content

Commit

Permalink
Merge a1c179b into 0428ac7
Browse files Browse the repository at this point in the history
  • Loading branch information
SuperBo committed Nov 10, 2022
2 parents 0428ac7 + a1c179b commit d95e648
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 4 deletions.
42 changes: 42 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,8 @@ pub enum Expr {
ArraySubquery(Box<Query>),
/// The `LISTAGG` function `SELECT LISTAGG(...) WITHIN GROUP (ORDER BY ...)`
ListAgg(ListAgg),
/// The `ARRAY_AGG` function `SELECT ARRAY_AGG(... ORDER BY ...)`
ArrayAgg(ArrayAgg),
/// The `GROUPING SETS` expr.
GroupingSets(Vec<Vec<Expr>>),
/// The `CUBE` expr.
Expand Down Expand Up @@ -655,6 +657,7 @@ impl fmt::Display for Expr {
Expr::Subquery(s) => write!(f, "({})", s),
Expr::ArraySubquery(s) => write!(f, "ARRAY({})", s),
Expr::ListAgg(listagg) => write!(f, "{}", listagg),
Expr::ArrayAgg(arrayagg) => write!(f, "{}", arrayagg),
Expr::GroupingSets(sets) => {
write!(f, "GROUPING SETS (")?;
let mut sep = "";
Expand Down Expand Up @@ -3036,6 +3039,45 @@ impl fmt::Display for ListAggOnOverflow {
}
}

/// An `ARRAY_AGG` invocation `ARRAY_AGG( [ DISTINCT ] <expr> [ORDER BY <expr>] [LIMIT <n>] )`
/// Or `ARRAY_AGG( [ DISTINCT ] <expr> ) [ WITHIN GROUP ( ORDER BY <expr> ) ]`
/// ORDER BY position is defined differently for BigQuery, Postgres and Snowflake.
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct ArrayAgg {
pub distinct: bool,
pub expr: Box<Expr>,
pub order_by: Option<Box<OrderByExpr>>,
pub limit: Option<Box<Expr>>,
pub within_group: bool, // order by is used inside a within group or not
}

impl fmt::Display for ArrayAgg {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"ARRAY_AGG({}{}",
if self.distinct { "DISTINCT " } else { "" },
self.expr
)?;
if !self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " ORDER BY {}", order_by)?;
}
if let Some(limit) = &self.limit {
write!(f, " LIMIT {}", limit)?;
}
}
write!(f, ")")?;
if self.within_group {
if let Some(order_by) = &self.order_by {
write!(f, " WITHIN GROUP (ORDER BY {})", order_by)?;
}
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ObjectType {
Expand Down
6 changes: 6 additions & 0 deletions src/dialect/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ pub trait Dialect: Debug + Any {
fn supports_filter_during_aggregation(&self) -> bool {
false
}
/// Returns true if the dialect supports ARRAY_AGG() [WITHIN GROUP (ORDER BY)] expressions.
/// Otherwise, the dialect should expect an `ORDER BY` without the `WITHIN GROUP` clause, e.g. `ANSI` [(1)].
/// [(1)]: https://jakewheat.github.io/sql-overview/sql-2016-foundation-grammar.html#array-aggregate-function
fn supports_within_after_array_aggregation(&self) -> bool {
false
}
/// Dialect-specific prefix parser override
fn parse_prefix(&self, _parser: &mut Parser) -> Option<Result<Expr, ParserError>> {
// return None to fall back to the default behavior
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,8 @@ impl Dialect for SnowflakeDialect {
|| ch == '$'
|| ch == '_'
}

fn supports_within_after_array_aggregation(&self) -> bool {
true
}
}
49 changes: 49 additions & 0 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::LParen)?;
self.parse_array_subquery()
}
Keyword::ARRAY_AGG => self.parse_array_agg_expr(),
Keyword::NOT => self.parse_not(),
// Here `w` is a word, check if it's a part of a multi-part
// identifier, a function call, or a simple identifier:
Expand Down Expand Up @@ -1071,6 +1072,54 @@ impl<'a> Parser<'a> {
}))
}

pub fn parse_array_agg_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let distinct = self.parse_keyword(Keyword::DISTINCT);
let expr = Box::new(self.parse_expr()?);
// ANSI SQL and BigQuery define ORDER BY inside function.
if !self.dialect.supports_within_after_array_aggregation() {
let order_by = if self.parse_keywords(&[Keyword::ORDER, Keyword::BY]) {
let order_by_expr = self.parse_order_by_expr()?;
Some(Box::new(order_by_expr))
} else {
None
};
let limit = if self.parse_keyword(Keyword::LIMIT) {
self.parse_limit()?.map(Box::new)
} else {
None
};
self.expect_token(&Token::RParen)?;
return Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by,
limit,
within_group: false,
}));
}
// Snowflake defines ORDERY BY in within group instead of inside the function like
// ANSI SQL.
self.expect_token(&Token::RParen)?;
let within_group = if self.parse_keywords(&[Keyword::WITHIN, Keyword::GROUP]) {
self.expect_token(&Token::LParen)?;
self.expect_keywords(&[Keyword::ORDER, Keyword::BY])?;
let order_by_expr = self.parse_order_by_expr()?;
self.expect_token(&Token::RParen)?;
Some(Box::new(order_by_expr))
} else {
None
};

Ok(Expr::ArrayAgg(ArrayAgg {
distinct,
expr,
order_by: within_group,
limit: None,
within_group: true,
}))
}

// This function parses date/time fields for the EXTRACT function-like
// operator, interval qualifiers, and the ceil/floor operations.
// EXTRACT supports a wider set of date/time fields than interval qualifiers,
Expand Down
11 changes: 11 additions & 0 deletions tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,17 @@ fn parse_similar_to() {
chk(true);
}

#[test]
fn parse_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
bigquery().verified_stmt(sql);
}
}

fn bigquery() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(BigQueryDialect {})],
Expand Down
21 changes: 21 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1777,6 +1777,27 @@ fn parse_listagg() {
);
}

#[test]
fn parse_array_agg_func() {
let supported_dialects = TestedDialects {
dialects: vec![
Box::new(GenericDialect {}),
Box::new(PostgreSqlDialect {}),
Box::new(MsSqlDialect {}),
Box::new(AnsiDialect {}),
Box::new(HiveDialect {}),
],
};

for sql in [
"SELECT ARRAY_AGG(x ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(x ORDER BY x LIMIT 2) FROM tbl",
"SELECT ARRAY_AGG(DISTINCT x ORDER BY x LIMIT 2) FROM tbl",
] {
supported_dialects.verified_stmt(sql);
}
}

#[test]
fn parse_create_table() {
let sql = "CREATE TABLE uk_cities (\
Expand Down
8 changes: 4 additions & 4 deletions tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,17 +281,17 @@ fn parse_create_function() {
#[test]
fn filtering_during_aggregation() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL), \
array_agg(name) FILTER (WHERE name LIKE 'a%') \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL), \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') \
FROM region";
println!("{}", hive().verified_stmt(rename));
}

#[test]
fn filtering_during_aggregation_aliased() {
let rename = "SELECT \
array_agg(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
array_agg(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
ARRAY_AGG(name) FILTER (WHERE name IS NOT NULL) AS agg1, \
ARRAY_AGG(name) FILTER (WHERE name LIKE 'a%') AS agg2 \
FROM region";
println!("{}", hive().verified_stmt(rename));
}
Expand Down
19 changes: 19 additions & 0 deletions tests/sqlparser_snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,25 @@ fn parse_similar_to() {
chk(true);
}

#[test]
fn test_array_agg_func() {
for sql in [
"SELECT ARRAY_AGG(x) WITHIN GROUP (ORDER BY x) AS a FROM T",
"SELECT ARRAY_AGG(DISTINCT x) WITHIN GROUP (ORDER BY x ASC) FROM tbl",
] {
snowflake().verified_stmt(sql);
}

let sql = "select array_agg(x order by x) as a from T";
let result = snowflake().parse_sql_statements(sql);
assert_eq!(
result,
Err(ParserError::ParserError(String::from(
"Expected ), found: order"
)))
)
}

fn snowflake() -> TestedDialects {
TestedDialects {
dialects: vec![Box::new(SnowflakeDialect {})],
Expand Down

0 comments on commit d95e648

Please sign in to comment.