Skip to content

Commit

Permalink
Fix(optimizer): fix multiple bugs in unnest_subqueries, clean up test…
Browse files Browse the repository at this point in the history
… suite (#3464)

* Fix(optimizer): fix multiple bugs in unnest_subqueries, clean up test suite

* Fix AttributeError issue with subquery projections

* Fix #3448
  • Loading branch information
georgesittas committed May 13, 2024
1 parent 58d5f2b commit 065281e
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 288 deletions.
31 changes: 19 additions & 12 deletions sqlglot/optimizer/unnest_subqueries.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sqlglot import exp
from sqlglot.helper import name_sequence
from sqlglot.optimizer.scope import ScopeType, traverse_scope
from sqlglot.optimizer.scope import ScopeType, find_in_scope, traverse_scope


def unnest_subqueries(expression):
Expand Down Expand Up @@ -64,7 +64,7 @@ def unnest(select, parent_select, next_alias_name):
(not clause or clause_parent_select is not parent_select)
and (
parent_select.args.get("group")
or any(projection.find(exp.AggFunc) for projection in parent_select.selects)
or any(find_in_scope(select, exp.AggFunc) for select in parent_select.selects)
)
):
column = exp.Max(this=column)
Expand Down Expand Up @@ -101,7 +101,7 @@ def unnest(select, parent_select, next_alias_name):
if group:
if {value.this} != set(group.expressions):
select = (
exp.select(exp.column(value.alias, "_q"))
exp.select(exp.alias_(exp.column(value.alias, "_q"), value.alias))
.from_(select.subquery("_q", copy=False), copy=False)
.group_by(exp.column(value.alias, "_q"), copy=False)
)
Expand Down Expand Up @@ -152,7 +152,9 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
return

is_subquery_projection = any(
node is select.parent for node in parent_select.selects if isinstance(node, exp.Subquery)
node is select.parent
for node in map(lambda s: s.unalias(), parent_select.selects)
if isinstance(node, exp.Subquery)
)

value = select.selects[0]
Expand Down Expand Up @@ -200,19 +202,25 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):

alias = exp.column(value.alias, table_alias)
other = _other_operand(parent_predicate)
op_type = type(parent_predicate.parent) if parent_predicate else None

if isinstance(parent_predicate, exp.Exists):
alias = exp.column(list(key_aliases.values())[0], table_alias)
parent_predicate = _replace(parent_predicate, f"NOT {alias} IS NULL")
elif isinstance(parent_predicate, exp.All):
assert issubclass(op_type, exp.Binary)
predicate = op_type(this=other, expression=exp.column("_x"))
parent_predicate = _replace(
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> _x = {other})"
parent_predicate.parent, f"ARRAY_ALL({alias}, _x -> {predicate})"
)
elif isinstance(parent_predicate, exp.Any):
assert issubclass(op_type, exp.Binary)
if value.this in group_by:
parent_predicate = _replace(parent_predicate.parent, f"{other} = {alias}")
predicate = op_type(this=other, expression=alias)
parent_predicate = _replace(parent_predicate.parent, predicate)
else:
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> _x = {other})")
predicate = op_type(this=other, expression=exp.column("_x"))
parent_predicate = _replace(parent_predicate, f"ARRAY_ANY({alias}, _x -> {predicate})")
elif isinstance(parent_predicate, exp.In):
if value.this in group_by:
parent_predicate = _replace(parent_predicate, f"{other} = {alias}")
Expand All @@ -222,7 +230,7 @@ def decorrelate(select, parent_select, external_columns, next_alias_name):
f"ARRAY_ANY({alias}, _x -> _x = {parent_predicate.this})",
)
else:
if is_subquery_projection:
if is_subquery_projection and select.parent.alias:
alias = exp.alias_(alias, select.parent.alias)

# COUNT always returns 0 on empty datasets, so we need take that into consideration here
Expand All @@ -236,10 +244,7 @@ def remove_aggs(node):
return exp.null()
return node

alias = exp.Coalesce(
this=alias,
expressions=[value.this.transform(remove_aggs)],
)
alias = exp.Coalesce(this=alias, expressions=[value.this.transform(remove_aggs)])

select.parent.replace(alias)

Expand All @@ -249,6 +254,8 @@ def remove_aggs(node):

if is_subquery_projection:
key.replace(nested)
if not isinstance(predicate, exp.EQ):
parent_select.where(predicate, copy=False)
continue

if key in group_by:
Expand Down
4 changes: 1 addition & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2701,9 +2701,7 @@ def _parse_subquery(
)

def _implicit_unnests_to_explicit(self, this: E) -> E:
from sqlglot.optimizer.normalize_identifiers import (
normalize_identifiers as _norm,
)
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers as _norm

refs = {_norm(this.args["from"].this.copy(), dialect=self.dialect).alias_or_name}
for i, join in enumerate(this.args.get("joins") or []):
Expand Down
30 changes: 5 additions & 25 deletions tests/fixtures/optimizer/tpc-ds/tpc-ds.sql
Original file line number Diff line number Diff line change
Expand Up @@ -1409,31 +1409,11 @@ WITH "_u_0" AS (
"store_sales"."ss_quantity" <= 80 AND "store_sales"."ss_quantity" >= 61
)
SELECT
CASE
WHEN MAX("_u_0"."_col_0") > 3672
THEN MAX("_u_1"."_col_0")
ELSE MAX("_u_2"."_col_0")
END AS "bucket1",
CASE
WHEN MAX("_u_3"."_col_0") > 3392
THEN MAX("_u_4"."_col_0")
ELSE MAX("_u_5"."_col_0")
END AS "bucket2",
CASE
WHEN MAX("_u_6"."_col_0") > 32784
THEN MAX("_u_7"."_col_0")
ELSE MAX("_u_8"."_col_0")
END AS "bucket3",
CASE
WHEN MAX("_u_9"."_col_0") > 26032
THEN MAX("_u_10"."_col_0")
ELSE MAX("_u_11"."_col_0")
END AS "bucket4",
CASE
WHEN MAX("_u_12"."_col_0") > 23982
THEN MAX("_u_13"."_col_0")
ELSE MAX("_u_14"."_col_0")
END AS "bucket5"
CASE WHEN "_u_0"."_col_0" > 3672 THEN "_u_1"."_col_0" ELSE "_u_2"."_col_0" END AS "bucket1",
CASE WHEN "_u_3"."_col_0" > 3392 THEN "_u_4"."_col_0" ELSE "_u_5"."_col_0" END AS "bucket2",
CASE WHEN "_u_6"."_col_0" > 32784 THEN "_u_7"."_col_0" ELSE "_u_8"."_col_0" END AS "bucket3",
CASE WHEN "_u_9"."_col_0" > 26032 THEN "_u_10"."_col_0" ELSE "_u_11"."_col_0" END AS "bucket4",
CASE WHEN "_u_12"."_col_0" > 23982 THEN "_u_13"."_col_0" ELSE "_u_14"."_col_0" END AS "bucket5"
FROM "reason" AS "reason"
CROSS JOIN "_u_0" AS "_u_0"
CROSS JOIN "_u_1" AS "_u_1"
Expand Down

0 comments on commit 065281e

Please sign in to comment.