Skip to content

Commit

Permalink
Fix!: cast less aggressively (#3302)
Browse files Browse the repository at this point in the history
* Fix: cast less aggressively

* Refactor!: get rid of cast_unless, move its logic into cast

* Cleanup

* Refactor

* Cleanup

* Remove copy=True
  • Loading branch information
georgesittas committed Apr 11, 2024
1 parent 6f73186 commit 75e0c69
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 80 deletions.
8 changes: 5 additions & 3 deletions sqlglot/dialects/bigquery.py
Expand Up @@ -197,8 +197,8 @@ def _ts_or_ds_add_sql(self: BigQuery.Generator, expression: exp.TsOrDsAdd) -> st


def _ts_or_ds_diff_sql(self: BigQuery.Generator, expression: exp.TsOrDsDiff) -> str:
expression.this.replace(exp.cast(expression.this, "TIMESTAMP", copy=True))
expression.expression.replace(exp.cast(expression.expression, "TIMESTAMP", copy=True))
expression.this.replace(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))
expression.expression.replace(exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP))
unit = unit_to_var(expression)
return self.func("DATE_DIFF", expression.this, expression.expression, unit)

Expand All @@ -214,7 +214,9 @@ def _unix_to_time_sql(self: BigQuery.Generator, expression: exp.UnixToTime) -> s
if scale == exp.UnixToTime.MICROS:
return self.func("TIMESTAMP_MICROS", timestamp)

unix_seconds = exp.cast(exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), "int64")
unix_seconds = exp.cast(
exp.Div(this=timestamp, expression=exp.func("POW", 10, scale)), exp.DataType.Type.BIGINT
)
return self.func("TIMESTAMP_SECONDS", unix_seconds)


Expand Down
12 changes: 6 additions & 6 deletions sqlglot/dialects/dialect.py
Expand Up @@ -562,7 +562,7 @@ def _if_sql(self: Generator, expression: exp.If) -> str:
def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str:
this = expression.this
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
this.replace(exp.cast(this, "json"))
this.replace(exp.cast(this, exp.DataType.Type.JSON))

return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")

Expand Down Expand Up @@ -772,11 +772,11 @@ def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str:
from sqlglot.optimizer.annotate_types import annotate_types

target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP
return self.sql(exp.cast(expression.this, to=target_type))
return self.sql(exp.cast(expression.this, target_type))
if expression.text("expression").lower() in TIMEZONES:
return self.sql(
exp.AtTimeZone(
this=exp.cast(expression.this, to=exp.DataType.Type.TIMESTAMP),
this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP),
zone=expression.expression,
)
)
Expand Down Expand Up @@ -813,11 +813,11 @@ def right_to_substring_sql(self: Generator, expression: exp.Left) -> str:


def timestrtotime_sql(self: Generator, expression: exp.TimeStrToTime) -> str:
return self.sql(exp.cast(expression.this, "timestamp"))
return self.sql(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP))


def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str:
return self.sql(exp.cast(expression.this, "date"))
return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE))


# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8
Expand Down Expand Up @@ -1030,7 +1030,7 @@ def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str:
plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month")
minus_one_day = exp.func("date_sub", plus_one_month, 1, "day")

return self.sql(exp.cast(minus_one_day, "date"))
return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))


def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/dialects/drill.py
Expand Up @@ -19,7 +19,7 @@ def _str_to_date(self: Drill.Generator, expression: exp.StrToDate) -> str:
this = self.sql(expression, "this")
time_format = self.format_time(expression)
if time_format == Drill.DATE_FORMAT:
return self.sql(exp.cast(this, "date"))
return self.sql(exp.cast(this, exp.DataType.Type.DATE))
return self.func("TO_DATE", this, time_format)


