Skip to content

Commit

Permalink
fix(snowflake): COPY postfix (#3398)
Browse files Browse the repository at this point in the history
* fix(snowflake): COPY postfix

* Add missing type hints

* Remove unnecessary ternary

* PR Feedback 1
  • Loading branch information
VaggelisD committed May 3, 2024
1 parent 729b19b commit 2c2a788
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
15 changes: 13 additions & 2 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class Parser(parser.Parser):

PROPERTY_PARSERS = {
**parser.Parser.PROPERTY_PARSERS,
"LOCATION": lambda self: self._parse_location(),
"LOCATION": lambda self: self._parse_location_property(),
}

SHOW_PARSERS = {
Expand Down Expand Up @@ -676,10 +676,13 @@ def _parse_alter_table_swap(self) -> exp.SwapTable:
self._match_text_seq("WITH")
return self.expression(exp.SwapTable, this=self._parse_table(schema=True))

def _parse_location(self) -> exp.LocationProperty:
def _parse_location_property(self) -> exp.LocationProperty:
self._match(TokenType.EQ)
return self.expression(exp.LocationProperty, this=self._parse_location_path())

def _parse_file_location(self) -> t.Optional[exp.Expression]:
return self._parse_table_parts()

def _parse_location_path(self) -> exp.Var:
parts = [self._advance_any(ignore_reserved=True)]

Expand Down Expand Up @@ -1037,3 +1040,11 @@ def struct_sql(self, expression: exp.Struct) -> str:
values.append(e)

return self.func("OBJECT_CONSTRUCT", *flatten(zip(keys, values)))

def copyparameter_sql(self, expression: exp.CopyParameter) -> str:
option = self.sql(expression, "this")
if option.upper() == "FILE_FORMAT":
values = self.expressions(expression, key="expression", flat=True, sep=" ")
return f"{option} = ({values})"

return super().copyparameter_sql(expression)
5 changes: 2 additions & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3795,7 +3795,7 @@ def credentials_sql(self, expression: exp.Credentials) -> str:
if isinstance(cred_expr, exp.Literal):
# Redshift case: CREDENTIALS <string>
credentials = self.sql(expression, "credentials")
credentials = f"CREDENTIALS {credentials}"
credentials = f"CREDENTIALS {credentials}" if credentials else ""
else:
# Snowflake case: CREDENTIALS = (...)
credentials = self.expressions(expression, key="credentials", flat=True, sep=" ")
Expand All @@ -3805,7 +3805,7 @@ def credentials_sql(self, expression: exp.Credentials) -> str:
storage = f" {storage}" if storage else ""

encryption = self.expressions(expression, key="encryption", flat=True, sep=" ")
encryption = f"ENCRYPTION = ({encryption})" if encryption else ""
encryption = f" ENCRYPTION = ({encryption})" if encryption else ""

iam_role = self.sql(expression, "iam_role")
iam_role = f"IAM_ROLE {iam_role}" if iam_role else ""
Expand All @@ -3821,7 +3821,6 @@ def copy_sql(self, expression: exp.Copy) -> str:

credentials = self.sql(expression, "credentials")
credentials = f" {credentials}" if credentials else ""

kind = " FROM " if expression.args.get("kind") else " TO "
files = self.expressions(expression, key="files", flat=True)

Expand Down
43 changes: 29 additions & 14 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4323,7 +4323,9 @@ def _parse_primary(self) -> t.Optional[exp.Expression]:

this = self._parse_query_modifiers(seq_get(expressions, 0))

if isinstance(this, exp.UNWRAPPED_QUERIES):
if not this and self._match(TokenType.R_PAREN, advance=False):
this = self.expression(exp.Tuple)
elif isinstance(this, exp.UNWRAPPED_QUERIES):
this = self._parse_set_operations(
self._parse_subquery(this=this, parse_alias=False)
)
Expand Down Expand Up @@ -6313,19 +6315,35 @@ def _parse_with_operator(self) -> t.Optional[exp.Expression]:

return self.expression(exp.WithOperator, this=this, op=op)

def _parse_wrapped_options(self) -> t.List[t.Optional[exp.Expression]]:
opts = []
self._match(TokenType.EQ)
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
opts.append(self._parse_conjunction())
self._match(TokenType.COMMA)
return opts

def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
sep = TokenType.COMMA if self.dialect.COPY_PARAMS_ARE_CSV else None

options = []
while self._curr and not self._match(TokenType.R_PAREN, advance=False):
option = self._parse_unquoted_field()
value = None

# Some options are defined as functions with the values as params
if not isinstance(option, exp.Func):
prev = self._prev.text.upper()
# Different dialects might separate options and values by white space, "=" and "AS"
self._match(TokenType.EQ)
self._match(TokenType.ALIAS)
value = self._parse_unquoted_field()

if prev == "FILE_FORMAT" and self._match(TokenType.L_PAREN):
# Snowflake FILE_FORMAT case
value = self._parse_wrapped_options()
else:
value = self._parse_unquoted_field()

param = self.expression(exp.CopyParameter, this=option, expression=value)
options.append(param)
Expand All @@ -6336,32 +6354,29 @@ def _parse_copy_parameters(self) -> t.List[exp.CopyParameter]:
return options

def _parse_credentials(self) -> t.Optional[exp.Credentials]:
def parse_options():
opts = []
self._match(TokenType.EQ)
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
opts.append(self._parse_conjunction())
return opts

expr = self.expression(exp.Credentials)

if self._match_text_seq("STORAGE_INTEGRATION", advance=False):
expr.set("storage", self._parse_conjunction())
if self._match_text_seq("CREDENTIALS"):
# Snowflake supports CREDENTIALS = (...), while Redshift CREDENTIALS <string>
creds = parse_options() if self._match(TokenType.EQ) else self._parse_field()
creds = (
self._parse_wrapped_options() if self._match(TokenType.EQ) else self._parse_field()
)
expr.set("credentials", creds)
if self._match_text_seq("ENCRYPTION"):
expr.set("encryption", parse_options())
expr.set("encryption", self._parse_wrapped_options())
if self._match_text_seq("IAM_ROLE"):
expr.set("iam_role", self._parse_field())
if self._match_text_seq("REGION"):
expr.set("region", self._parse_field())

return expr

def _parse_copy(self):
def _parse_file_location(self) -> t.Optional[exp.Expression]:
return self._parse_field()

def _parse_copy(self) -> exp.Copy | exp.Command:
start = self._prev

self._match(TokenType.INTO)
Expand All @@ -6374,7 +6389,7 @@ def _parse_copy(self):

kind = self._match(TokenType.FROM) or not self._match_text_seq("TO")

files = self._parse_csv(self._parse_conjunction)
files = self._parse_csv(self._parse_file_location)
credentials = self._parse_credentials()

self._match_text_seq("WITH")
Expand Down
17 changes: 14 additions & 3 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ def test_snowflake(self):
self.validate_identity(
"SELECT * FROM DATA AS DATA_L ASOF JOIN DATA AS DATA_R MATCH_CONDITION (DATA_L.VAL > DATA_R.VAL) ON DATA_L.ID = DATA_R.ID"
)
self.validate_identity(
"COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format) PARSE_HEADER = TRUE"
)
self.validate_identity(
"REGEXP_REPLACE('target', 'pattern', '\n')",
"REGEXP_REPLACE('target', 'pattern', '\\n')",
Expand Down Expand Up @@ -1833,3 +1830,17 @@ def test_try_cast(self):

expression = annotate_types(expression)
self.assertEqual(expression.sql(dialect="snowflake"), "SELECT TRY_CAST(FOO() AS TEXT)")

def test_copy(self):
self.validate_identity(
"""COPY INTO mytable (col1, col2) FROM 's3://mybucket/data/files' FILES = ('file1', 'file2') PATTERN = 'pattern' FILE_FORMAT = (FORMAT_NAME = my_csv_format NULL_IF = ('str1', 'str2')) PARSE_HEADER = TRUE"""
)
self.validate_identity(
"""COPY INTO temp FROM @random_stage/path/ FILE_FORMAT = (TYPE = CSV FIELD_DELIMITER = '|' NULL_IF = () FIELD_OPTIONALLY_ENCLOSED_BY = '"' TIMESTAMP_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' DATE_FORMAT = 'TZHTZM YYYY-MM-DD HH24:MI:SS.FF9' BINARY_FORMAT = BASE64) VALIDATION_MODE = 'RETURN_3_ROWS'"""
)
self.validate_identity(
"""COPY INTO load1 FROM @%load1/data1/ FILES = ('test1.csv', 'test2.csv') FORCE = TRUE"""
)
self.validate_identity(
"""COPY INTO mytable FROM 'azure://myaccount.blob.core.windows.net/mycontainer/data/files' CREDENTIALS = (AZURE_SAS_TOKEN = 'token') ENCRYPTION = (TYPE = 'AZURE_CSE' MASTER_KEY = 'kPx...') FILE_FORMAT = (FORMAT_NAME = my_csv_format)"""
)

0 comments on commit 2c2a788

Please sign in to comment.