Skip to content

Commit

Permalink
Feat: improve transpilation of TO_NUMBER
Browse files Browse the repository at this point in the history
* add to_number from oracle to doris

* Change to double

* fix make style error

* modify the use of to_number

* add tests for all dialects

* fix style error

* fix mypy error

* fix err
  • Loading branch information
codeDing18 committed Mar 11, 2024
1 parent 804af34 commit c4e7bbf
Show file tree
Hide file tree
Showing 26 changed files with 77 additions and 0 deletions.
1 change: 1 addition & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,7 @@ class Generator(generator.Generator):
IGNORE_NULLS_IN_FUNC = True
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False
NAMED_PLACEHOLDER_TOKEN = "@"

TRANSFORMS = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ class Generator(generator.Generator):
TABLESAMPLE_KEYWORDS = "SAMPLE"
LAST_DAY_SUPPORTS_DATE_PART = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False

STRING_TYPE_MAPPING = {
exp.DataType.Type.CHAR: "String",
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ class Generator(Spark.Generator):
]
),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

TRANSFORMS.pop(exp.TryCast)
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class Generator(MySQL.Generator):
}

LAST_DAY_SUPPORTS_DATE_PART = False
SUPPORTS_TO_NUMBER = False

TIMESTAMP_FUNC_TYPES = set()

Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,5 @@ class Generator(generator.Generator):
e: f"DATE_ADD(CAST({self.sql(e, 'this')} AS DATE), {self.sql(exp.Interval(this=e.expression, unit=exp.var('DAY')))})",
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}
1 change: 1 addition & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ class Generator(generator.Generator):
SUPPORTS_CREATE_TABLE_LIKE = False
MULTI_ARG_DISTINCT = False
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,7 @@ class Generator(generator.Generator):
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_PATH_SINGLE_QUOTE_ESCAPE = True
SUPPORTS_TO_NUMBER = False

EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Insert,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ class Generator(generator.Generator):
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
JSON_KEY_VALUE_PAIR_SEP = ","
SUPPORTS_TO_NUMBER = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ class Generator(generator.Generator):
exp.TemporaryProperty: lambda _, e: f"{e.name or 'GLOBAL'} TEMPORARY",
exp.TimeToStr: lambda self, e: self.func("TO_CHAR", e.this, self.format_time(e)),
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
exp.Trim: trim_sql,
exp.UnixToTime: lambda self,
e: f"TO_DATE('1970-01-01', 'YYYY-MM-DD') + ({self.sql(e, 'this')} / 86400)",
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,7 @@ class Generator(generator.Generator):
exp.VariancePop: rename_func("VAR_POP"),
exp.Variance: rename_func("VAR_SAMP"),
exp.Xor: bool_xor_sql,
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

