Skip to content

Commit

Permalink
Fix!: StrToUnix Hive parsing, Presto generation fixes (#3225)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Mar 26, 2024
1 parent 83569ab commit 0919be5
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 11 deletions.
4 changes: 3 additions & 1 deletion sqlglot/dialects/hive.py
Expand Up @@ -319,7 +319,9 @@ class Parser(parser.Parser):
"TO_DATE": build_formatted_time(exp.TsOrDsToDate, "hive"),
"TO_JSON": exp.JSONFormat.from_arg_list,
"UNBASE64": exp.FromBase64.from_arg_list,
"UNIX_TIMESTAMP": build_formatted_time(exp.StrToUnix, "hive", True),
"UNIX_TIMESTAMP": lambda args: build_formatted_time(exp.StrToUnix, "hive", True)(
args or [exp.CurrentTimestamp()]
),
"YEAR": lambda args: exp.Year(this=exp.TsOrDsToDate.from_arg_list(args)),
}

Expand Down
20 changes: 17 additions & 3 deletions sqlglot/dialects/presto.py
Expand Up @@ -26,6 +26,7 @@
timestrtotime_sql,
ts_or_ds_add_cast,
)
from sqlglot.dialects.hive import Hive
from sqlglot.dialects.mysql import MySQL
from sqlglot.helper import apply_index_offset, seq_get
from sqlglot.tokens import TokenType
Expand Down Expand Up @@ -404,9 +405,6 @@ class Generator(generator.Generator):
exp.StrToDate: lambda self, e: f"CAST({_str_to_time_sql(self, e)} AS DATE)",
exp.StrToMap: rename_func("SPLIT_TO_MAP"),
exp.StrToTime: _str_to_time_sql,
exp.StrToUnix: lambda self, e: self.func(
"TO_UNIXTIME", self.func("DATE_PARSE", e.this, self.format_time(e))
),
exp.StructExtract: struct_extract_sql,
exp.Table: transforms.preprocess([_unnest_sequence]),
exp.Timestamp: no_timestamp_sql,
Expand Down Expand Up @@ -439,6 +437,22 @@ class Generator(generator.Generator):
exp.Xor: bool_xor_sql,
}

def strtounix_sql(self, expression: exp.StrToUnix) -> str:
# Since `TO_UNIXTIME` requires a `TIMESTAMP`, we need to parse the argument into one.
# To do this, we first try to `DATE_PARSE` it, but since this can fail when there's a
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
value_as_text = exp.cast(expression.this, "text")
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
parse_with_tz = self.func(
"PARSE_DATETIME",
value_as_text,
self.format_time(expression, Hive.INVERSE_TIME_MAPPING, Hive.INVERSE_TIME_TRIE),
)
coalesced = self.func("COALESCE", self.func("TRY", parse_without_tz), parse_with_tz)
return self.func("TO_UNIXTIME", coalesced)

def bracket_sql(self, expression: exp.Bracket) -> str:
if expression.args.get("safe"):
return self.func(
Expand Down
11 changes: 8 additions & 3 deletions sqlglot/generator.py
Expand Up @@ -3192,11 +3192,16 @@ def format_args(self, *args: t.Optional[str | exp.Expression]) -> str:
def text_width(self, args: t.Iterable) -> int:
return sum(len(arg) for arg in args)

def format_time(self, expression: exp.Expression) -> t.Optional[str]:
def format_time(
self,
expression: exp.Expression,
inverse_time_mapping: t.Optional[t.Dict[str, str]] = None,
inverse_time_trie: t.Optional[t.Dict] = None,
) -> t.Optional[str]:
return format_time(
self.sql(expression, "format"),
self.dialect.INVERSE_TIME_MAPPING,
self.dialect.INVERSE_TIME_TRIE,
inverse_time_mapping or self.dialect.INVERSE_TIME_MAPPING,
inverse_time_trie or self.dialect.INVERSE_TIME_TRIE,
)

def expressions(
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_dialect.py
Expand Up @@ -611,7 +611,7 @@ def test_time(self):
write={
"duckdb": "EPOCH(STRPTIME('2020-01-01', '%Y-%m-%d'))",
"hive": "UNIX_TIMESTAMP('2020-01-01', 'yyyy-MM-dd')",
"presto": "TO_UNIXTIME(DATE_PARSE('2020-01-01', '%Y-%m-%d'))",
"presto": "TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST('2020-01-01' AS VARCHAR), '%Y-%m-%d')), PARSE_DATETIME(CAST('2020-01-01' AS VARCHAR), 'yyyy-MM-dd')))",
"starrocks": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')",
"doris": "UNIX_TIMESTAMP('2020-01-01', '%Y-%m-%d')",
},
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_hive.py
Expand Up @@ -369,7 +369,7 @@ def test_time(self):
"UNIX_TIMESTAMP(x)",
write={
"duckdb": "EPOCH(STRPTIME(x, '%Y-%m-%d %H:%M:%S'))",
"presto": "TO_UNIXTIME(DATE_PARSE(x, '%Y-%m-%d %T'))",
"presto": "TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST(x AS VARCHAR), '%Y-%m-%d %T')), PARSE_DATETIME(CAST(x AS VARCHAR), 'yyyy-MM-dd HH:mm:ss')))",
"hive": "UNIX_TIMESTAMP(x)",
"spark": "UNIX_TIMESTAMP(x)",
"": "STR_TO_UNIX(x, '%Y-%m-%d %H:%M:%S')",
Expand Down
5 changes: 4 additions & 1 deletion tests/dialects/test_spark.py
Expand Up @@ -246,12 +246,15 @@ def test_spark(self):
self.validate_identity("SELECT TRANSFORM(ARRAY(1, 2, 3), (x, i) -> x + i)")
self.validate_identity("REFRESH TABLE a.b.c")
self.validate_identity("INTERVAL -86 DAYS")
self.validate_identity("SELECT UNIX_TIMESTAMP()")
self.validate_identity("TRIM(' SparkSQL ')")
self.validate_identity("TRIM(BOTH 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(LEADING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("TRIM(TRAILING 'SL' FROM 'SSparkSQLS')")
self.validate_identity("SPLIT(str, pattern, lim)")
self.validate_identity(
"SELECT UNIX_TIMESTAMP()",
"SELECT UNIX_TIMESTAMP(CURRENT_TIMESTAMP())",
)
self.validate_identity(
"SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL 23 HOUR + 59 MINUTE + 59 SECONDS",
"SELECT CAST('2023-01-01' AS TIMESTAMP) + INTERVAL '23' HOUR + INTERVAL '59' MINUTE + INTERVAL '59' SECONDS",
Expand Down
6 changes: 5 additions & 1 deletion tests/test_transpile.py
Expand Up @@ -702,7 +702,11 @@ def test_time(self):
)

self.validate("STR_TO_TIME('x', 'y')", "DATE_PARSE('x', 'y')", write="presto")
self.validate("STR_TO_UNIX('x', 'y')", "TO_UNIXTIME(DATE_PARSE('x', 'y'))", write="presto")
self.validate(
"STR_TO_UNIX('x', 'y')",
"TO_UNIXTIME(COALESCE(TRY(DATE_PARSE(CAST('x' AS VARCHAR), 'y')), PARSE_DATETIME(CAST('x' AS VARCHAR), 'y')))",
write="presto",
)
self.validate("TIME_TO_STR(x, 'y')", "DATE_FORMAT(x, 'y')", write="presto")
self.validate("TIME_TO_UNIX(x)", "TO_UNIXTIME(x)", write="presto")
self.validate(
Expand Down

0 comments on commit 0919be5

Please sign in to comment.