Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: properly parse/generate duckdb MAP {..} syntax, annotate MAPs #3241

Merged
merged 2 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,9 @@ def _parse_json_object(self, agg=False):

return json_object

def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)

if this is bracket:
Expand Down
12 changes: 12 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,11 @@ class Parser(parser.Parser):
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")

NO_PAREN_FUNCTION_PARSERS = {
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
"MAP": lambda self: self._parse_map(),
}

TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
TokenType.ANTI,
Expand All @@ -307,6 +312,13 @@ class Parser(parser.Parser):
),
}

def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())

args = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))

def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5355,6 +5355,11 @@ def values(self) -> t.List[Expression]:
return values.expressions if values else []


# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP
class ToMap(Func):
pass


class MapFromEntries(Func):
pass

Expand Down
97 changes: 49 additions & 48 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class Generator(metaclass=_Generator):
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda *_: "TRANSIENT",
Expand Down Expand Up @@ -2996,30 +2997,6 @@ def havingmax_sql(self, expression: exp.HavingMax) -> str:
kind = "MAX" if expression.args.get("max") else "MIN"
return f"{this_sql} HAVING {kind} {expression_sql}"

def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)

if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)

agg_func = expression.find(exp.AggFunc)

if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"

return f"{self.sql(expression, 'this')} {text}"

def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
Expand Down Expand Up @@ -3565,30 +3542,6 @@ def arrayany_sql(self, expression: exp.ArrayAny) -> str:

return self.function_fallback_sql(expression)

def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""

if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"

this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"

def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""

def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

expression = simplify(expression, dialect=self.dialect)

return expression

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
expression.set("is_end_exclusive", None)
return self.function_fallback_sql(expression)
Expand Down Expand Up @@ -3675,3 +3628,51 @@ def convert_sql(self, expression: exp.Convert) -> str:
transformed = cast(this=value, to=to, safe=safe)

return self.sql(transformed)

def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to see here - just moved these underscore methods to the bottom of the generator file so they're grouped like helper methods. The only change is the TRANSFORMS entry.

this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""

if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"

this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"

def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""

def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

expression = simplify(expression, dialect=self.dialect)

return expression

def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)

if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)

agg_func = expression.find(exp.AggFunc)

if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"

return f"{self.sql(expression, 'this')} {text}"
114 changes: 71 additions & 43 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,22 +271,23 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.VarMap: lambda self, e: self._annotate_map(e),
}

NESTED_TYPES = {
Expand Down Expand Up @@ -428,23 +429,13 @@ def _maybe_coerce(
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN

if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value

return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore

# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004

@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)

left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this
left_type, right_type = left.type.this, right.type.this # type: ignore

if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
Expand All @@ -465,7 +456,6 @@ def _annotate_binary(self, expression: B) -> B:

return expression

@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)

Expand All @@ -476,7 +466,6 @@ def _annotate_unary(self, expression: E) -> E:

return expression

@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
self._set_type(expression, exp.DataType.Type.VARCHAR)
Expand All @@ -487,33 +476,17 @@ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:

return expression

@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)

@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)

# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)

return expression.type

@t.no_type_check
def _annotate_by_args(
self,
expression: E,
*args: str,
promote: bool = False,
array: bool = False,
struct: bool = False,
) -> E:
self._annotate_args(expression)

Expand Down Expand Up @@ -549,16 +522,6 @@ def _annotate_by_args(
),
)

if struct:
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)

return expression

def _annotate_timeunit(
Expand Down Expand Up @@ -638,3 +601,68 @@ def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression

def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)

# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)

return expression.type

def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
self._annotate_args(expression)
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
nested=True,
),
)
return expression

@t.overload
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...

@t.overload
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...

def _annotate_map(self, expression):
self._annotate_args(expression)

keys = expression.args.get("keys")
values = expression.args.get("values")

map_type = exp.DataType(this=exp.DataType.Type.MAP)
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN

if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [key_type, value_type])
map_type.set("nested", True)

self._set_type(expression, map_type)
return expression

def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
self._annotate_args(expression)

map_type = exp.DataType(this=exp.DataType.Type.MAP)
arg = expression.this
if arg.is_type(exp.DataType.Type.STRUCT):
for coldef in arg.type.expressions:
kind = coldef.kind
if kind != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
map_type.set("nested", True)
break

self._set_type(expression, map_type)
return expression
2 changes: 1 addition & 1 deletion sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4789,7 +4789,7 @@ def _parse_primary_key(
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))

def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this

Expand Down
2 changes: 2 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def test_duckdb(self):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
)

self.validate_identity("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])")
self.validate_identity("SELECT MAP {'x': 1}")
self.validate_identity("SELECT df1.*, df2.* FROM df1 POSITIONAL JOIN df2")
self.validate_identity("MAKE_TIMESTAMP(1992, 9, 20, 13, 34, 27.123456)")
self.validate_identity("MAKE_TIMESTAMP(1667810584123456)")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,21 @@ def test_unnest_annotation(self):
exp.DataType.build("date"),
)

def test_map_annotation(self):
# ToMap annotation
expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))

# Map annotation
expression = annotate_types(
parse_one("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])", read="duckdb")
)
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))

# VarMap annotation
expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))

def test_recursive_cte(self):
query = parse_one(
"""
Expand Down