Skip to content

Commit

Permalink
feat: add typing for explode closes #2927
Browse files Browse the repository at this point in the history
  • Loading branch information
tobymao committed Feb 8, 2024
1 parent 1842c96 commit 9241858
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
16 changes: 10 additions & 6 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ class TypeAnnotator(metaclass=_TypeAnnotator):
exp.DateTrunc: lambda self, e: self._annotate_timeunit(e),
exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"),
exp.Div: lambda self, e: self._annotate_div(e),
exp.Explode: lambda self, e: self._annotate_explode(e),
exp.Filter: lambda self, e: self._annotate_by_args(e, "this"),
exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"),
exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL),
Expand Down Expand Up @@ -333,9 +334,9 @@ def __init__(
self._visited: t.Set[int] = set()

def _set_type(
self, expression: exp.Expression, target_type: exp.DataType | exp.DataType.Type
self, expression: exp.Expression, target_type: t.Optional[exp.DataType | exp.DataType.Type]
) -> None:
expression.type = target_type # type: ignore
expression.type = target_type or exp.DataType.Type.UNKNOWN # type: ignore
self._visited.add(id(expression))

def annotate(self, expression: E) -> E:
Expand Down Expand Up @@ -564,13 +565,11 @@ def _annotate_bracket(self, expression: exp.Bracket) -> exp.Bracket:
if isinstance(bracket_arg, exp.Slice):
self._set_type(expression, this.type)
elif this.type.is_type(exp.DataType.Type.ARRAY):
contained_type = seq_get(this.type.expressions, 0) or exp.DataType.Type.UNKNOWN
self._set_type(expression, contained_type)
self._set_type(expression, seq_get(this.type.expressions, 0))
elif isinstance(this, (exp.Map, exp.VarMap)) and bracket_arg in this.keys:
index = this.keys.index(bracket_arg)
value = seq_get(this.values, index)
value_type = value.type if value else exp.DataType.Type.UNKNOWN
self._set_type(expression, value_type or exp.DataType.Type.UNKNOWN)
self._set_type(expression, value.type if value else None)
else:
self._set_type(expression, exp.DataType.Type.UNKNOWN)

Expand All @@ -591,3 +590,8 @@ def _annotate_div(self, expression: exp.Div) -> exp.Div:
self._set_type(expression, self._maybe_coerce(left_type, right_type))

return expression

def _annotate_explode(self, expression: exp.Explode) -> exp.Explode:
self._annotate_args(expression)
self._set_type(expression, seq_get(expression.this.type.expressions, 0))
return expression
6 changes: 6 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,12 @@ def test_nested_type_annotation(self):

self.assertEqual(expression.selects[0].type.sql(dialect="bigquery"), "STRUCT<`f` STRING>")

expression = annotate_types(
parse_one("SELECT unnest(t.x) FROM t AS t", dialect="postgres"),
schema={"t": {"x": "array<int>"}},
)
self.assertTrue(expression.selects[0].is_type("int"))

def test_type_annotation_cache(self):
sql = "SELECT 1 + 1"
expression = annotate_types(parse_one(sql))
Expand Down

0 comments on commit 9241858

Please sign in to comment.