PROPERTIES_LOCATION = {
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ class Generator(generator.Generator):
SUPPORTS_SINGLE_ARG_CONCAT = False
LIKE_PROPERTY_INSIDE_SCHEMA = True
MULTI_ARG_DISTINCT = False
SUPPORTS_TO_NUMBER = False

PROPERTIES_LOCATION = {
**generator.Generator.PROPERTIES_LOCATION,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/prql.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,4 +109,5 @@ def _parse_from(
)

class Generator(generator.Generator):
SUPPORTS_TO_NUMBER = False
pass
1 change: 1 addition & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ class Generator(Postgres.Generator):
exp.TableSample: no_tablesample_sql,
exp.TsOrDsAdd: date_delta_sql("DATEADD"),
exp.TsOrDsDiff: date_delta_sql("DATEDIFF"),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

# Postgres maps exp.Pivot to no_pivot_sql, but Redshift support pivots
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ class Generator(generator.Generator):
exp.VarMap: lambda self, e: var_map_sql(self, e, "OBJECT_CONSTRUCT"),
exp.WeekOfYear: rename_func("WEEKOFYEAR"),
exp.Xor: rename_func("BOOLXOR"),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

SUPPORTED_JSON_PATH_PARTS = {
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class Generator(Hive.Generator):
QUERY_HINTS = True
NVL2_SUPPORTED = True
CAN_IMPLEMENT_ARRAY_ANY = True
SUPPORTS_TO_NUMBER = True

PROPERTIES_LOCATION = {
**Hive.Generator.PROPERTIES_LOCATION,
Expand Down Expand Up @@ -251,6 +252,7 @@ class Generator(Hive.Generator):
exp.WithinGroup: transforms.preprocess(
[transforms.remove_within_group_for_percentiles]
),
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class Generator(generator.Generator):
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False
SUPPORTS_TO_NUMBER = False

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/tableau.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class Generator(generator.Generator):
JOIN_HINTS = False
TABLE_HINTS = False
QUERY_HINTS = False
SUPPORTS_TO_NUMBER = False

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/teradata.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ class Generator(generator.Generator):
e: f"CAST({self.sql(e, 'this')} AS DATE FORMAT {self.format_time(e)})",
exp.ToChar: lambda self, e: self.function_fallback_sql(e),
exp.Use: lambda self, e: f"DATABASE {self.sql(e, 'this')}",
exp.ToNumber: lambda self, e: self.function_fallback_sql(e),
}

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,7 @@ class Generator(generator.Generator):
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_TO_NUMBER = False

EXPRESSIONS_WITHOUT_NESTED_CTES = {
exp.Delete,
Expand Down
6 changes: 6 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4520,6 +4520,12 @@ class ToChar(Func):
arg_types = {"this": True, "format": False, "nlsparam": False}


# https://docs.snowflake.com/en/sql-reference/functions/to_decimal
# https://docs.oracle.com/en/database/oracle/oracle-database/23/sqlrf/TO_NUMBER.html
class ToNumber(Func):
arg_types = {"this": True, "format": False, "precision": False, "scale": False}


# https://learn.microsoft.com/en-us/sql/t-sql/functions/cast-and-convert-transact-sql?view=sql-server-ver16#syntax
class Convert(Func):
arg_types = {"this": True, "expression": True, "style": False}
Expand Down
7 changes: 7 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ class Generator(metaclass=_Generator):
# Whether any(f(x) for x in array) can be implemented by this dialect
CAN_IMPLEMENT_ARRAY_ANY = False

# Whether the function TO_NUMBER is supported
SUPPORTS_TO_NUMBER = True

TYPE_MAPPING = {
exp.DataType.Type.NCHAR: "CHAR",
exp.DataType.Type.NVARCHAR: "VARCHAR",
Expand Down Expand Up @@ -3272,6 +3275,10 @@ def tochar_sql(self, expression: exp.ToChar) -> str:

return self.sql(exp.cast(expression.this, "text"))

def tonumber_sql(self, expression: exp.ToNumber) -> str:
self.unsupported("Unsupported TO_NUMBER function, converting to a CAST as DOUBLE")
return self.sql(exp.cast(expression.this, "double"))

def dictproperty_sql(self, expression: exp.DictProperty) -> str:
this = self.sql(expression, "this")
kind = self.sql(expression, "kind")
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,12 @@ def test_bigquery(self):
"spark": "TO_JSON(x)",
},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"bigquery": "CAST(x AS FLOAT64)",
},
)
self.validate_all(
"""SELECT
`u`.`harness_user_email` AS `harness_user_email`,
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,12 @@ def test_clickhouse(self):
write={"clickhouse": "SELECT startsWith('a', 'b'), startsWith('a', 'b')"},
)
self.validate_identity("SYSTEM STOP MERGES foo.bar", check_command_warning=True)
self.validate_all(
"TO_NUMBER(x)",
write={
"clickhouse": "CAST(x AS Float64)",
},
)

def test_cte(self):
self.validate_identity("WITH 'x' AS foo SELECT foo")
Expand Down
20 changes: 20 additions & 0 deletions tests/dialects/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,26 @@ def test_oracle(self):
"oracle": "TO_CHAR(x)",
},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"doris": "CAST(x AS DOUBLE)",
"presto": "CAST(x AS DOUBLE)",
"duckdb": "CAST(x AS DOUBLE)",
"hive": "CAST(x AS DOUBLE)",
"mysql": "CAST(x AS DOUBLE)",
"starrocks": "CAST(x AS DOUBLE)",
"tableau": "CAST(x AS DOUBLE)",
"oracle": "TO_NUMBER(x)",
"snowflake": "TO_NUMBER(x)",
"spark2": "TO_NUMBER(x)",
"databricks": "TO_NUMBER(x)",
"drill": "TO_NUMBER(x)",
"postgres": "TO_NUMBER(x)",
"redshift": "TO_NUMBER(x)",
"teradata": "TO_NUMBER(x)",
},
)
self.validate_all(
"SELECT TO_CHAR(TIMESTAMP '1999-12-01 10:00:00')",
write={
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,12 @@ def test_sqlite(self):
read={"snowflake": "LEAST(x, y, z)"},
write={"snowflake": "LEAST(x, y, z)"},
)
self.validate_all(
"TO_NUMBER(x)",
write={
"sqlite": "CAST(x AS REAL)",
},
)

def test_datediff(self):
self.validate_all(
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,12 @@ def test_tsql(self):
)
self.validate_identity("HASHBYTES('MD2', 'x')")
self.validate_identity("LOG(n, b)")
self.validate_all(
"TO_NUMBER(x)",
write={
"tsql": "CAST(x AS FLOAT)",
},
)

def test_option(self):
possible_options = [
Expand Down

0 comments on commit c4e7bbf

Please sign in to comment.