Expand Down Expand Up @@ -134,7 +134,7 @@ class Generator(generator.Generator):
[transforms.eliminate_distinct_on, transforms.eliminate_semi_and_anti_joins]
),
exp.StrToTime: lambda self, e: self.func("TO_TIMESTAMP", e.this, self.format_time(e)),
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
Expand Down
14 changes: 8 additions & 6 deletions sqlglot/dialects/duckdb.py
Expand Up @@ -421,8 +421,8 @@ class Generator(generator.Generator):
exp.MonthsBetween: lambda self, e: self.func(
"DATEDIFF",
"'month'",
exp.cast(e.expression, "timestamp", copy=True),
exp.cast(e.this, "timestamp", copy=True),
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP, copy=True),
exp.cast(e.this, exp.DataType.Type.TIMESTAMP, copy=True),
),
exp.ParseJSON: rename_func("JSON"),
exp.PercentileCont: _rename_unless_within_group("PERCENTILE_CONT", "QUANTILE_CONT"),
Expand Down Expand Up @@ -457,9 +457,11 @@ class Generator(generator.Generator):
"DATE_DIFF", exp.Literal.string(e.unit), e.expression, e.this
),
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, "date")),
exp.TimeStrToDate: lambda self, e: self.sql(exp.cast(e.this, exp.DataType.Type.DATE)),
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeStrToUnix: lambda self, e: self.func("EPOCH", exp.cast(e.this, "timestamp")),
exp.TimeStrToUnix: lambda self, e: self.func(
"EPOCH", exp.cast(e.this, exp.DataType.Type.TIMESTAMP)
),
exp.TimeToStr: lambda self, e: self.func("STRFTIME", e.this, self.format_time(e)),
exp.TimeToUnix: rename_func("EPOCH"),
exp.TsOrDiToDi: lambda self,
Expand All @@ -468,8 +470,8 @@ class Generator(generator.Generator):
exp.TsOrDsDiff: lambda self, e: self.func(
"DATE_DIFF",
f"'{e.args.get('unit') or 'DAY'}'",
exp.cast(e.expression, "TIMESTAMP"),
exp.cast(e.this, "TIMESTAMP"),
exp.cast(e.expression, exp.DataType.Type.TIMESTAMP),
exp.cast(e.this, exp.DataType.Type.TIMESTAMP),
),
exp.UnixToStr: lambda self, e: self.func(
"STRFTIME", self.func("TO_TIMESTAMP", e.this), self.format_time(e)
Expand Down
4 changes: 3 additions & 1 deletion sqlglot/dialects/mysql.py
Expand Up @@ -710,7 +710,9 @@ class Generator(generator.Generator):
),
exp.TimestampSub: date_add_interval_sql("DATE", "SUB"),
exp.TimeStrToUnix: rename_func("UNIX_TIMESTAMP"),
exp.TimeStrToTime: lambda self, e: self.sql(exp.cast(e.this, "datetime", copy=True)),
exp.TimeStrToTime: lambda self, e: self.sql(
exp.cast(e.this, exp.DataType.Type.DATETIME, copy=True)
),
exp.TimeToStr: _remove_ts_or_ds_to_date(
lambda self, e: self.func("DATE_FORMAT", e.this, self.format_time(e))
),
Expand Down
12 changes: 7 additions & 5 deletions sqlglot/dialects/presto.py
Expand Up @@ -90,8 +90,10 @@ def _str_to_time_sql(
def _ts_or_ds_to_date_sql(self: Presto.Generator, expression: exp.TsOrDsToDate) -> str:
time_format = self.format_time(expression)
if time_format and time_format not in (Presto.TIME_FORMAT, Presto.DATE_FORMAT):
return self.sql(exp.cast(_str_to_time_sql(self, expression), "DATE"))
return self.sql(exp.cast(exp.cast(expression.this, "TIMESTAMP"), "DATE"))
return self.sql(exp.cast(_str_to_time_sql(self, expression), exp.DataType.Type.DATE))
return self.sql(
exp.cast(exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), exp.DataType.Type.DATE)
)


