From 1407746646974fffdf67a34c288738a8fc39b22d Mon Sep 17 00:00:00 2001 From: George Sittas Date: Tue, 9 Apr 2024 22:12:12 +0300 Subject: [PATCH] Fix(optimizer): propagate recursive CTE source to children scopes early --- sqlglot/optimizer/scope.py | 12 +++--------- tests/test_optimizer.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 073ced2844..c589e242e9 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -600,7 +600,7 @@ def _traverse_ctes(scope): sources = {} for cte in scope.ctes: - recursive_scope = None + cte_name = cte.alias # if the scope is a recursive cte, it must be in the form of base_case UNION recursive. # thus the recursive scope is the first section of the union. @@ -609,7 +609,7 @@ def _traverse_ctes(scope): union = cte.this if isinstance(union, exp.Union): - recursive_scope = scope.branch(union.this, scope_type=ScopeType.CTE) + sources[cte_name] = scope.branch(union.this, scope_type=ScopeType.CTE) child_scope = None @@ -623,15 +623,9 @@ def _traverse_ctes(scope): ): yield child_scope - alias = cte.alias - sources[alias] = child_scope - - if recursive_scope: - child_scope.add_source(alias, recursive_scope) - child_scope.cte_sources[alias] = recursive_scope - # append the final child_scope yielded if child_scope: + sources[cte_name] = child_scope scope.cte_scopes.append(child_scope) scope.sources.update(sources) diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 0e8ce157ba..e8759955c0 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -228,6 +228,17 @@ def test_normalize(self): @patch("sqlglot.generator.logger") def test_qualify_columns(self, logger): + self.assertEqual( + optimizer.qualify_columns.qualify_columns( + parse_one( + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT x + 1 FROM t AS child WHERE x < 10) SELECT * FROM t" + ), + schema={}, + infer_schema=False, + ).sql(), + "WITH RECURSIVE t AS (SELECT 1 AS x UNION ALL SELECT child.x + 1 AS _col_0 FROM t AS child WHERE child.x < 10) SELECT t.x AS x FROM t", + ) + self.assertEqual( optimizer.qualify_columns.qualify_columns( parse_one("WITH x AS (SELECT a FROM db.y) SELECT * FROM db.x"),