Skip to content

Commit

Permalink
Feat(clickhouse): add support for LIMIT BY clause (#2926)
Browse files Browse the repository at this point in the history
* Feat(clickhouse): add support for LIMIT BY clause

* Improve coverage

* Cover OFFSET modifier case
  • Loading branch information
georgesittas committed Feb 7, 2024
1 parent d684759 commit 1842c96
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 11 deletions.
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,7 +1796,7 @@ class Lambda(Expression):


class Limit(Expression):
arg_types = {"this": False, "expression": True, "offset": False}
arg_types = {"this": False, "expression": True, "offset": False, "expressions": False}


class Literal(Condition):
Expand Down Expand Up @@ -1969,7 +1969,7 @@ class Final(Expression):


class Offset(Expression):
arg_types = {"this": False, "expression": True}
arg_types = {"this": False, "expression": True, "expressions": False}


class Order(Expression):
Expand Down
15 changes: 9 additions & 6 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,15 +1841,18 @@ def limit_sql(self, expression: exp.Limit, top: bool = False) -> str:

args_sql = ", ".join(self.sql(e) for e in args)
args_sql = f"({args_sql})" if any(top and not e.is_number for e in args) else args_sql
return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}"
expressions = self.expressions(expression, flat=True)
expressions = f" BY {expressions}" if expressions else ""

return f"{this}{self.seg('TOP' if top else 'LIMIT')} {args_sql}{expressions}"

def offset_sql(self, expression: exp.Offset) -> str:
this = self.sql(expression, "this")
expression = expression.expression
expression = (
self._simplify_unless_literal(expression) if self.LIMIT_ONLY_LITERALS else expression
)
return f"{this}{self.seg('OFFSET')} {self.sql(expression)}"
value = expression.expression
value = self._simplify_unless_literal(value) if self.LIMIT_ONLY_LITERALS else value
expressions = self.expressions(expression, flat=True)
expressions = f" BY {expressions}" if expressions else ""
return f"{this}{self.seg('OFFSET')} {self.sql(value)}{expressions}"

def setitem_sql(self, expression: exp.SetItem) -> str:
kind = self.sql(expression, "kind")
Expand Down
23 changes: 20 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2465,8 +2465,14 @@ def _parse_query_modifiers(
this.set(key, expression)
if key == "limit":
offset = expression.args.pop("offset", None)

if offset:
this.set("offset", exp.Offset(expression=offset))
offset = exp.Offset(expression=offset)
this.set("offset", offset)

limit_by_expressions = expression.expressions
expression.set("expressions", None)
offset.set("expressions", limit_by_expressions)
continue
break
return this
Expand Down Expand Up @@ -3341,7 +3347,12 @@ def _parse_limit(
offset = None

limit_exp = self.expression(
exp.Limit, this=this, expression=expression, offset=offset, comments=comments
exp.Limit,
this=this,
expression=expression,
offset=offset,
comments=comments,
expressions=self._parse_limit_by(),
)

return limit_exp
Expand Down Expand Up @@ -3377,7 +3388,13 @@ def _parse_offset(self, this: t.Optional[exp.Expression] = None) -> t.Optional[e

count = self._parse_term()
self._match_set((TokenType.ROW, TokenType.ROWS))
return self.expression(exp.Offset, this=this, expression=count)

return self.expression(
exp.Offset, this=this, expression=count, expressions=self._parse_limit_by()
)

def _parse_limit_by(self) -> t.Optional[t.List[exp.Expression]]:
return self._match_text_seq("BY") and self._parse_csv(self._parse_bitwise)

def _parse_locks(self) -> t.List[exp.Lock]:
locks = []
Expand Down
9 changes: 9 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def test_clickhouse(self):
self.validate_identity("CAST(x as MEDIUMINT)", "CAST(x AS Int32)")
self.validate_identity("SELECT arrayJoin([1, 2, 3] AS src) AS dst, 'Hello', src")
self.validate_identity("""SELECT JSONExtractString('{"x": {"y": 1}}', 'x', 'y')""")
self.validate_identity("SELECT * FROM table LIMIT 1 BY a, b")
self.validate_identity("SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b")
self.validate_identity(
"SELECT * FROM table LIMIT 1, 2 BY a, b",
"SELECT * FROM table LIMIT 2 OFFSET 1 BY a, b",
)
self.validate_identity(
"SELECT * FROM table LIMIT 1 BY CONCAT(datalayerVariantNo, datalayerProductId, warehouse)"
)
self.validate_identity(
"""SELECT JSONExtractString('{"a": "hello", "b": [-100, 200.0, 300]}', 'a')"""
)
Expand Down

0 comments on commit 1842c96

Please sign in to comment.