Skip to content

Commit

Permalink
Feat(prql): add support for AGGREGATE (#3395)
Browse files Browse the repository at this point in the history
* aggregate?

* update handle for this keyword

* update sum for prql

* add test

* update

* update

* update

* update

* Update sqlglot/dialects/prql.py

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>

* Update sqlglot/dialects/prql.py

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
fool1280 and georgesittas committed May 2, 2024
1 parent 00ff877 commit 16691f9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
40 changes: 37 additions & 3 deletions sqlglot/dialects/prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from sqlglot import exp, parser, tokens
from sqlglot.dialects.dialect import Dialect
from sqlglot.helper import seq_get
from sqlglot.tokens import TokenType


Expand Down Expand Up @@ -53,6 +54,15 @@ class Parser(parser.Parser):
_select_all(self._parse_table()), distinct=False, copy=False
),
"SORT": lambda self, query: self._parse_order_by(query),
"AGGREGATE": lambda self, query: self._parse_selection(
query, parse_method=self._parse_aggregate, append=False
),
}

FUNCTIONS = {
**parser.Parser.FUNCTIONS,
"AVERAGE": exp.Avg.from_arg_list,
"SUM": lambda args: exp.func("COALESCE", exp.Sum(this=seq_get(args, 0)), 0),
}

def _parse_equality(self) -> t.Optional[exp.Expression]:
Expand Down Expand Up @@ -87,14 +97,20 @@ def _parse_query(self) -> t.Optional[exp.Query]:

return query

def _parse_selection(self, query: exp.Query, append: bool = True) -> exp.Query:
def _parse_selection(
self,
query: exp.Query,
parse_method: t.Optional[t.Callable] = None,
append: bool = True,
) -> exp.Query:
parse_method = parse_method if parse_method else self._parse_expression
if self._match(TokenType.L_BRACE):
selects = self._parse_csv(self._parse_expression)
selects = self._parse_csv(parse_method)

if not self._match(TokenType.R_BRACE, expression=query):
self.raise_error("Expecting }")
else:
expression = self._parse_expression()
expression = parse_method()
selects = [expression] if expression else []

projections = {
Expand Down Expand Up @@ -136,6 +152,24 @@ def _parse_order_by(self, query: exp.Select) -> t.Optional[exp.Query]:
self.raise_error("Expecting }")
return query.order_by(self.expression(exp.Order, expressions=expressions), copy=False)

def _parse_aggregate(self) -> t.Optional[exp.Expression]:
alias = None
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(any_token=True)
self._match(TokenType.ALIAS)

name = self._curr and self._curr.text.upper()
func_builder = self.FUNCTIONS.get(name)
if func_builder:
self._advance()
args = self._parse_column()
func = func_builder([args])
else:
self.raise_error(f"Unsupported aggregation function {name}")
if alias:
return self.expression(exp.Alias, this=func, alias=alias)
return func

def _parse_expression(self) -> t.Optional[exp.Expression]:
if self._next and self._next.token_type == TokenType.ALIAS:
alias = self._parse_id_var(True)
Expand Down
13 changes: 13 additions & 0 deletions tests/dialects/test_prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,16 @@ def test_prql(self):
"from x filter (a > 1 || null != b || c != null)",
"SELECT * FROM x WHERE (a > 1 OR NOT b IS NULL OR NOT c IS NULL)",
)
self.validate_identity("from a aggregate { average x }", "SELECT AVG(x) FROM a")
self.validate_identity(
"from a aggregate { average x, min y, ct = sum z }",
"SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) AS ct FROM a",
)
self.validate_identity(
"from a aggregate { average x, min y, sum z }",
"SELECT AVG(x), MIN(y), COALESCE(SUM(z), 0) FROM a",
)
self.validate_identity(
"from a aggregate { min y, b = stddev x, max z }",
"SELECT MIN(y), STDDEV(x) AS b, MAX(z) FROM a",
)

0 comments on commit 16691f9

Please sign in to comment.