Skip to content

Commit

Permalink
Fix: properly parse/generate duckdb MAP {..} syntax, annotate MAPs (#…
Browse files Browse the repository at this point in the history
…3241)

* Fix: properly parse/generate duckdb MAP {..} syntax, annotate MAPs

* Improve test coverage
  • Loading branch information
georgesittas committed Mar 29, 2024
1 parent 12563ae commit 2a3a5cd
Show file tree
Hide file tree
Showing 8 changed files with 158 additions and 93 deletions.
4 changes: 3 additions & 1 deletion sqlglot/dialects/bigquery.py
Expand Up @@ -536,7 +536,9 @@ def _parse_json_object(self, agg=False):

return json_object

def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
def _parse_bracket(
self, this: t.Optional[exp.Expression] = None
) -> t.Optional[exp.Expression]:
bracket = super()._parse_bracket(this)

if this is bracket:
Expand Down
12 changes: 12 additions & 0 deletions sqlglot/dialects/duckdb.py
Expand Up @@ -293,6 +293,11 @@ class Parser(parser.Parser):
FUNCTION_PARSERS = parser.Parser.FUNCTION_PARSERS.copy()
FUNCTION_PARSERS.pop("DECODE")

NO_PAREN_FUNCTION_PARSERS = {
**parser.Parser.NO_PAREN_FUNCTION_PARSERS,
"MAP": lambda self: self._parse_map(),
}

TABLE_ALIAS_TOKENS = parser.Parser.TABLE_ALIAS_TOKENS - {
TokenType.SEMI,
TokenType.ANTI,
Expand All @@ -307,6 +312,13 @@ class Parser(parser.Parser):
),
}

def _parse_map(self) -> exp.ToMap | exp.Map:
if self._match(TokenType.L_BRACE, advance=False):
return self.expression(exp.ToMap, this=self._parse_bracket())

args = self._parse_wrapped_csv(self._parse_conjunction)
return self.expression(exp.Map, keys=seq_get(args, 0), values=seq_get(args, 1))

def _parse_types(
self, check_func: bool = False, schema: bool = False, allow_identifiers: bool = True
) -> t.Optional[exp.Expression]:
Expand Down
5 changes: 5 additions & 0 deletions sqlglot/expressions.py
Expand Up @@ -5355,6 +5355,11 @@ def values(self) -> t.List[Expression]:
return values.expressions if values else []


# Represents the MAP {...} syntax in DuckDB - basically convert a struct to a MAP
class ToMap(Func):
pass


class MapFromEntries(Func):
pass

Expand Down
97 changes: 49 additions & 48 deletions sqlglot/generator.py
Expand Up @@ -137,6 +137,7 @@ class Generator(metaclass=_Generator):
exp.TemporaryProperty: lambda *_: "TEMPORARY",
exp.TitleColumnConstraint: lambda self, e: f"TITLE {self.sql(e, 'this')}",
exp.Timestamp: lambda self, e: self.func("TIMESTAMP", e.this, e.expression),
exp.ToMap: lambda self, e: f"MAP {self.sql(e, 'this')}",
exp.ToTableProperty: lambda self, e: f"TO {self.sql(e.this)}",
exp.TransformModelProperty: lambda self, e: self.func("TRANSFORM", *e.expressions),
exp.TransientProperty: lambda *_: "TRANSIENT",
Expand Down Expand Up @@ -2996,30 +2997,6 @@ def havingmax_sql(self, expression: exp.HavingMax) -> str:
kind = "MAX" if expression.args.get("max") else "MIN"
return f"{this_sql} HAVING {kind} {expression_sql}"

def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)

if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)

agg_func = expression.find(exp.AggFunc)

if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"

return f"{self.sql(expression, 'this')} {text}"