def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:
Expand All @@ -101,8 +103,8 @@ def _ts_or_ds_add_sql(self: Presto.Generator, expression: exp.TsOrDsAdd) -> str:


def _ts_or_ds_diff_sql(self: Presto.Generator, expression: exp.TsOrDsDiff) -> str:
this = exp.cast(expression.this, "TIMESTAMP")
expr = exp.cast(expression.expression, "TIMESTAMP")
this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMP)
expr = exp.cast(expression.expression, exp.DataType.Type.TIMESTAMP)
unit = unit_to_str(expression)
return self.func("DATE_DIFF", unit, expr, this)

Expand Down Expand Up @@ -447,7 +449,7 @@ def strtounix_sql(self, expression: exp.StrToUnix) -> str:
# timezone involved, we wrap it in a `TRY` call and use `PARSE_DATETIME` as a fallback,
# which seems to be using the same time mapping as Hive, as per:
# https://joda-time.sourceforge.net/apidocs/org/joda/time/format/DateTimeFormat.html
value_as_text = exp.cast(expression.this, "text")
value_as_text = exp.cast(expression.this, exp.DataType.Type.TEXT)
parse_without_tz = self.func("DATE_PARSE", value_as_text, self.format_time(expression))
parse_with_tz = self.func(
"PARSE_DATETIME",
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/snowflake.py
Expand Up @@ -818,7 +818,7 @@ class Generator(generator.Generator):
exp.TimestampTrunc: timestamptrunc_sql,
exp.TimeStrToTime: timestrtotime_sql,
exp.TimeToStr: lambda self, e: self.func(
"TO_CHAR", exp.cast(e.this, "timestamp"), self.format_time(e)
"TO_CHAR", exp.cast(e.this, exp.DataType.Type.TIMESTAMP), self.format_time(e)
),
exp.TimeToUnix: lambda self, e: f"EXTRACT(epoch_second FROM {self.sql(e, 'this')})",
exp.ToArray: rename_func("TO_ARRAY"),
Expand Down
14 changes: 3 additions & 11 deletions sqlglot/dialects/spark2.py
Expand Up @@ -48,7 +48,7 @@ def _unix_to_time_sql(self: Spark2.Generator, expression: exp.UnixToTime) -> str
timestamp = expression.this

if scale is None:
return self.sql(exp.cast(exp.func("from_unixtime", timestamp), "timestamp"))
return self.sql(exp.cast(exp.func("from_unixtime", timestamp), exp.DataType.Type.TIMESTAMP))
if scale == exp.UnixToTime.SECONDS:
return self.func("TIMESTAMP_SECONDS", timestamp)
if scale == exp.UnixToTime.MILLIS:
Expand Down Expand Up @@ -129,11 +129,7 @@ class Parser(Hive.Parser):
"DOUBLE": _build_as_cast("double"),
"FLOAT": _build_as_cast("float"),
"FROM_UTC_TIMESTAMP": lambda args: exp.AtTimeZone(
this=exp.cast_unless(
seq_get(args, 0) or exp.Var(this=""),
exp.DataType.build("timestamp"),
exp.DataType.build("timestamp"),
),
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
zone=seq_get(args, 1),
),
"INT": _build_as_cast("int"),
Expand All @@ -150,11 +146,7 @@ class Parser(Hive.Parser):
),
"TO_UNIX_TIMESTAMP": exp.StrToUnix.from_arg_list,
"TO_UTC_TIMESTAMP": lambda args: exp.FromTimeZone(
this=exp.cast_unless(
seq_get(args, 0) or exp.Var(this=""),
exp.DataType.build("timestamp"),
exp.DataType.build("timestamp"),
),
this=exp.cast(seq_get(args, 0) or exp.Var(this=""), exp.DataType.Type.TIMESTAMP),
zone=seq_get(args, 1),
),
"TRUNC": lambda args: exp.DateTrunc(unit=seq_get(args, 1), this=seq_get(args, 0)),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/teradata.py
Expand Up @@ -311,7 +311,7 @@ def extract_sql(self, expression: exp.Extract) -> str:
return super().extract_sql(expression)

