Skip to content

Commit

Permalink
Fix!(optimizer): promote typed div to double if operands are non-float (
Browse files Browse the repository at this point in the history
#3094)

* Fix(optimizer): promote typed div to double if operands are non-float

* Refactor FLOAT_TYPES

* Add a test for the generation changes
  • Loading branch information
georgesittas committed Mar 7, 2024
1 parent 4fb74ff commit 3bbcca2
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 17 deletions.
45 changes: 31 additions & 14 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3778,44 +3778,61 @@ class Type(AutoName):
TEXT_TYPES = {
Type.CHAR,
Type.NCHAR,
Type.VARCHAR,
Type.NVARCHAR,
Type.TEXT,
Type.VARCHAR,
}

INTEGER_TYPES = {
Type.INT,
Type.TINYINT,
Type.SMALLINT,
Type.BIGINT,
Type.BIT,
Type.INT,
Type.INT128,
Type.INT256,
Type.BIT,
Type.MEDIUMINT,
Type.SMALLINT,
Type.TINYINT,
Type.UBIGINT,
Type.UINT,
Type.UINT128,
Type.UINT256,
Type.UMEDIUMINT,
Type.USMALLINT,
Type.UTINYINT,
}

FLOAT_TYPES = {
Type.FLOAT,
Type.DOUBLE,
Type.FLOAT,
}

REAL_TYPES = {
*FLOAT_TYPES,
Type.BIGDECIMAL,
Type.DECIMAL,
Type.MONEY,
Type.SMALLMONEY,
Type.UDECIMAL,
}

NUMERIC_TYPES = {
*INTEGER_TYPES,
*FLOAT_TYPES,
*REAL_TYPES,
}

TEMPORAL_TYPES = {
Type.DATE,
Type.DATE32,
Type.DATETIME,
Type.DATETIME64,
Type.TIME,
Type.TIMETZ,
Type.TIMESTAMP,
Type.TIMESTAMPTZ,
Type.TIMESTAMPLTZ,
Type.TIMESTAMP_S,
Type.TIMESTAMPTZ,
Type.TIMESTAMP_MS,
Type.TIMESTAMP_NS,
Type.DATE,
Type.DATE32,
Type.DATETIME,
Type.DATETIME64,
Type.TIMESTAMP_S,
Type.TIMETZ,
}

@classmethod
Expand Down
4 changes: 1 addition & 3 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,9 +2953,7 @@ def div_sql(self, expression: exp.Div) -> str:
r.replace(exp.Nullif(this=r.copy(), expression=exp.Literal.number(0)))

if self.dialect.TYPED_DIVISION and not expression.args.get("typed"):
if not l.is_type(*exp.DataType.FLOAT_TYPES) and not r.is_type(
*exp.DataType.FLOAT_TYPES
):
if not l.is_type(*exp.DataType.REAL_TYPES) and not r.is_type(*exp.DataType.REAL_TYPES):
l.replace(exp.cast(l.copy(), to=exp.DataType.Type.DOUBLE))

elif not self.dialect.TYPED_DIVISION and expression.args.get("typed"):
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,10 @@ def _annotate_div(self, expression: exp.Div) -> exp.Div:
self._set_type(expression, exp.DataType.Type.BIGINT)
else:
self._set_type(expression, self._maybe_coerce(left_type, right_type))
if expression.type and expression.type.this not in exp.DataType.REAL_TYPES:
self._set_type(
expression, self._maybe_coerce(expression.type, exp.DataType.Type.DOUBLE)
)

return expression

Expand Down
5 changes: 5 additions & 0 deletions tests/dialects/test_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,6 +1719,11 @@ def test_safediv(self):
with self.subTest(f"{expression.__class__.__name__} {dialect} -> {expected}"):
self.assertEqual(expected, expression.sql(dialect=dialect))

self.assertEqual(
parse_one("CAST(x AS DECIMAL) / y", read="mysql").sql(dialect="postgres"),
"CAST(x AS DECIMAL) / NULLIF(y, 0)",
)

def test_limit(self):
self.validate_all(
"SELECT * FROM data LIMIT 10, 20",
Expand Down
8 changes: 8 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,14 @@ def test_typeddiv_annotation(self):
self.assertEqual(expressions[0].type.this, exp.DataType.Type.BIGINT)
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DOUBLE)

expressions = annotate_types(
parse_one("SELECT SUM(2 / 3), CAST(2 AS DECIMAL) / 3", dialect="mysql")
).expressions

self.assertEqual(expressions[0].type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expressions[0].this.type.this, exp.DataType.Type.DOUBLE)
self.assertEqual(expressions[1].type.this, exp.DataType.Type.DECIMAL)

def test_bracket_annotation(self):
expression = annotate_types(parse_one("SELECT A[:]")).expressions[0]

Expand Down

0 comments on commit 3bbcca2

Please sign in to comment.