Skip to content

Commit

Permalink
Add option to pushdown CTE aliases earlier in qualify (#2726)
Browse files Browse the repository at this point in the history
* Add option to pushdown CTE aliases earlier in qualify

* Use fixtures for pushdown_cte_alias_columns

* Setpushdown_cte_alias_columns rule defaults via dialect

* Move pushdown_cte_alias_columns rule option close to usage

* Add more tests to pushdown_cte_alias_columns

* Rename flag and add better description

* Remove pushdown_cte_alias_columns flag since it's dialect driven
  • Loading branch information
dangoldin committed Dec 28, 2023
1 parent a8582b1 commit 4629410
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 1 deletion.
18 changes: 18 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""

PREFER_CTE_ALIAS_COLUMN = False
"""
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
HAVING clause of the CTE. This flag will cause the CTE alias columns to override
any projection aliases in the subquery.
For example,
WITH y(c) AS (
SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class Snowflake(Dialect):
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
PREFER_CTE_ALIAS_COLUMN = True

TIME_MAPPING = {
"YYYY": "%Y",
Expand Down
6 changes: 5 additions & 1 deletion sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
Expand Down Expand Up @@ -66,6 +67,9 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

if Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN:
expression = pushdown_cte_alias_columns_func(expression)

if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
Expand Down
32 changes: 32 additions & 0 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
)


def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
"""
Pushes down the CTE alias columns into the projection,
This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.
Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'
Args:
expression: Expression to pushdown.
Returns:
The expression with the CTE aliases pushed down into the projection.
"""
for cte in expression.find_all(exp.CTE):
if cte.alias_column_names:
new_expressions = []
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
if isinstance(projection, exp.Alias):
projection.set("alias", _alias)
else:
projection = alias(projection, alias=_alias)
new_expressions.append(projection)
cte.this.set("expressions", new_expressions)

return expression


class Resolver:
"""
Helper for resolving columns.
Expand Down
20 changes: 20 additions & 0 deletions tests/fixtures/optimizer/pushdown_cte_alias_columns.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
WITH y(c) AS (SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0) SELECT c FROM y;
WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y;

WITH y(c) AS (SELECT SUM(a) as d FROM (SELECT 1 a) AS x HAVING c > 0) SELECT c FROM y;
WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y;

WITH x(c) AS (SELECT SUM(1) a HAVING c > 0 LIMIT 1) SELECT * FROM x;
WITH x(c) AS (SELECT SUM(1) AS c HAVING c > 0 LIMIT 1) SELECT * FROM x;

-- Invalid statement in Snowflake but checking more complex structures
WITH x(c) AS ((SELECT 1 a) HAVING c > 0) SELECT * FROM x;
WITH x(c) AS ((SELECT 1 AS a) HAVING c > 0) SELECT * FROM x;

-- Invalid statement in Snowflake but checking more complex structures
WITH x(c) AS ((SELECT SUM(1) a) HAVING c > 0 LIMIT 1) SELECT * FROM x;
WITH x(c) AS ((SELECT SUM(1) AS a) HAVING c > 0 LIMIT 1) SELECT * FROM x;

-- Invalid statement in Snowflake but checking that we don't fail
WITH x(c) AS (SELECT SUM(a) FROM x HAVING c > 0 UNION ALL SELECT SUM(a) FROM y HAVING c > 0) SELECT * FROM x;
WITH x(c) AS (SELECT SUM(a) FROM x HAVING c > 0 UNION ALL SELECT SUM(a) FROM y HAVING c > 0) SELECT * FROM x;
5 changes: 5 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ def test_qualify_columns__with_invisible(self):
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)

def test_pushdown_cte_alias_columns(self):
self.check_file(
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
)

def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
Expand Down

0 comments on commit 4629410

Please sign in to comment.