Skip to content

Commit

Permalink
Fix: drop CLUSTER/DISTRIBUTED/SORT BY modifiers when unsupported (#3069)
Browse files Browse the repository at this point in the history
* Fix: drop CLUSTER/DISTRIBUTED/SORT BY modifiers when unsupported

* Get rid of method
  • Loading branch information
georgesittas committed Mar 3, 2024
1 parent 97b6baa commit 2770ddc
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 14 deletions.
7 changes: 7 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[
if enum not in ("", "bigquery"):
klass.generator_class.SELECT_KINDS = ()

if enum not in ("", "databricks", "hive", "spark", "spark2"):
modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy()
for modifier in ("cluster", "distribute", "sort"):
modifier_transforms.pop(modifier, None)

klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms

if not klass.SUPPORTS_SEMI_ANTI_JOIN:
klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | {
TokenType.ANTI,
Expand Down
27 changes: 13 additions & 14 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ class Generator(metaclass=_Generator):
"YEARS": "YEAR",
}

AFTER_HAVING_MODIFIER_TRANSFORMS = {
"cluster": lambda self, e: self.sql(e, "cluster"),
"distribute": lambda self, e: self.sql(e, "distribute"),
"qualify": lambda self, e: self.sql(e, "qualify"),
"sort": lambda self, e: self.sql(e, "sort"),
"windows": lambda self, e: (
self.seg("WINDOW ") + self.expressions(e, key="windows", flat=True)
if e.args.get("windows")
else ""
),
}

TOKEN_MAPPING: t.Dict[TokenType, str] = {}

STRUCT_DELIMITER = ("<", ">")
Expand Down Expand Up @@ -2097,7 +2109,7 @@ def query_modifiers(self, expression: exp.Expression, *sqls: str) -> str:
self.sql(expression, "where"),
self.sql(expression, "group"),
self.sql(expression, "having"),
*self.after_having_modifiers(expression),
*[gen(self, expression) for gen in self.AFTER_HAVING_MODIFIER_TRANSFORMS.values()],
self.sql(expression, "order"),
*offset_limit_modifiers,
*self.after_limit_modifiers(expression),
Expand All @@ -2116,19 +2128,6 @@ def offset_limit_modifiers(
self.sql(limit) if fetch else self.sql(expression, "offset"),
]

def after_having_modifiers(self, expression: exp.Expression) -> t.List[str]:
return [
self.sql(expression, "qualify"),
(
self.seg("WINDOW ") + self.expressions(expression, key="windows", flat=True)
if expression.args.get("windows")
else ""
),
self.sql(expression, "distribute"),
self.sql(expression, "sort"),
self.sql(expression, "cluster"),
]

def after_limit_modifiers(self, expression: exp.Expression) -> t.List[str]:
locks = self.expressions(expression, key="locks", sep=" ")
locks = f" {locks}" if locks else ""
Expand Down
14 changes: 14 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock

from sqlglot import exp, parse_one
from sqlglot.dialects.dialect import Dialects
from tests.dialects.test_dialect import Validator


Expand Down Expand Up @@ -720,3 +721,16 @@ def test_explode_to_unnest(self):
"presto": "SELECT col, pos, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.col_2) AS col_2, IF(_u_2.pos_2 = _u_3.pos_3, _u_3.pos_3) AS pos_3 FROM _u CROSS JOIN UNNEST(SEQUENCE(1, GREATEST(CARDINALITY(ARRAY[2, 3])))) AS _u_2(pos_2) CROSS JOIN UNNEST(ARRAY[2, 3]) WITH ORDINALITY AS _u_3(col_2, pos_3) WHERE _u_2.pos_2 = _u_3.pos_3 OR (_u_2.pos_2 > CARDINALITY(ARRAY[2, 3]) AND _u_3.pos_3 = CARDINALITY(ARRAY[2, 3]))",
},
)

def test_strip_modifiers(self):
without_modifiers = "SELECT * FROM t"
with_modifiers = f"{without_modifiers} CLUSTER BY y DISTRIBUTE BY x SORT BY z"
query = self.parse_one(with_modifiers)

for dialect in Dialects:
with self.subTest(f"Transpiling query with CLUSTER/DISTRIBUTE/SORT BY to {dialect}"):
name = dialect.value
if name in ("", "databricks", "hive", "spark", "spark2"):
self.assertEqual(query.sql(name), with_modifiers)
else:
self.assertEqual(query.sql(name), without_modifiers)

0 comments on commit 2770ddc

Please sign in to comment.