def intdiv_sql(self, expression: exp.IntDiv) -> str:
return self.sql(
exp.Cast(
Expand Down Expand Up @@ -3572,30 +3549,6 @@ def arrayany_sql(self, expression: exp.ArrayAny) -> str:

return self.function_fallback_sql(expression)

def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""

if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"

this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"

def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""

def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

expression = simplify(expression, dialect=self.dialect)

return expression

def generateseries_sql(self, expression: exp.GenerateSeries) -> str:
expression.set("is_end_exclusive", None)
return self.function_fallback_sql(expression)
Expand Down Expand Up @@ -3682,3 +3635,51 @@ def convert_sql(self, expression: exp.Convert) -> str:
transformed = cast(this=value, to=to, safe=safe)

return self.sql(transformed)

def _jsonpathkey_sql(self, expression: exp.JSONPathKey) -> str:
this = expression.this
if isinstance(this, exp.JSONPathWildcard):
this = self.json_path_part(this)
return f".{this}" if this else ""

if exp.SAFE_IDENTIFIER_RE.match(this):
return f".{this}"

this = self.json_path_part(this)
return f"[{this}]" if self.JSON_PATH_BRACKETED_KEY_SUPPORTED else f".{this}"

def _jsonpathsubscript_sql(self, expression: exp.JSONPathSubscript) -> str:
this = self.json_path_part(expression.this)
return f"[{this}]" if this else ""

def _simplify_unless_literal(self, expression: E) -> E:
if not isinstance(expression, exp.Literal):
from sqlglot.optimizer.simplify import simplify

expression = simplify(expression, dialect=self.dialect)

return expression

def _embed_ignore_nulls(self, expression: exp.IgnoreNulls | exp.RespectNulls, text: str) -> str:
if self.IGNORE_NULLS_IN_FUNC and not expression.meta.get("inline"):
# The first modifier here will be the one closest to the AggFunc's arg
mods = sorted(
expression.find_all(exp.HavingMax, exp.Order, exp.Limit),
key=lambda x: 0
if isinstance(x, exp.HavingMax)
else (1 if isinstance(x, exp.Order) else 2),
)

if mods:
mod = mods[0]
this = expression.__class__(this=mod.this.copy())
this.meta["inline"] = True
mod.this.replace(this)
return self.sql(expression.this)

agg_func = expression.find(exp.AggFunc)

if agg_func:
return self.sql(agg_func)[:-1] + f" {text})"

return f"{self.sql(expression, 'this')} {text}"
114 changes: 71 additions & 43 deletions sqlglot/optimizer/annotate_types.py
Expand Up @@ -271,22 +271,23 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Literal: lambda self, e: self._annotate_literal(e),
exp.Map: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.Map: lambda self, e: self._annotate_map(e),
exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"),
exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL),
exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"),
exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"),
exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN),
exp.Struct: lambda self, e: self._annotate_by_args(e, "expressions", struct=True),
exp.Struct: lambda self, e: self._annotate_struct(e),
exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True),
exp.Timestamp: lambda self, e: self._annotate_with_type(
e,
exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP,
),
exp.ToMap: lambda self, e: self._annotate_to_map(e),
exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]),
exp.Unnest: lambda self, e: self._annotate_unnest(e),
exp.VarMap: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.MAP),
exp.VarMap: lambda self, e: self._annotate_map(e),
}

NESTED_TYPES = {
Expand Down Expand Up @@ -428,23 +429,13 @@ def _maybe_coerce(
if exp.DataType.Type.UNKNOWN in (type1_value, type2_value):
return exp.DataType.Type.UNKNOWN

if type1_value in self.NESTED_TYPES:
return type1
if type2_value in self.NESTED_TYPES:
return type2
return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value

return type2_value if type2_value in self.coerces_to.get(type1_value, {}) else type1_value # type: ignore

# Note: the following "no_type_check" decorators were added because mypy was yelling due
# to assigning Type values to expression.type (since its getter returns Optional[DataType]).
# This is a known mypy issue: https://github.com/python/mypy/issues/3004

@t.no_type_check
def _annotate_binary(self, expression: B) -> B:
self._annotate_args(expression)

left, right = expression.left, expression.right
left_type, right_type = left.type.this, right.type.this
left_type, right_type = left.type.this, right.type.this # type: ignore

if isinstance(expression, exp.Connector):
if left_type == exp.DataType.Type.NULL and right_type == exp.DataType.Type.NULL:
Expand All @@ -465,7 +456,6 @@ def _annotate_binary(self, expression: B) -> B:

return expression

@t.no_type_check
def _annotate_unary(self, expression: E) -> E:
self._annotate_args(expression)

Expand All @@ -476,7 +466,6 @@ def _annotate_unary(self, expression: E) -> E:

return expression

@t.no_type_check
def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:
if expression.is_string:
self._set_type(expression, exp.DataType.Type.VARCHAR)
Expand All @@ -487,33 +476,17 @@ def _annotate_literal(self, expression: exp.Literal) -> exp.Literal:

return expression

@t.no_type_check
def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) -> E:
self._set_type(expression, target_type)
return self._annotate_args(expression)

@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)

# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)

return expression.type

@t.no_type_check
def _annotate_by_args(
self,
expression: E,
*args: str,
promote: bool = False,
array: bool = False,
struct: bool = False,
) -> E:
self._annotate_args(expression)

Expand Down Expand Up @@ -549,16 +522,6 @@ def _annotate_by_args(
),
)

if struct:
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)

return expression

def _annotate_timeunit(
Expand Down Expand Up @@ -638,3 +601,68 @@ def _annotate_unnest(self, expression: exp.Unnest) -> exp.Unnest:
child = seq_get(expression.expressions, 0)
self._set_type(expression, child and seq_get(child.type.expressions, 0))
return expression

def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)

# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)

return expression.type

def _annotate_struct(self, expression: exp.Struct) -> exp.Struct:
self._annotate_args(expression)
self._set_type(
expression,
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expression.expressions],
nested=True,
),
)
return expression

@t.overload
def _annotate_map(self, expression: exp.Map) -> exp.Map: ...

@t.overload
def _annotate_map(self, expression: exp.VarMap) -> exp.VarMap: ...

def _annotate_map(self, expression):
self._annotate_args(expression)

keys = expression.args.get("keys")
values = expression.args.get("values")

map_type = exp.DataType(this=exp.DataType.Type.MAP)
if isinstance(keys, exp.Array) and isinstance(values, exp.Array):
key_type = seq_get(keys.type.expressions, 0) or exp.DataType.Type.UNKNOWN
value_type = seq_get(values.type.expressions, 0) or exp.DataType.Type.UNKNOWN

if key_type != exp.DataType.Type.UNKNOWN and value_type != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [key_type, value_type])
map_type.set("nested", True)

self._set_type(expression, map_type)
return expression

def _annotate_to_map(self, expression: exp.ToMap) -> exp.ToMap:
self._annotate_args(expression)

map_type = exp.DataType(this=exp.DataType.Type.MAP)
arg = expression.this
if arg.is_type(exp.DataType.Type.STRUCT):
for coldef in arg.type.expressions:
kind = coldef.kind
if kind != exp.DataType.Type.UNKNOWN:
map_type.set("expressions", [exp.DataType.build("varchar"), kind])
map_type.set("nested", True)
break

self._set_type(expression, map_type)
return expression
2 changes: 1 addition & 1 deletion sqlglot/parser.py
Expand Up @@ -4789,7 +4789,7 @@ def _parse_primary_key(
def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]:
return self._parse_slice(self._parse_alias(self._parse_conjunction(), explicit=True))

def _parse_bracket(self, this: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
def _parse_bracket(self, this: t.Optional[exp.Expression] = None) -> t.Optional[exp.Expression]:
if not self._match_set((TokenType.L_BRACKET, TokenType.L_BRACE)):
return this

Expand Down
2 changes: 2 additions & 0 deletions tests/dialects/test_duckdb.py
Expand Up @@ -238,6 +238,8 @@ def test_duckdb(self):
parse_one("a // b", read="duckdb").assert_is(exp.IntDiv).sql(dialect="duckdb"), "a // b"
)

self.validate_identity("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])")
self.validate_identity("SELECT MAP {'x': 1}")
self.validate_identity("SELECT df1.*, df2.* FROM df1 POSITIONAL JOIN df2")
self.validate_identity("MAKE_TIMESTAMP(1992, 9, 20, 13, 34, 27.123456)")
self.validate_identity("MAKE_TIMESTAMP(1667810584123456)")
Expand Down
15 changes: 15 additions & 0 deletions tests/test_optimizer.py
Expand Up @@ -1119,6 +1119,21 @@ def test_unnest_annotation(self):
exp.DataType.build("date"),
)

def test_map_annotation(self):
# ToMap annotation
expression = annotate_types(parse_one("SELECT MAP {'x': 1}", read="duckdb"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))

# Map annotation
expression = annotate_types(
parse_one("SELECT MAP(['key1', 'key2', 'key3'], [10, 20, 30])", read="duckdb")
)
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, INT)"))

# VarMap annotation
expression = annotate_types(parse_one("SELECT MAP('a', 'b')", read="spark"))
self.assertEqual(expression.selects[0].type, exp.DataType.build("MAP(VARCHAR, VARCHAR)"))

def test_recursive_cte(self):
query = parse_one(
"""
Expand Down

0 comments on commit 2a3a5cd

Please sign in to comment.