to_char = exp.func("to_char", expression.expression, exp.Literal.string("Q"))
return self.sql(exp.cast(to_char, "int"))
return self.sql(exp.cast(to_char, exp.DataType.Type.INT))

def interval_sql(self, expression: exp.Interval) -> str:
multiplier = 0
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/tsql.py
Expand Up @@ -109,7 +109,7 @@ def _builder(args: t.List) -> E:
assert len(args) == 2

return exp_class(
this=exp.cast(args[1], "datetime"),
this=exp.cast(args[1], exp.DataType.Type.DATETIME),
format=exp.Literal.string(
format_time(
args[0].name.lower(),
Expand Down
34 changes: 9 additions & 25 deletions sqlglot/expressions.py
Expand Up @@ -6885,11 +6885,16 @@ def cast(expression: ExpOrStr, to: DATA_TYPE, copy: bool = True, **opts) -> Cast
Returns:
The new Cast instance.
"""
expression = maybe_parse(expression, copy=copy, **opts)
expr = maybe_parse(expression, copy=copy, **opts)
data_type = DataType.build(to, copy=copy, **opts)
expression = Cast(this=expression, to=data_type)
expression.type = data_type
return expression

if expr.is_type(data_type):
return expr

expr = Cast(this=expr, to=data_type)
expr.type = data_type

return expr


def table_(
Expand Down Expand Up @@ -7418,27 +7423,6 @@ def case(
return Case(this=this, ifs=[])


def cast_unless(
expression: ExpOrStr,
to: DATA_TYPE,
*types: DATA_TYPE,
**opts: t.Any,
) -> Expression | Cast:
"""
Cast an expression to a data type unless it is a specified type.
Args:
expression: The expression to cast.
to: The data type to cast to.
**types: The types to exclude from casting.
**opts: Extra keyword arguments for parsing `expression`
"""
expr = maybe_parse(expression, **opts)
if expr.is_type(*types):
return expr
return cast(expr, to, **opts)


def array(
*expressions: ExpOrStr, copy: bool = True, dialect: DialectType = None, **kwargs
) -> Array:
Expand Down
25 changes: 15 additions & 10 deletions sqlglot/generator.py
Expand Up @@ -2522,7 +2522,7 @@ def convert_concat_args(self, expression: exp.Concat | exp.ConcatWs) -> t.List[e
args = args[1:] # Skip the delimiter

if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
args = [exp.cast(e, "text") for e in args]
args = [exp.cast(e, exp.DataType.Type.TEXT) for e in args]

if not self.dialect.CONCAT_COALESCE and expression.args.get("coalesce"):
args = [exp.func("coalesce", e, exp.Literal.string("")) for e in args]
Expand Down Expand Up @@ -3048,7 +3048,9 @@ def intdiv_sql(self, expression: exp.IntDiv) -> str:

def dpipe_sql(self, expression: exp.DPipe) -> str:
if self.dialect.STRICT_STRING_CONCAT and expression.args.get("safe"):
return self.func("CONCAT", *(exp.cast(e, "text") for e in expression.flatten()))
return self.func(
"CONCAT", *(exp.cast(e, exp.DataType.Type.TEXT) for e in expression.flatten())
)
return self.binary(expression, "||")

def div_sql(self, expression: exp.Div) -> str:
Expand Down Expand Up @@ -3351,17 +3353,17 @@ def tochar_sql(self, expression: exp.ToChar) -> str:
if expression.args.get("format"):
self.unsupported("Format argument unsupported for TO_CHAR/TO_VARCHAR function")

return self.sql(exp.cast(expression.this, "text"))
return self.sql(exp.cast(expression.this, exp.DataType.Type.TEXT))

def tonumber_sql(self, expression: exp.ToNumber) -> str:
if not self.SUPPORTS_TO_NUMBER:
self.unsupported("Unsupported TO_NUMBER function")
return self.sql(exp.cast(expression.this, "double"))
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))

fmt = expression.args.get("format")
if not fmt:
self.unsupported("Conversion format is required for TO_NUMBER")
return self.sql(exp.cast(expression.this, "double"))
return self.sql(exp.cast(expression.this, exp.DataType.Type.DOUBLE))

return self.func("TO_NUMBER", expression.this, fmt)

Expand Down Expand Up @@ -3532,35 +3534,38 @@ def tsordstotime_sql(self, expression: exp.TsOrDsToTime) -> str:
if isinstance(this, exp.TsOrDsToTime) or this.is_type(exp.DataType.Type.TIME):
return self.sql(this)

return self.sql(exp.cast(this, "time"))
return self.sql(exp.cast(this, exp.DataType.Type.TIME))

def tsordstotimestamp_sql(self, expression: exp.TsOrDsToTimestamp) -> str:
this = expression.this
if isinstance(this, exp.TsOrDsToTimestamp) or this.is_type(exp.DataType.Type.TIMESTAMP):
return self.sql(this)

return self.sql(exp.cast(this, "timestamp"))
return self.sql(exp.cast(this, exp.DataType.Type.TIMESTAMP))

def tsordstodate_sql(self, expression: exp.TsOrDsToDate) -> str:
this = expression.this
time_format = self.format_time(expression)

if time_format and time_format not in (self.dialect.TIME_FORMAT, self.dialect.DATE_FORMAT):
return self.sql(
exp.cast(exp.StrToTime(this=this, format=expression.args["format"]), "date")
exp.cast(
exp.StrToTime(this=this, format=expression.args["format"]),
exp.DataType.Type.DATE,
)
)

if isinstance(this, exp.TsOrDsToDate) or this.is_type(exp.DataType.Type.DATE):
return self.sql(this)

return self.sql(exp.cast(this, "date"))
return self.sql(exp.cast(this, exp.DataType.Type.DATE))

def unixdate_sql(self, expression: exp.UnixDate) -> str:
return self.sql(
exp.func(
"DATEDIFF",
expression.this,
exp.cast(exp.Literal.string("1970-01-01"), "date"),
exp.cast(exp.Literal.string("1970-01-01"), exp.DataType.Type.DATE),
"day",
)
)
Expand Down
3 changes: 3 additions & 0 deletions tests/dialects/test_databricks.py
Expand Up @@ -25,6 +25,9 @@ def test_databricks(self):
self.validate_identity("CREATE FUNCTION a AS b")
self.validate_identity("SELECT ${x} FROM ${y} WHERE ${z} > 1")
self.validate_identity("CREATE TABLE foo (x DATE GENERATED ALWAYS AS (CAST(y AS DATE)))")
self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t"
)
self.validate_identity(
"SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))"
)
Expand Down
10 changes: 2 additions & 8 deletions tests/test_build.py
Expand Up @@ -673,14 +673,8 @@ def test_build(self):
"(x, y) IN ((1, 2), (3, 4))",
"postgres",
),
(
lambda: exp.cast_unless("CAST(x AS INT)", "int", "int"),
"CAST(x AS INT)",
),
(
lambda: exp.cast_unless("CAST(x AS TEXT)", "int", "int"),
"CAST(CAST(x AS TEXT) AS INT)",
),
(lambda: exp.cast("CAST(x AS INT)", "int"), "CAST(x AS INT)"),
(lambda: exp.cast("CAST(x AS TEXT)", "int"), "CAST(CAST(x AS TEXT) AS INT)"),
(
lambda: exp.rename_column("table1", "c1", "c2", True),
"ALTER TABLE table1 RENAME COLUMN IF EXISTS c1 TO c2",
Expand Down

0 comments on commit 75e0c69

Please sign in to comment.