Skip to content

Commit

Permalink
Fix: snowflake object_construct to struct closes #1699
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed May 29, 2023
1 parent 223c58d commit da17c4d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 3 deletions.
19 changes: 18 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,19 @@ def _snowflake_to_timestamp(args: t.List) -> t.Union[exp.StrToTime, exp.UnixToTi
return exp.UnixToTime.from_arg_list(args)


def _parse_object_construct(args: t.List) -> t.Union[exp.StarMap, exp.Struct]:
expression = parser.parse_var_map(args)

if isinstance(expression, exp.StarMap):
return expression

return exp.Struct(
expressions=[
t.cast(exp.Condition, k).eq(v) for k, v in zip(expression.keys, expression.values)
]
)


def _unix_to_time_sql(self: generator.Generator, expression: exp.UnixToTime) -> str:
scale = expression.args.get("scale")
timestamp = self.sql(expression, "this")
Expand Down Expand Up @@ -209,7 +222,7 @@ class Parser(parser.Parser):
"DIV0": _div0_to_if,
"IFF": exp.If.from_arg_list,
"NULLIFZERO": _nullifzero_to_if,
"OBJECT_CONSTRUCT": parser.parse_var_map,
"OBJECT_CONSTRUCT": _parse_object_construct,
"RLIKE": exp.RegexpLike.from_arg_list,
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
"TO_ARRAY": exp.Array.from_arg_list,
Expand Down Expand Up @@ -325,6 +338,10 @@ class Generator(generator.Generator):
"POSITION", e.args.get("substr"), e.this, e.args.get("position")
),
exp.StrToTime: lambda self, e: f"TO_TIMESTAMP({self.sql(e, 'this')}, {self.format_time(e)})",
exp.Struct: lambda self, e: self.func(
"OBJECT_CONSTRUCT",
*(arg for expression in e.expressions for arg in expression.flatten()),
),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.TimeToStr: lambda self, e: self.func(
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4183,6 +4183,14 @@ class VarMap(Func):
arg_types = {"keys": True, "values": True}
is_var_len_args = True

@property
def keys(self) -> t.List[Expression]:
return self.args["keys"].expressions

@property
def values(self) -> t.List[Expression]:
return self.args["values"].expressions


# https://dev.mysql.com/doc/refman/8.0/en/fulltext-search.html
class MatchAgainst(Func):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
logger = logging.getLogger("sqlglot")


def parse_var_map(args: t.List) -> exp.Expression:
def parse_var_map(args: t.List) -> exp.StarMap | exp.VarMap:
if len(args) == 1 and args[0].is_star:
return exp.StarMap(this=args[0])

Expand Down
7 changes: 6 additions & 1 deletion tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,12 @@ def test_hive(self):
"spark": "GET_JSON_OBJECT(x, '$.name')",
},
)
self.validate_all(
"STRUCT(a = b, c = d)",
read={
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
)
self.validate_all(
"MAP(a, b, c, d)",
read={
Expand All @@ -568,7 +574,6 @@ def test_hive(self):
"hive": "MAP(a, b, c, d)",
"presto": "MAP(ARRAY[a, c], ARRAY[b, d])",
"spark": "MAP(a, b, c, d)",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
write={
"": "MAP(ARRAY(a, c), ARRAY(b, d))",
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@ def test_snowflake(self):
self.validate_all("CAST(x AS CHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS CHARACTER VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all("CAST(x AS NCHAR VARYING)", write={"snowflake": "CAST(x AS VARCHAR)"})
self.validate_all(
"OBJECT_CONSTRUCT(a, b, c, d)",
read={
"": "STRUCT(a as b, c as d)",
},
write={
"duckdb": "{'a': b, 'c': d}",
"snowflake": "OBJECT_CONSTRUCT(a, b, c, d)",
},
)
self.validate_all(
"SELECT i, p, o FROM qt QUALIFY ROW_NUMBER() OVER (PARTITION BY p ORDER BY o) = 1",
write={
Expand Down

0 comments on commit da17c4d

Please sign in to comment.