Skip to content

Commit

Permalink
Fix(optimizer): constrain UDTF scope boundary (#3226)
Browse files Browse the repository at this point in the history
* Fix(optimizer): constrain UDTF scope boundary

* Replace end-to-end test with a minimal projection pushdown test
  • Loading branch information
georgesittas committed Mar 26, 2024
1 parent 2f6a2f1 commit e6b8d1f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 3 deletions.
5 changes: 2 additions & 3 deletions sqlglot/optimizer/scope.py
Expand Up @@ -800,10 +800,9 @@ def walk_in_scope(expression, bfs=True, prune=None):
if (
isinstance(node, exp.CTE)
or (
_is_derived_table(node)
and isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
isinstance(node.parent, (exp.From, exp.Join, exp.Subquery))
and (_is_derived_table(node) or isinstance(node, exp.UDTF))
)
or isinstance(node, exp.UDTF)
or isinstance(node, exp.UNWRAPPED_QUERIES)
):
crossed_scope_boundary = True
Expand Down
3 changes: 3 additions & 0 deletions tests/fixtures/optimizer/pushdown_projections.sql
Expand Up @@ -79,6 +79,9 @@ WITH y AS (SELECT MAX(1) AS _ FROM x AS x) SELECT 1 AS "1" FROM y AS y;
WITH y AS (SELECT a FROM x GROUP BY a) SELECT 1 FROM y;
WITH y AS (SELECT 1 AS _ FROM x AS x GROUP BY x.a) SELECT 1 AS "1" FROM y AS y;

WITH cte AS (SELECT col FROM t) SELECT IF(1 IN UNNEST(col), 1, 0) AS col FROM cte;
WITH cte AS (SELECT t.col AS col FROM t AS t) SELECT CASE WHEN 1 IN (SELECT UNNEST(cte.col)) THEN 1 ELSE 0 END AS col FROM cte AS cte;

--------------------------------------
-- Unknown Star Expansion
--------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions tests/test_optimizer.py
Expand Up @@ -505,6 +505,9 @@ def test_file_schema(self):
)

def test_scope(self):
ast = parse_one("SELECT IF(a IN UNNEST(b), 1, 0) AS c FROM t", dialect="bigquery")
self.assertEqual(build_scope(ast).columns, [exp.column("a"), exp.column("b")])

many_unions = parse_one(" UNION ALL ".join(["SELECT x FROM t"] * 10000))
scopes_using_traverse = list(build_scope(many_unions).traverse())
scopes_using_traverse_scope = traverse_scope(many_unions)
Expand Down

0 comments on commit e6b8d1f

Please sign in to comment.