Skip to content

Commit

Permalink
Fix(snowflake)!: preserve star clauses (EXCLUDE, RENAME, REPLACE) (#3477
Browse files Browse the repository at this point in the history
)
  • Loading branch information
georgesittas committed May 14, 2024
1 parent e004d2a commit e3ff67b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 81 deletions.
3 changes: 1 addition & 2 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ class Generator(generator.Generator):
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
COPY_HAS_INTO_KEYWORD = False
STAR_EXCEPT = "EXCLUDE"

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down Expand Up @@ -506,8 +507,6 @@ class Generator(generator.Generator):
exp.DataType.Type.TIMESTAMP_NS: "TIMESTAMP_NS",
}

STAR_MAPPING = {**generator.Generator.STAR_MAPPING, "except": "EXCLUDE"}

UNWRAPPED_INTERVAL_VALUES = (exp.Literal, exp.Paren)

# DuckDB doesn't generally support CREATE TABLE .. properties
Expand Down
7 changes: 1 addition & 6 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,6 @@ class Tokenizer(tokens.Tokenizer):
"NCHAR VARYING": TokenType.VARCHAR,
"PUT": TokenType.COMMAND,
"REMOVE": TokenType.COMMAND,
"RENAME": TokenType.REPLACE,
"RM": TokenType.COMMAND,
"SAMPLE": TokenType.TABLE_SAMPLE,
"SQL_DOUBLE": TokenType.DOUBLE,
Expand Down Expand Up @@ -773,6 +772,7 @@ class Generator(generator.Generator):
STRUCT_DELIMITER = ("(", ")")
COPY_PARAMS_ARE_WRAPPED = False
COPY_PARAMS_EQ_REQUIRED = True
STAR_EXCEPT = "EXCLUDE"

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down Expand Up @@ -875,11 +875,6 @@ class Generator(generator.Generator):
exp.DataType.Type.STRUCT: "OBJECT",
}

STAR_MAPPING = {
"except": "EXCLUDE",
"replace": "RENAME",
}

PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
exp.SetProperty: exp.Properties.Location.UNSUPPORTED,
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3820,7 +3820,7 @@ class Where(Expression):


class Star(Expression):
arg_types = {"except": False, "replace": False}
arg_types = {"except": False, "replace": False, "rename": False}

@property
def name(self) -> str:
Expand Down
16 changes: 8 additions & 8 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ class Generator(metaclass=_Generator):
# Whether the conditional TRY(expression) function is supported
TRY_SUPPORTED = True

# The keyword to use when generating a star projection with excluded columns
STAR_EXCEPT = "EXCEPT"

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand All @@ -366,11 +369,6 @@ class Generator(metaclass=_Generator):
exp.DataType.Type.ROWVERSION: "VARBINARY",
}

STAR_MAPPING = {
"except": "EXCEPT",
"replace": "REPLACE",
}

TIME_PART_SINGULARS = {
"MICROSECONDS": "MICROSECOND",
"SECONDS": "SECOND",
Expand Down Expand Up @@ -2308,10 +2306,12 @@ def schema_columns_sql(self, expression: exp.Schema) -> str:

def star_sql(self, expression: exp.Star) -> str:
except_ = self.expressions(expression, key="except", flat=True)
except_ = f"{self.seg(self.STAR_MAPPING['except'])} ({except_})" if except_ else ""
except_ = f"{self.seg(self.STAR_EXCEPT)} ({except_})" if except_ else ""
replace = self.expressions(expression, key="replace", flat=True)
replace = f"{self.seg(self.STAR_MAPPING['replace'])} ({replace})" if replace else ""
return f"*{except_}{replace}"
replace = f"{self.seg('REPLACE')} ({replace})" if replace else ""
rename = self.expressions(expression, key="rename", flat=True)
rename = f"{self.seg('RENAME')} ({rename})" if rename else ""
return f"*{except_}{replace}{rename}"

def parameter_sql(self, expression: exp.Parameter) -> str:
this = self.sql(expression, "this")
Expand Down
24 changes: 10 additions & 14 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,7 +713,12 @@ class Parser(metaclass=_Parser):
TokenType.FALSE: lambda self, _: self.expression(exp.Boolean, this=False),
TokenType.SESSION_PARAMETER: lambda self, _: self._parse_session_parameter(),
TokenType.STAR: lambda self, _: self.expression(
exp.Star, **{"except": self._parse_except(), "replace": self._parse_replace()}
exp.Star,
**{
"except": self._parse_star_op("EXCEPT", "EXCLUDE"),
"replace": self._parse_star_op("REPLACE"),
"rename": self._parse_star_op("RENAME"),
},
),
}

Expand Down Expand Up @@ -5677,23 +5682,14 @@ def _parse_placeholder(self) -> t.Optional[exp.Expression]:
self._advance(-1)
return None

