Skip to content

Commit

Permalink
TO_NUMBER transpilation cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Mar 11, 2024
1 parent c4e7bbf commit 14c1dad
Show file tree
Hide file tree
Showing 22 changed files with 98 additions and 54 deletions.
1 change: 0 additions & 1 deletion sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ class Generator(Spark.Generator):
]
),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

TRANSFORMS.pop(exp.TryCast)
Expand Down
9 changes: 9 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,3 +1079,12 @@ def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> s
unnest = exp.Unnest(expressions=[expression.this])
filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond)
return self.sql(exp.Array(expressions=[filtered]))


def to_number_with_nls_param(self, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("nlsparam"),
)
6 changes: 2 additions & 4 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Parser(MySQL.Parser):
}

class Generator(MySQL.Generator):
CAST_MAPPING = {}
LAST_DAY_SUPPORTS_DATE_PART = False

TYPE_MAPPING = {
**MySQL.Generator.TYPE_MAPPING,
Expand All @@ -36,9 +36,7 @@ class Generator(MySQL.Generator):
exp.DataType.Type.TIMESTAMPTZ: "DATETIME",
}

LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_TO_NUMBER = False

CAST_MAPPING = {}
TIMESTAMP_FUNC_TYPES = set()

TRANSFORMS = {
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,5 +155,4 @@ class Generator(generator.Generator):
e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}
3 changes: 2 additions & 1 deletion sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
build_formatted_time,
no_ilike_sql,
rename_func,
to_number_with_nls_param,
trim_sql,
)
from sqlglot.helper import seq_get
Expand Down Expand Up @@ -246,7 +247,7 @@ class Generator(generator.Generator):
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ class Generator(generator.Generator):
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

