Skip to content

Commit

Permalink
Feat: improve transpilation of Doris' MONTHS_ADD (#3012)
Browse files Browse the repository at this point in the history
  • Loading branch information
georgesittas committed Feb 22, 2024
1 parent 2a88e40 commit d2e15ed
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def date_diff(end: ColumnOrName, start: ColumnOrName) -> Column:


def add_months(start: ColumnOrName, months: t.Union[ColumnOrName, int]) -> Column:
return Column.invoke_anonymous_function(start, "ADD_MONTHS", months)
return Column.invoke_expression_over_column(start, expression.AddMonths, expression=months)


def months_between(
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Parser(MySQL.Parser):
**MySQL.Parser.FUNCTIONS,
"COLLECT_SET": exp.ArrayUniqueAgg.from_arg_list,
"DATE_TRUNC": build_timestamp_trunc,
"MONTHS_ADD": exp.AddMonths.from_arg_list,
"REGEXP": exp.RegexpLike.from_arg_list,
"TO_DATE": exp.TsOrDsToDate.from_arg_list,
}
Expand All @@ -41,6 +42,7 @@ class Generator(MySQL.Generator):

TRANSFORMS = {
**MySQL.Generator.TRANSFORMS,
exp.AddMonths: rename_func("MONTHS_ADD"),
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5197,6 +5197,10 @@ class Month(Func):
pass


class AddMonths(Func):
arg_types = {"this": True, "expression": True}


class Nvl2(Func):
arg_types = {"this": True, "true": True, "false": False}

Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_doris.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,16 @@ def test_doris(self):
"doris": "SELECT ARRAY_SUM(x -> x * x, ARRAY(2, 3))",
},
)
self.validate_all(
"MONTHS_ADD(d, n)",
read={
"oracle": "ADD_MONTHS(d, n)",
},
write={
"doris": "MONTHS_ADD(d, n)",
"oracle": "ADD_MONTHS(d, n)",
},
)

def test_identity(self):
self.validate_identity("COALECSE(a, b, c, d)")
Expand Down
1 change: 1 addition & 0 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def test_functions(self):
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)

def test_column(self):
column = parse_one("a.b.c.d")
Expand Down

0 comments on commit d2e15ed

Please sign in to comment.