def _parse_except(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.EXCEPT):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_column)

except_column = self._parse_column()
return [except_column] if except_column else None

def _parse_replace(self) -> t.Optional[t.List[exp.Expression]]:
if not self._match(TokenType.REPLACE):
def _parse_star_op(self, *keywords: str) -> t.Optional[t.List[exp.Expression]]:
if not self._match_texts(keywords):
return None
if self._match(TokenType.L_PAREN, advance=False):
return self._parse_wrapped_csv(self._parse_expression)

replace_expression = self._parse_expression()
return [replace_expression] if replace_expression else None
expression = self._parse_expression()
return [expression] if expression else None

def _parse_csv(
self, parse_method: t.Callable, sep: TokenType = TokenType.COMMA
Expand Down
85 changes: 35 additions & 50 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def test_snowflake(self):
self.validate_identity("SELECT CONVERT_TIMEZONE('UTC', 'America/Los_Angeles', col)")
self.validate_identity("ALTER TABLE a SWAP WITH b")
self.validate_identity("SELECT MATCH_CONDITION")
self.validate_identity("SELECT * REPLACE (CAST(col AS TEXT) AS scol) FROM t")
self.validate_identity(
"MERGE INTO my_db AS ids USING (SELECT new_id FROM my_model WHERE NOT col IS NULL) AS new_ids ON ids.type = new_ids.type AND ids.source = new_ids.source WHEN NOT MATCHED THEN INSERT VALUES (new_ids.new_id)"
)
Expand Down Expand Up @@ -230,6 +231,38 @@ def test_snowflake(self):
"CAST(x AS NCHAR VARYING)",
"CAST(x AS VARCHAR)",
)
self.validate_identity(
"CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
"CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
)
self.validate_identity(
"CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
"CREATE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
)
self.validate_identity(
"CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
"CREATE TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
)
self.validate_identity(
"ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
"ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
)
self.validate_identity(
"SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
"SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
)
self.validate_identity(
"SELECT * FROM xxx WHERE col ilike '%Don''t%'",
"SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
)
self.validate_identity(
"SELECT * EXCLUDE a, b FROM xxx",
"SELECT * EXCLUDE (a), b FROM xxx",
)
self.validate_identity(
"SELECT * RENAME a AS b, c AS d FROM xxx",
"SELECT * RENAME (a AS b), c AS d FROM xxx",
)

self.validate_all(
"OBJECT_CONSTRUCT_KEEP_NULL('key_1', 'one', 'key_2', NULL)",
Expand Down Expand Up @@ -550,60 +583,12 @@ def test_snowflake(self):
},
)
self.validate_all(
"CREATE OR REPLACE TEMPORARY TABLE x (y NUMBER IDENTITY(0, 1))",
write={
"snowflake": "CREATE OR REPLACE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
},
)
self.validate_all(
"CREATE TEMPORARY TABLE x (y NUMBER AUTOINCREMENT(0, 1))",
write={
"snowflake": "CREATE TEMPORARY TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
},
)
self.validate_all(
"CREATE TABLE x (y NUMBER IDENTITY START 0 INCREMENT 1)",
write={
"snowflake": "CREATE TABLE x (y DECIMAL(38, 0) AUTOINCREMENT START 0 INCREMENT 1)",
},
)
self.validate_all(
"ALTER TABLE foo ADD COLUMN id INT identity(1, 1)",
write={
"snowflake": "ALTER TABLE foo ADD COLUMN id INT AUTOINCREMENT START 1 INCREMENT 1",
},
)
self.validate_all(
"SELECT DAYOFWEEK('2016-01-02T23:39:20.123-07:00'::TIMESTAMP)",
write={
"snowflake": "SELECT DAYOFWEEK(CAST('2016-01-02T23:39:20.123-07:00' AS TIMESTAMP))",
},
)
self.validate_all(
"SELECT * FROM xxx WHERE col ilike '%Don''t%'",
write={
"snowflake": "SELECT * FROM xxx WHERE col ILIKE '%Don\\'t%'",
},
)
self.validate_all(
"SELECT * EXCLUDE a, b FROM xxx",
write={
"snowflake": "SELECT * EXCLUDE (a), b FROM xxx",
},
)
self.validate_all(
"SELECT * RENAME a AS b, c AS d FROM xxx",
write={
"snowflake": "SELECT * RENAME (a AS b), c AS d FROM xxx",
},
)
self.validate_all(
"SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
"SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
read={
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
write={
"snowflake": "SELECT * EXCLUDE (a, b) RENAME (c AS d, E AS F) FROM xxx",
"snowflake": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
"duckdb": "SELECT * EXCLUDE (a, b) REPLACE (c AS d, E AS F) FROM xxx",
},
)
Expand Down

0 comments on commit e3ff67b

Please sign in to comment.