PROPERTIES_LOCATION = {
Expand Down
4 changes: 0 additions & 4 deletions sqlglot/dialects/prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from sqlglot.dialects.dialect import Dialect
from sqlglot.tokens import TokenType

if t.TYPE_CHECKING:
pass


class PRQL(Dialect):
class Tokenizer(tokens.Tokenizer):
Expand Down Expand Up @@ -109,5 +106,4 @@ def _parse_from(
)

class Generator(generator.Generator):
SUPPORTS_TO_NUMBER = False
pass
1 change: 0 additions & 1 deletion sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ class Generator(Postgres.Generator):
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
Expand Down
16 changes: 15 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,12 @@ class Parser(parser.Parser):
"TIMESTAMPDIFF": _build_datediff,
"TIMESTAMPFROMPARTS": _build_timestamp_from_parts,
"TIMESTAMP_FROM_PARTS": _build_timestamp_from_parts,
"TO_NUMBER": lambda args: exp.ToNumber(
this=seq_get(args, 0),
format=seq_get(args, 1),
precision=seq_get(args, 2),
scale=seq_get(args, 3),
),
"TO_TIMESTAMP": _build_to_timestamp,
"TO_VARCHAR": exp.ToChar.from_arg_list,
"ZEROIFNULL": _build_if_from_zeroifnull,
Expand Down Expand Up @@ -785,7 +791,6 @@ class Generator(generator.Generator):
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.Xor: rename_func("BOOLXOR"),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

SUPPORTED_JSON_PATH_PARTS = {
Expand All @@ -810,6 +815,15 @@ class Generator(generator.Generator):
exp.VolatileProperty: exp.Properties.Location.UNSUPPORTED,
}

def tonumber_sql(self, expression: exp.ToNumber) -> str:
return self.func(
"TO_NUMBER",
expression.this,
expression.args.get("format"),
expression.args.get("precision"),
expression.args.get("scale"),
)

def timestampfromparts_sql(self, expression: exp.TimestampFromParts) -> str:
milli = expression.args.get("milli")
if milli is not None:
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def _parse_generated_as_identity(
return this

class Generator(Spark2.Generator):
SUPPORTS_TO_NUMBER = True

TYPE_MAPPING = {
**Spark2.Generator.TYPE_MAPPING,
exp.DataType.Type.MONEY: "DECIMAL(15, 4)",
Expand Down
2 changes: 0 additions & 2 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class Generator(Hive.Generator):
QUERY_HINTS = True
NVL2_SUPPORTED = True
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = True

PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
Expand Down Expand Up @@ -252,7 +251,6 @@ class Generator(Hive.Generator):
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
SUPPORTS_TO_NUMBER = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
10 changes: 8 additions & 2 deletions sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import typing as t

from sqlglot import exp, generator, parser, tokens, transforms
from sqlglot.dialects.dialect import Dialect, max_or_greatest, min_or_least, rename_func
from sqlglot.dialects.dialect import (
Dialect,
max_or_greatest,
min_or_least,
rename_func,
to_number_with_nls_param,
)
from sqlglot.tokens import TokenType


Expand Down Expand Up @@ -206,8 +212,8 @@ class Generator(generator.Generator):
exp.StrToDate: lambda self,
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: to_number_with_nls_param,
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
Expand Down
8 changes: 7 additions & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4523,7 +4523,13 @@ class ToChar(Func):
# https://docs.snowflake.com/en/sql-reference/functions/to_decimal
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html
class ToNumber(Func):
arg_types = {"this": True, "format": False, "precision": False, "scale": False}
arg_types = {
"this": True,
"format": False,
"nlsparam": False,
"precision": False,
"scale": False,
}


# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
Expand Down
12 changes: 10 additions & 2 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3276,8 +3276,16 @@ def tochar_sql(self, expression: exp.ToChar) -> str:
return self.sql(exp.cast(expression.this, "text"))

def tonumber_sql(self, expression: exp.ToNumber) -> str:
self.unsupported("Unsupported TO_NUMBER function, converting to a CAST as DOUBLE")
return self.sql(exp.cast(expression.this, "double"))
if not self.SUPPORTS_TO_NUMBER:
self.unsupported("Unsupported TO_NUMBER function")
return self.sql(exp.cast(expression.this, "double"))

fmt = expression.args.get("format")
if not fmt:
self.unsupported("Conversion format is required for TO_NUMBER")
return self.sql(exp.cast(expression.this, "double"))

return self.func("TO_NUMBER", expression.this, fmt)

def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
Expand Down
6 changes: 0 additions & 6 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,12 +1051,6 @@ def test_bigquery(self):
"spark": "TO_JSON(x)",
},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"bigquery": "CAST(x AS FLOAT64)",
},
)
self.validate_all(
"""SELECT
`u`.`harness_user_email` AS `harness_user_email`,
Expand Down
6 changes: 0 additions & 6 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,6 @@ def test_clickhouse(self):
write={"clickhouse": "SELECT startsWith('a', 'b'), startsWith('a', 'b')"},
)
self.validate_identity("SYSTEM STOP MERGES foo.bar", check_command_warning=True)
self.validate_all(
"TO_NUMBER(x)",
write={
"clickhouse": "CAST(x AS Float64)",
},
)

def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
Expand Down
49 changes: 41 additions & 8 deletions tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,59 @@ def test_oracle(self):
"oracle": "TO_CHAR(x)",
},
)
self.validate_all(
"TO_NUMBER(expr, fmt, nlsparam)",
read={
"teradata": "TO_NUMBER(expr, fmt, nlsparam)",
},
write={
"oracle": "TO_NUMBER(expr, fmt, nlsparam)",
"teradata": "TO_NUMBER(expr, fmt, nlsparam)",
},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"bigquery": "CAST(x AS FLOAT64)",
"doris": "CAST(x AS DOUBLE)",
"presto": "CAST(x AS DOUBLE)",
"drill": "CAST(x AS DOUBLE)",
"duckdb": "CAST(x AS DOUBLE)",
"hive": "CAST(x AS DOUBLE)",
"mysql": "CAST(x AS DOUBLE)",
"starrocks": "CAST(x AS DOUBLE)",
"tableau": "CAST(x AS DOUBLE)",
"oracle": "TO_NUMBER(x)",
"postgres": "CAST(x AS DOUBLE PRECISION)",
"presto": "CAST(x AS DOUBLE)",
"redshift": "CAST(x AS DOUBLE PRECISION)",
"snowflake": "TO_NUMBER(x)",
"spark2": "TO_NUMBER(x)",
"databricks": "TO_NUMBER(x)",
"drill": "TO_NUMBER(x)",
"postgres": "TO_NUMBER(x)",
"redshift": "TO_NUMBER(x)",
"spark": "CAST(x AS DOUBLE)",
"spark2": "CAST(x AS DOUBLE)",
"starrocks": "CAST(x AS DOUBLE)",
"tableau": "CAST(x AS DOUBLE)",
"teradata": "TO_NUMBER(x)",
},
)
self.validate_all(
"TO_NUMBER(x, fmt)",
read={
"databricks": "TO_NUMBER(x, fmt)",
"drill": "TO_NUMBER(x, fmt)",
"postgres": "TO_NUMBER(x, fmt)",
"snowflake": "TO_NUMBER(x, fmt)",
"spark": "TO_NUMBER(x, fmt)",
"redshift": "TO_NUMBER(x, fmt)",
"teradata": "TO_NUMBER(x, fmt)",
},
write={
"databricks": "TO_NUMBER(x, fmt)",
"drill": "TO_NUMBER(x, fmt)",
"oracle": "TO_NUMBER(x, fmt)",
"postgres": "TO_NUMBER(x, fmt)",
"snowflake": "TO_NUMBER(x, fmt)",
"spark": "TO_NUMBER(x, fmt)",
"redshift": "TO_NUMBER(x, fmt)",
"teradata": "TO_NUMBER(x, fmt)",
},
)
self.validate_all(
"SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')",
write={
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_snowflake(self):
)""",
)

self.validate_identity("TO_DECIMAL(expr, fmt, precision, scale)")
self.validate_identity("ALTER TABLE authors ADD CONSTRAINT c1 UNIQUE (id, email)")
self.validate_identity("RM @parquet_stage", check_command_warning=True)
self.validate_identity("REMOVE @parquet_stage", check_command_warning=True)
Expand Down
6 changes: 0 additions & 6 deletions tests/dialects/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,6 @@ def test_sqlite(self):
read={"snowflake": "LEAST(x, y, z)"},
write={"snowflake": "LEAST(x, y, z)"},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"sqlite": "CAST(x AS REAL)",
},
)

def test_datediff(self):
self.validate_all(
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ class TestTeradata(Validator):
dialect = "teradata"

def test_teradata(self):
self.validate_identity("TO_NUMBER(expr, fmt, nlsparam)")
self.validate_identity("SELECT TOP 10 * FROM tbl")
self.validate_identity("SELECT * FROM tbl SAMPLE 5")
self.validate_identity(
Expand Down
6 changes: 0 additions & 6 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,6 @@ def test_tsql(self):
)
self.validate_identity("HASHBYTES('MD2', 'x')")
self.validate_identity("LOG(n, b)")
self.validate_all(
"TO_NUMBER(x)",
write={
"tsql": "CAST(x AS FLOAT)",
},
)

def test_option(self):
possible_options = [
Expand Down

0 comments on commit 14c1dad

Please sign in to comment.