Skip to content

Commit

Permalink
Fix: correctly generate ArrayJoin in various dialects (#3120)
Browse files Browse the repository at this point in the history
* Fix: correctly generate ArrayJoin in various dialects

* Refactor: rename ArrayJoin to ARRAY_TO_STRING
  • Loading branch information
georgesittas committed Mar 11, 2024
1 parent 14c1dad commit c333017
Show file tree
Hide file tree
Showing 11 changed files with 18 additions and 10 deletions.
4 changes: 2 additions & 2 deletions sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,10 +971,10 @@ def array_join(
) -> Column:
if null_replacement is not None:
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter), null=lit(null_replacement)
col, expression.ArrayToString, expression=lit(delimiter), null=lit(null_replacement)
)
return Column.invoke_expression_over_column(
col, expression.ArrayJoin, expression=lit(delimiter)
col, expression.ArrayToString, expression=lit(delimiter)
)


Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,6 @@ class Generator(generator.Generator):
if e.expressions and e.expressions[0].find(exp.Select)
else inline_array_sql(self, e)
),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArraySize: rename_func("ARRAY_LENGTH"),
exp.ArgMax: arg_max_or_min_no_count("ARG_MAX"),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ class Generator(generator.Generator):
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayJoin: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySize: rename_func("SIZE"),
exp.ArraySort: _array_sort_sql,
exp.With: no_recursive_cte_sql,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ class Generator(generator.Generator):
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
exp.ArraySize: rename_func("CARDINALITY"),
exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.ArrayUniqueAgg: rename_func("SET_AGG"),
exp.AtTimeZone: rename_func("AT_TIMEZONE"),
exp.BitwiseAnd: lambda self, e: self.func("BITWISE_AND", e.this, e.expression),
Expand Down
1 change: 0 additions & 1 deletion sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,6 @@ class Generator(generator.Generator):
exp.Array: inline_array_sql,
exp.ArrayConcat: rename_func("ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func("ARRAY_CONTAINS", e.expression, e.this),
exp.ArrayJoin: rename_func("ARRAY_TO_STRING"),
exp.AtTimeZone: lambda self, e: self.func(
"CONVERT_TIMEZONE", e.args.get("zone"), e.this
),
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class Generator(Hive.Generator):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArraySum: lambda self,
e: f"AGGREGATE({self.sql(e, 'this')}, 0, (acc, x) -> acc + x, acc -> acc)",
exp.ArrayToString: rename_func("ARRAY_JOIN"),
exp.AtTimeZone: lambda self, e: self.func(
"FROM_UTC_TIMESTAMP", e.this, e.args.get("zone")
),
Expand Down Expand Up @@ -252,7 +253,6 @@ class Generator(Hive.Generator):
[transforms.remove_within_group_for_percentiles]
),
}
TRANSFORMS.pop(exp.ArrayJoin)
TRANSFORMS.pop(exp.ArraySort)
TRANSFORMS.pop(exp.ILike)
TRANSFORMS.pop(exp.Left)
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,6 +761,7 @@ class Generator(generator.Generator):
TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: any_value_to_max_sql,
exp.ArrayToString: rename_func("STRING_AGG"),
exp.AutoIncrementColumnConstraint: lambda *_: "IDENTITY",
exp.DateAdd: date_delta_sql("DATEADD"),
exp.DateDiff: date_delta_sql("DATEDIFF"),
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/executor/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def interval(this, unit):


@null_if_any("this", "expression")
def arrayjoin(this, expression, null=None):
def arraytostring(this, expression, null=None):
return expression.join(x for x in (x if x is not None else null for x in this) if x is not None)


Expand Down Expand Up @@ -173,7 +173,7 @@ def jsonextract(this, expression):
"ABS": null_if_any(lambda this: abs(this)),
"ADD": null_if_any(lambda e, this: e + this),
"ARRAYANY": null_if_any(lambda arr, func: any(func(e) for e in arr)),
"ARRAYJOIN": arrayjoin,
"ARRAYTOSTRING": arraytostring,
"BETWEEN": null_if_any(lambda this, low, high: low <= this and this <= high),
"BITWISEAND": null_if_any(lambda this, e: this & e),
"BITWISELEFTSHIFT": null_if_any(lambda this, e: this << e),
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4577,9 +4577,9 @@ class ArrayFilter(Func):
_sql_names = ["FILTER", "ARRAY_FILTER"]


class ArrayJoin(Func):
class ArrayToString(Func):
arg_types = {"this": True, "expression": True, "null": False}
_sql_names = ["ARRAY_JOIN", "ARRAY_TO_STRING"]
_sql_names = ["ARRAY_TO_STRING", "ARRAY_JOIN"]


class ArrayOverlaps(Binary, Func):
Expand Down
1 change: 1 addition & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def test_bigquery(self):
select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")

self.validate_identity("SELECT ARRAY_TO_STRING(list, '--') AS text")
self.validate_identity("SELECT jsondoc['some_key']")
self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`")
self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`")
Expand Down
7 changes: 7 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ def test_duckdb(self):
self.validate_all(
"ARRAY_TO_STRING(arr, delim)",
read={
"bigquery": "ARRAY_TO_STRING(arr, delim)",
"postgres": "ARRAY_TO_STRING(arr, delim)",
"presto": "ARRAY_JOIN(arr, delim)",
"snowflake": "ARRAY_TO_STRING(arr, delim)",
"spark": "ARRAY_JOIN(arr, delim)",
},
write={
"bigquery": "ARRAY_TO_STRING(arr, delim)",
"duckdb": "ARRAY_TO_STRING(arr, delim)",
"postgres": "ARRAY_TO_STRING(arr, delim)",
"presto": "ARRAY_JOIN(arr, delim)",
"snowflake": "ARRAY_TO_STRING(arr, delim)",
"spark": "ARRAY_JOIN(arr, delim)",
"tsql": "STRING_AGG(arr, delim)",
},
)
self.validate_all(
Expand Down

0 comments on commit c333017

Please sign in to comment.