Skip to content

Commit

Permalink
Feat: Supporting RANGE <-> GENERATE_SERIES between DuckDB & SQLite (#…
Browse files Browse the repository at this point in the history
…3010)

* Transpiling RANGE <-> GENERATE_SERIES between DuckDB & SQLite

* Apply George's suggestions from 1st iteration

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>

* Finishing addressing comments of 1st iteration

* Second iteration comments

* Rewrite args[...] = using set method

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
VaggelisD and georgesittas committed Feb 22, 2024
1 parent 0e78a79 commit e50609b
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 2 deletions.
25 changes: 25 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,21 @@ def _build_date_diff(args: t.List) -> exp.Expression:
return exp.DateDiff(this=seq_get(args, 2), expression=seq_get(args, 1), unit=seq_get(args, 0))


def _build_generate_series(end_exclusive: bool = False) -> t.Callable[[t.List], exp.GenerateSeries]:
def _builder(args: t.List) -> exp.GenerateSeries:
# Check https://duckdb.org/docs/sql/functions/nested.html#range-functions
if len(args) == 1:
# DuckDB uses 0 as a default for the series' start when it's omitted
args.insert(0, exp.Literal.number("0"))

gen_series = exp.GenerateSeries.from_arg_list(args)
gen_series.set("is_end_exclusive", end_exclusive)

return gen_series

return _builder


def _build_make_timestamp(args: t.List) -> exp.Expression:
if len(args) == 1:
return exp.UnixToTime(this=seq_get(args, 0), scale=exp.UnixToTime.MICROS)
Expand Down Expand Up @@ -267,6 +282,8 @@ class Parser(parser.Parser):
"TO_TIMESTAMP": exp.UnixToTime.from_arg_list,
"UNNEST": exp.Explode.from_arg_list,
"XOR": binary_from_function(exp.BitwiseXor),
"GENERATE_SERIES": _build_generate_series(),
"RANGE": _build_generate_series(end_exclusive=True),
}

FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
Expand Down Expand Up @@ -548,3 +565,11 @@ def join_sql(self, expression: exp.Join) -> str:
return super().join_sql(expression.on(exp.true()))

return super().join_sql(expression)

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
# GENERATE_SERIES(a, b) -> [a, b], RANGE(a, b) -> [a, b)
if expression.args.get("is_end_exclusive"):
expression.set("is_end_exclusive", None)
return rename_func("RANGE")(self, expression)

return super().generateseries_sql(expression)
16 changes: 16 additions & 0 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class Generator(generator.Generator):
NVL2_SUPPORTED = False
JSON_PATH_BRACKETED_KEY_SUPPORTED = False
SUPPORTS_CREATE_TABLE_LIKE = False
SUPPORTS_TABLE_ALIAS_COLUMNS = False

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathKey,
Expand Down Expand Up @@ -173,6 +174,21 @@ def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) ->

return super().cast_sql(expression)

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
parent = expression.parent
alias = parent and parent.args.get("alias")

if isinstance(alias, exp.TableAlias) and alias.columns:
column_alias = alias.columns[0]
alias.set("columns", None)
sql = self.sql(
exp.select(exp.alias_("value", column_alias)).from_(expression).subquery()
)
else:
sql = super().generateseries_sql(expression)

return sql

def datediff_sql(self, expression: exp.DateDiff) -> str:
unit = expression.args.get("unit")
unit = unit.name.upper() if unit else "DAY"
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4434,7 +4434,7 @@ class ToChar(Func):


class GenerateSeries(Func):
arg_types = {"start": True, "end": True, "step": False}
arg_types = {"start": True, "end": True, "step": False, "is_end_exclusive": False}


class ArrayAgg(AggFunc):
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3454,3 +3454,7 @@ def _ensure_string_if_null(self, values: t.List[exp.Expression]) -> t.List[exp.E
for value in values
if value
]

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
expression.set("is_end_exclusive", None)
return self.function_fallback_sql(expression)
22 changes: 21 additions & 1 deletion tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def test_duckdb(self):
},
)

self.validate_identity("SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC")
self.validate_identity("INSERT INTO x BY NAME SELECT 1 AS y")
self.validate_identity("SELECT 1 AS x UNION ALL BY NAME SELECT 2 AS x")
self.validate_identity("SELECT SUM(x) FILTER (x = 1)", "SELECT SUM(x) FILTER(WHERE x = 1)")
Expand Down Expand Up @@ -626,6 +625,27 @@ def test_duckdb(self):
},
)

self.validate_identity("SELECT * FROM RANGE(1, 5, 10)")
self.validate_identity("SELECT * FROM GENERATE_SERIES(2, 13, 4)")

self.validate_all(
"WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) t(i)) SELECT * FROM t",
write={
"duckdb": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM RANGE(1, 5) AS t(i)) SELECT * FROM t",
"sqlite": "WITH t AS (SELECT i, i * i * i * i * i AS i5 FROM (SELECT value AS i FROM GENERATE_SERIES(1, 5)) AS t) SELECT * FROM t",
},
)

self.validate_identity(
"""SELECT i FROM RANGE(5) AS _(i) ORDER BY i ASC""",
"""SELECT i FROM RANGE(0, 5) AS _(i) ORDER BY i ASC""",
)

self.validate_identity(
"""SELECT i FROM GENERATE_SERIES(12) AS _(i) ORDER BY i ASC""",
"""SELECT i FROM GENERATE_SERIES(0, 12) AS _(i) ORDER BY i ASC""",
)

def test_array_index(self):
with self.assertLogs(helper_logger) as cm:
self.validate_all(
Expand Down
12 changes: 12 additions & 0 deletions tests/dialects/test_sqlite.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from tests.dialects.test_dialect import Validator

from sqlglot.helper import logger as helper_logger


class TestSQLite(Validator):
dialect = "sqlite"
Expand Down Expand Up @@ -76,6 +78,7 @@ def test_sqlite(self):
self.validate_identity(
"""SELECT item AS "item", some AS "some" FROM data WHERE (item = 'value_1' COLLATE NOCASE) AND (some = 't' COLLATE NOCASE) ORDER BY item ASC LIMIT 1 OFFSET 0"""
)
self.validate_identity("SELECT * FROM GENERATE_SERIES(1, 5)")

self.validate_all("SELECT LIKE(y, x)", write={"sqlite": "SELECT x LIKE y"})
self.validate_all("SELECT GLOB('*y*', 'xyz')", write={"sqlite": "SELECT 'xyz' GLOB '*y*'"})
Expand Down Expand Up @@ -178,3 +181,12 @@ def test_longvarchar_dtype(self):
"CREATE TABLE foo (bar LONGVARCHAR)",
write={"sqlite": "CREATE TABLE foo (bar TEXT)"},
)

def test_warnings(self):
with self.assertLogs(helper_logger) as cm:
self.validate_identity(
"SELECT * FROM t AS t(c1, c2)",
"SELECT * FROM t AS t",
)

self.assertIn("Named columns are not supported in table alias.", cm.output[0])

0 comments on commit e50609b

Please sign in to comment.