Skip to content

Commit

Permalink
feat!: allow unions to be limited directly and stop subquerying since… (
Browse files Browse the repository at this point in the history
#3301)

* feat!: allow unions to be limited directly and stop subquerying since it is not necessary

closes #3300

* clickhouse and tsql overrides
  • Loading branch information
tobymao committed Apr 11, 2024
1 parent b523bb6 commit 32cdc36
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 107 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/clickhouse.py
Expand Up @@ -668,6 +668,7 @@ class Generator(generator.Generator):
TABLE_HINTS = False
EXPLICIT_UNION = True
GROUPINGS_SEP = ""
OUTER_UNION_MODIFIERS = False

# there's no list in docs, but it can be found in Clickhouse code
# see `ClickHouse/src/Parsers/ParserCreate*.cpp`
Expand Down
8 changes: 1 addition & 7 deletions sqlglot/dialects/tsql.py
Expand Up @@ -726,6 +726,7 @@ class Generator(generator.Generator):
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False
OUTER_UNION_MODIFIERS = False

EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
Expand Down Expand Up @@ -882,13 +883,6 @@ def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:

return rename_func("DATETIMEFROMPARTS")(self, expression)

def set_operations(self, expression: exp.Union) -> str:
limit = expression.args.get("limit")
if limit:
return self.sql(expression.limit(limit.pop(), copy=False))

return super().set_operations(expression)

def setitem_sql(self, expression: exp.SetItem) -> str:
this = expression.this
if isinstance(this, exp.EQ) and not isinstance(this.left, exp.Parameter):
Expand Down
185 changes: 88 additions & 97 deletions sqlglot/expressions.py
Expand Up @@ -1028,14 +1028,14 @@ def subquery(self, alias: t.Optional[ExpOrStr] = None, copy: bool = True) -> Sub
return Subquery(this=instance, alias=alias)

def limit(
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Select:
self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Q:
"""
Adds a LIMIT clause to this query.
Example:
>>> select("1").union(select("1")).limit(1).sql()
'SELECT * FROM (SELECT 1 UNION SELECT 1) AS _l_0 LIMIT 1'
'SELECT 1 UNION SELECT 1 LIMIT 1'
Args:
expression: the SQL code string to parse.
Expand All @@ -1049,10 +1049,90 @@ def limit(
Returns:
A limited Select expression.
"""
return (
select("*")
.from_(self.subquery(alias="_l_0", copy=copy))
.limit(expression, dialect=dialect, copy=False, **opts)
return _apply_builder(
expression=expression,
instance=self,
arg="limit",
into=Limit,
prefix="LIMIT",
dialect=dialect,
copy=copy,
into_arg="expression",
**opts,
)

def offset(
self: Q, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Q:
"""
Set the OFFSET expression.
Example:
>>> Select().from_("tbl").select("x").offset(10).sql()
'SELECT x FROM tbl OFFSET 10'
Args:
expression: the SQL code string to parse.
This can also be an integer.
If a `Offset` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Select expression.
"""
return _apply_builder(
expression=expression,
instance=self,
arg="offset",
into=Offset,
prefix="OFFSET",
dialect=dialect,
copy=copy,
into_arg="expression",
**opts,
)

def order_by(
self: Q,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
*opts,
) -> Q:
"""
Set the ORDER BY expression.
Example:
>>> Select().from_("tbl").select("x").order_by("x DESC").sql()
'SELECT x FROM tbl ORDER BY x DESC'
Args:
*expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Order`.
append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Order` expression into a single expression.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Select expression.
"""
return _apply_child_list_builder(
*expressions,
instance=self,
arg="order",
append=append,
copy=copy,
prefix="ORDER BY",
into=Order,
dialect=dialect,
**opts,
)

@property
Expand Down Expand Up @@ -3078,46 +3158,6 @@ def group_by(
**opts,
)

def order_by(
self,
*expressions: t.Optional[ExpOrStr],
append: bool = True,
dialect: DialectType = None,
copy: bool = True,
**opts,
) -> Select:
"""
Set the ORDER BY expression.
Example:
>>> Select().from_("tbl").select("x").order_by("x DESC").sql()
'SELECT x FROM tbl ORDER BY x DESC'
Args:
*expressions: the SQL code strings to parse.
If a `Group` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Order`.
append: if `True`, add to any existing expressions.
Otherwise, this flattens all the `Order` expression into a single expression.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Select expression.
"""
return _apply_child_list_builder(
*expressions,
instance=self,
arg="order",
append=append,
copy=copy,
prefix="ORDER BY",
into=Order,
dialect=dialect,
**opts,
)

def sort_by(
self,
*expressions: t.Optional[ExpOrStr],
Expand Down Expand Up @@ -3198,55 +3238,6 @@ def cluster_by(
**opts,
)

def limit(
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Select:
return _apply_builder(
expression=expression,
instance=self,
arg="limit",
into=Limit,
prefix="LIMIT",
dialect=dialect,
copy=copy,
into_arg="expression",
**opts,
)

def offset(
self, expression: ExpOrStr | int, dialect: DialectType = None, copy: bool = True, **opts
) -> Select:
"""
Set the OFFSET expression.
Example:
>>> Select().from_("tbl").select("x").offset(10).sql()
'SELECT x FROM tbl OFFSET 10'
Args:
expression: the SQL code string to parse.
This can also be an integer.
If a `Offset` instance is passed, this is used as-is.
If another `Expression` instance is passed, it will be wrapped in a `Offset`.
dialect: the dialect used to parse the input expression.
copy: if `False`, modify this expression instance in-place.
opts: other options to use to parse the input expressions.
Returns:
The modified Select expression.
"""
return _apply_builder(
expression=expression,
instance=self,
arg="offset",
into=Offset,
prefix="OFFSET",
dialect=dialect,
copy=copy,
into_arg="expression",
**opts,
)

def select(
self,
*expressions: t.Optional[ExpOrStr],
Expand Down Expand Up @@ -6800,7 +6791,7 @@ def subquery(
A new Select instance with the subquery expression included.
"""

expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias)
expression = maybe_parse(expression, dialect=dialect, **opts).subquery(alias, **opts)
return Select().from_(expression, dialect=dialect, **opts)


Expand Down
18 changes: 18 additions & 0 deletions sqlglot/generator.py
Expand Up @@ -334,6 +334,11 @@ class Generator(metaclass=_Generator):
# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True

# Whether or not union modifiers apply to the outer union or select.
# SELECT * FROM x UNION SELECT * FROM y LIMIT 1
# True means limit 1 happens after the union, False means it it happens on y.
OUTER_UNION_MODIFIERS = True

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -2312,6 +2317,19 @@ def qualify_sql(self, expression: exp.Qualify) -> str:
return f"{self.seg('QUALIFY')}{self.sep()}{this}"

def set_operations(self, expression: exp.Union) -> str:
if not self.OUTER_UNION_MODIFIERS:
limit = expression.args.get("limit")
order = expression.args.get("order")

if limit or order:
select = exp.subquery(expression, "_l_0", copy=False).select("*", copy=False)

if limit:
select = select.limit(limit.pop(), copy=False)
if order:
select = select.order_by(order.pop(), copy=False)
return self.sql(select)

sqls: t.List[str] = []
stack: t.List[t.Union[str, exp.Expression]] = [expression]

Expand Down
9 changes: 9 additions & 0 deletions tests/dialects/test_dialect.py
Expand Up @@ -1332,6 +1332,15 @@ def test_lateral_subquery(self):
)

def test_set_operators(self):
self.validate_all(
"SELECT * FROM a UNION SELECT * FROM b ORDER BY x LIMIT 1",
write={
"": "SELECT * FROM a UNION SELECT * FROM b ORDER BY x LIMIT 1",
"clickhouse": "SELECT * FROM (SELECT * FROM a UNION DISTINCT SELECT * FROM b) AS _l_0 ORDER BY x NULLS FIRST LIMIT 1",
"tsql": "SELECT TOP 1 * FROM (SELECT * FROM a UNION SELECT * FROM b) AS _l_0 ORDER BY x",
},
)

self.validate_all(
"SELECT * FROM a UNION SELECT * FROM b",
read={
Expand Down
14 changes: 11 additions & 3 deletions tests/test_build.py
Expand Up @@ -300,6 +300,10 @@ def test_build(self):
lambda: select("x").from_("tbl").order_by("y"),
"SELECT x FROM tbl ORDER BY y",
),
(
lambda: parse_one("select * from x union select * from y").order_by("y"),
"SELECT * FROM x UNION SELECT * FROM y ORDER BY y",
),
(
lambda: select("x").from_("tbl").cluster_by("y"),
"SELECT x FROM tbl CLUSTER BY y",
Expand Down Expand Up @@ -505,15 +509,19 @@ def test_build(self):
(lambda: parse_one("(SELECT 1)").select("2"), "(SELECT 1, 2)"),
(
lambda: parse_one("(SELECT 1)").limit(1),
"SELECT * FROM ((SELECT 1)) AS _l_0 LIMIT 1",
"(SELECT 1) LIMIT 1",
),
(
lambda: parse_one("WITH t AS (SELECT 1) (SELECT 1)").limit(1),
"SELECT * FROM (WITH t AS (SELECT 1) (SELECT 1)) AS _l_0 LIMIT 1",
"WITH t AS (SELECT 1) (SELECT 1) LIMIT 1",
),
(
lambda: parse_one("(SELECT 1 LIMIT 2)").limit(1),
"SELECT * FROM ((SELECT 1 LIMIT 2)) AS _l_0 LIMIT 1",
"(SELECT 1 LIMIT 2) LIMIT 1",
),
(
lambda: parse_one("SELECT 1 UNION SELECT 2").limit(5).offset(2),
"SELECT 1 UNION SELECT 2 LIMIT 5 OFFSET 2",
),
(lambda: parse_one("(SELECT 1)").subquery(), "((SELECT 1))"),
(lambda: parse_one("(SELECT 1)").subquery("alias"), "((SELECT 1)) AS alias"),
Expand Down

0 comments on commit 32cdc36

Please sign in to comment.