Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add option to pushdown CTE aliases earlier in qualify #2726

Merged
merged 7 commits into from
Dec 28, 2023
18 changes: 18 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,24 @@ class Dialect(metaclass=_Dialect):
For example, such columns may be excluded from `SELECT *` queries.
"""

PREFER_CTE_ALIAS_COLUMN = False
tobymao marked this conversation as resolved.
Show resolved Hide resolved
"""
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the
HAVING clause of the CTE. This flag will cause the CTE alias columns to override
any projection aliases in the subquery.

For example,
WITH y(c) AS (
SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0
) SELECT c FROM y;

will be rewritten as

WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class Snowflake(Dialect):
TIME_FORMAT = "'YYYY-MM-DD HH24:MI:SS'"
SUPPORTS_USER_DEFINED_TYPES = False
SUPPORTS_SEMI_ANTI_JOIN = False
PREFER_CTE_ALIAS_COLUMN = True

TIME_MAPPING = {
"YYYY": "%Y",
Expand Down
14 changes: 13 additions & 1 deletion sqlglot/optimizer/qualify.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import typing as t

from sqlglot import exp
from sqlglot.dialects.dialect import DialectType
from sqlglot.dialects.dialect import Dialect, DialectType
from sqlglot.optimizer.isolate_table_selects import isolate_table_selects
from sqlglot.optimizer.normalize_identifiers import normalize_identifiers
from sqlglot.optimizer.qualify_columns import (
pushdown_cte_alias_columns as pushdown_cte_alias_columns_func,
qualify_columns as qualify_columns_func,
quote_identifiers as quote_identifiers_func,
validate_qualify_columns as validate_qualify_columns_func,
Expand All @@ -28,6 +29,7 @@ def qualify(
validate_qualify_columns: bool = True,
quote_identifiers: bool = True,
identify: bool = True,
pushdown_cte_alias_columns: t.Optional[bool] = None,
dangoldin marked this conversation as resolved.
Show resolved Hide resolved
) -> exp.Expression:
"""
Rewrite sqlglot AST to have normalized and qualified tables and columns.
Expand Down Expand Up @@ -55,6 +57,8 @@ def qualify(
This step is necessary to ensure correctness for case sensitive queries.
But this flag is provided in case this step is performed at a later time.
identify: If True, quote all identifiers, else only necessary ones.
pushdown_cte_alias_columns: Whether or not to qualify outputs before columns.
Used for dialects (Snowflake) that can reference outputs in the HAVING.

Returns:
The qualified expression.
Expand All @@ -66,6 +70,14 @@ def qualify(
if isolate_tables:
expression = isolate_table_selects(expression, schema=schema)

pushdown_cte_alias_columns = (
dangoldin marked this conversation as resolved.
Show resolved Hide resolved
Dialect.get_or_raise(dialect).PREFER_CTE_ALIAS_COLUMN
if pushdown_cte_alias_columns is None
else pushdown_cte_alias_columns
)
if pushdown_cte_alias_columns:
expression = pushdown_cte_alias_columns_func(expression)

if qualify_columns:
expression = qualify_columns_func(
expression, schema, expand_alias_refs=expand_alias_refs, infer_schema=infer_schema
Expand Down
32 changes: 32 additions & 0 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,38 @@ def quote_identifiers(expression: E, dialect: DialectType = None, identify: bool
)


def pushdown_cte_alias_columns(expression: exp.Expression) -> exp.Expression:
"""
Pushes down the CTE alias columns into the projection,

This step is useful in Snowflake where the CTE alias columns can be referenced in the HAVING.

Example:
>>> import sqlglot
>>> expression = sqlglot.parse_one("WITH y (c) AS (SELECT SUM(a) FROM ( SELECT 1 a ) AS x HAVING c > 0) SELECT c FROM y")
>>> pushdown_cte_alias_columns(expression).sql()
'WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y'

Args:
expression: Expression to pushdown.

Returns:
The expression with the CTE aliases pushed down into the projection.
"""
for cte in expression.find_all(exp.CTE):
if cte.alias_column_names:
new_expressions = []
for _alias, projection in zip(cte.alias_column_names, cte.this.expressions):
barakalon marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(projection, exp.Alias):
projection.set("alias", _alias)
else:
projection = alias(projection, alias=_alias)
new_expressions.append(projection)
cte.this.set("expressions", new_expressions)

return expression


class Resolver:
"""
Helper for resolving columns.
Expand Down
20 changes: 20 additions & 0 deletions tests/fixtures/optimizer/pushdown_cte_alias_columns.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
WITH y(c) AS (SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0) SELECT c FROM y;
WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y;

WITH y(c) AS (SELECT SUM(a) as d FROM (SELECT 1 a) AS x HAVING c > 0) SELECT c FROM y;
WITH y(c) AS (SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0) SELECT c FROM y;

WITH x(c) AS (SELECT SUM(1) a HAVING c > 0 LIMIT 1) SELECT * FROM x;
WITH x(c) AS (SELECT SUM(1) AS c HAVING c > 0 LIMIT 1) SELECT * FROM x;

-- Invalid statement in Snowflake but checking more complex structures
WITH x(c) AS ((SELECT 1 a) HAVING c > 0) SELECT * FROM x;
WITH x(c) AS ((SELECT 1 AS a) HAVING c > 0) SELECT * FROM x;

-- Invalid statement in Snowflake but checking more complex structures
WITH x(c) AS ((SELECT SUM(1) a) HAVING c > 0 LIMIT 1) SELECT * FROM x;
WITH x(c) AS ((SELECT SUM(1) AS a) HAVING c > 0 LIMIT 1) SELECT * FROM x;

-- Invalid statement in Snowflake but checking that we don't fail
WITH x(c) AS (SELECT SUM(a) FROM x HAVING c > 0 UNION ALL SELECT SUM(a) FROM y HAVING c > 0) SELECT * FROM x;
WITH x(c) AS (SELECT SUM(a) FROM x HAVING c > 0 UNION ALL SELECT SUM(a) FROM y HAVING c > 0) SELECT * FROM x;
5 changes: 5 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,11 @@ def test_qualify_columns__with_invisible(self):
schema = MappingSchema(self.schema, {"x": {"a"}, "y": {"b"}, "z": {"b"}})
self.check_file("qualify_columns__with_invisible", qualify_columns, schema=schema)

def test_pushdown_cte_alias_columns(self):
self.check_file(
"pushdown_cte_alias_columns", optimizer.qualify_columns.pushdown_cte_alias_columns
)

def test_qualify_columns__invalid(self):
for sql in load_sql_fixtures("optimizer/qualify_columns__invalid.sql"):
with self.subTest(sql):
Expand Down