Skip to content

Commit

Permalink
Fix!: traverse union scopes iteratively (#3112)
Browse files Browse the repository at this point in the history
* Fix!: traverse union scopes iteratively

* Fixup

* Fix traverse recursion

* Improve testing

* Simplify Scope.traverse
  • Loading branch information
georgesittas committed Mar 11, 2024
1 parent 80d484c commit b1c8cac
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 38 deletions.
8 changes: 4 additions & 4 deletions sqlglot/optimizer/pushdown_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,13 +77,13 @@ def pushdown(condition, sources, scope_ref_count, dialect, join_index=None):
pushdown_dnf(predicates, sources, scope_ref_count)


def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
def pushdown_cnf(predicates, sources, scope_ref_count, join_index=None):
"""
If the predicates are in CNF like form, we can simply replace each block in the parent.
"""
join_index = join_index or {}
for predicate in predicates:
for node in nodes_for_predicate(predicate, scope, scope_ref_count).values():
for node in nodes_for_predicate(predicate, sources, scope_ref_count).values():
if isinstance(node, exp.Join):
name = node.alias_or_name
predicate_tables = exp.column_table_names(predicate, name)
Expand All @@ -103,7 +103,7 @@ def pushdown_cnf(predicates, scope, scope_ref_count, join_index=None):
node.where(inner_predicate, copy=False)


def pushdown_dnf(predicates, scope, scope_ref_count):
def pushdown_dnf(predicates, sources, scope_ref_count):
"""
If the predicates are in DNF form, we can only push down conditions that are in all blocks.
Additionally, we can't remove predicates from their original form.
Expand All @@ -127,7 +127,7 @@ def pushdown_dnf(predicates, scope, scope_ref_count):
# pushdown all predicates to their respective nodes
for table in sorted(pushdown_tables):
for predicate in predicates:
nodes = nodes_for_predicate(predicate, scope, scope_ref_count)
nodes = nodes_for_predicate(predicate, sources, scope_ref_count)

if table not in nodes:
continue
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def _expand_stars(
raise OptimizeError(f"Unknown table: {table}")

columns = resolver.get_source_columns(table, only_visible=True)
columns = columns or scope.outer_column_list
columns = columns or scope.outer_columns

if pseudocolumns:
columns = [name for name in columns if name.upper() not in pseudocolumns]
Expand Down Expand Up @@ -517,7 +517,7 @@ def qualify_outputs(scope_or_expression: Scope | exp.Expression) -> None:

new_selections = []
for i, (selection, aliased_column) in enumerate(
itertools.zip_longest(scope.expression.selects, scope.outer_column_list)
itertools.zip_longest(scope.expression.selects, scope.outer_columns)
):
if selection is None:
break
Expand Down
84 changes: 52 additions & 32 deletions sqlglot/optimizer/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ class Scope:
SELECT c FROM x LATERAL VIEW EXPLODE (a) AS c;
The LATERAL VIEW EXPLODE gets x as a source.
cte_sources (dict[str, Scope]): Sources from CTES
outer_column_list (list[str]): If this is a derived table or CTE, and the outer query
defines a column list of it's alias of this scope, this is that list of columns.
outer_columns (list[str]): If this is a derived table or CTE, and the outer query
defines a column list for the alias of this scope, this is that list of columns.
For example:
SELECT * FROM (SELECT ...) AS y(col1, col2)
The inner query would have `["col1", "col2"]` for its `outer_column_list`
The inner query would have `["col1", "col2"]` for its `outer_columns`
parent (Scope): Parent scope
scope_type (ScopeType): Type of this scope, relative to it's parent
subquery_scopes (list[Scope]): List of all child scopes for subqueries
Expand All @@ -58,7 +58,7 @@ def __init__(
self,
expression,
sources=None,
outer_column_list=None,
outer_columns=None,
parent=None,
scope_type=ScopeType.ROOT,
lateral_sources=None,
Expand All @@ -70,7 +70,7 @@ def __init__(
self.cte_sources = cte_sources or {}
self.sources.update(self.lateral_sources)
self.sources.update(self.cte_sources)
self.outer_column_list = outer_column_list or []
self.outer_columns = outer_columns or []
self.parent = parent
self.scope_type = scope_type
self.subquery_scopes = []
Expand Down Expand Up @@ -435,11 +435,21 @@ def traverse(self):
Yields:
Scope: scope instances in depth-first-search post-order
"""
for child_scope in itertools.chain(
self.cte_scopes, self.union_scopes, self.table_scopes, self.subquery_scopes
):
yield from child_scope.traverse()
yield self
stack = [self]
result = []
while stack:
scope = stack.pop()
result.append(scope)
stack.extend(
itertools.chain(
scope.cte_scopes,
scope.union_scopes,
scope.table_scopes,
scope.subquery_scopes,
)
)

yield from reversed(result)

def ref_count(self):
"""
Expand Down Expand Up @@ -522,7 +532,9 @@ def _traverse_scope(scope):
if isinstance(scope.expression, exp.Select):
yield from _traverse_select(scope)
elif isinstance(scope.expression, exp.Union):
yield from _traverse_ctes(scope)
yield from _traverse_union(scope)
return
elif isinstance(scope.expression, exp.Subquery):
if scope.is_root:
yield from _traverse_select(scope)
Expand All @@ -548,30 +560,38 @@ def _traverse_select(scope):


def _traverse_union(scope):
yield from _traverse_ctes(scope)
prev_scope = None
union_scope_stack = [scope]
expression_stack = [scope.expression.right, scope.expression.left]

# The last scope to be yield should be the top most scope
left = None
for left in _traverse_scope(
scope.branch(
scope.expression.left,
outer_column_list=scope.outer_column_list,
scope_type=ScopeType.UNION,
)
):
yield left
while expression_stack:
expression = expression_stack.pop()
union_scope = union_scope_stack[-1]

right = None
for right in _traverse_scope(
scope.branch(
scope.expression.right,
outer_column_list=scope.outer_column_list,
new_scope = union_scope.branch(
expression,
outer_columns=union_scope.outer_columns,
scope_type=ScopeType.UNION,
)
):
yield right

scope.union_scopes = [left, right]
if isinstance(expression, exp.Union):
yield from _traverse_ctes(new_scope)

union_scope_stack.append(new_scope)
expression_stack.extend([expression.right, expression.left])
continue

for scope in _traverse_scope(new_scope):
yield scope

if prev_scope:
union_scope_stack.pop()
union_scope.union_scopes = [prev_scope, scope]
prev_scope = union_scope

yield union_scope
else:
prev_scope = scope


def _traverse_ctes(scope):
Expand All @@ -595,7 +615,7 @@ def _traverse_ctes(scope):
scope.branch(
cte.this,
cte_sources=sources,
outer_column_list=cte.alias_column_names,
outer_columns=cte.alias_column_names,
scope_type=ScopeType.CTE,
)
):
Expand Down Expand Up @@ -690,7 +710,7 @@ def _traverse_tables(scope):
scope.branch(
expression,
lateral_sources=lateral_sources,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
scope_type=scope_type,
)
):
Expand Down Expand Up @@ -734,7 +754,7 @@ def _traverse_udtfs(scope):
scope.branch(
expression,
scope_type=ScopeType.DERIVED_TABLE,
outer_column_list=expression.alias_column_names,
outer_columns=expression.alias_column_names,
)
):
yield child_scope
Expand Down
9 changes: 9 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,15 @@ def test_file_schema(self):
)

def test_scope(self):
many_unions = parse_one(" UNION ALL ".join(["SELECT x FROM t"] * 10000))
scopes_using_traverse = list(build_scope(many_unions).traverse())
scopes_using_traverse_scope = traverse_scope(many_unions)
self.assertEqual(len(scopes_using_traverse), len(scopes_using_traverse_scope))
assert all(
x.expression is y.expression
for x, y in zip(scopes_using_traverse, scopes_using_traverse_scope)
)

sql = """
WITH q AS (
SELECT x.b FROM x
Expand Down

0 comments on commit b1c8cac

Please sign in to comment.