Skip to content

Commit

Permalink
Fix(duckdb): fix JSON pointer path parsing, reduce warning noise (#2911)
Browse files Browse the repository at this point in the history
* Fix(duckdb): fix JSON pointer path parsing, reduce warning noise

* Rename
  • Loading branch information
georgesittas committed Feb 2, 2024
1 parent 3b533c4 commit f3bdcb0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 20 deletions.
18 changes: 17 additions & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import logging
import typing as t
from enum import Enum, auto
from functools import reduce
Expand All @@ -8,7 +9,7 @@
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import generate as generate_json_path
from sqlglot.jsonpath import generate as generate_json_path, parse as parse_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
Expand All @@ -20,6 +21,8 @@
if t.TYPE_CHECKING:
from sqlglot._typing import B, E

logger = logging.getLogger("sqlglot")


class Dialects(str, Enum):
"""Dialects supported by SQLGLot."""
Expand Down Expand Up @@ -441,6 +444,19 @@ def quote_identifier(self, expression: E, identify: bool = True) -> E:

return expression

def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"

try:
return exp.JSONPath(expressions=parse_json_path(path_text))
except ParseError:
logger.warning(f"Invalid JSON path syntax: {path_text}")

return path

def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]:
return self.parser(**opts).parse(self.tokenize(sql), sql)

Expand Down
12 changes: 12 additions & 0 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ class DuckDB(Dialect):
# https://duckdb.org/docs/sql/introduction.html#creating-a-new-table
NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE

def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
# DuckDB also supports the JSON pointer syntax, where every path starts with a `/`.
# Additionally, it allows accessing the back of lists using the `[#-i]` syntax.
# This check ensures we'll avoid trying to parse these as JSON paths, which can
# either result in a noisy warning or in an invalid representation of the path.
path_text = path.name
if path_text.startswith("/") or "[#" in path_text:
return path

return super().to_json_path(path)

class Tokenizer(tokens.Tokenizer):
KEYWORDS = {
**tokens.Tokenizer.KEYWORDS,
Expand Down
26 changes: 7 additions & 19 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from sqlglot import exp
from sqlglot.errors import ErrorLevel, ParseError, concat_messages, merge_errors
from sqlglot.helper import apply_index_offset, ensure_list, seq_get
from sqlglot.jsonpath import parse as _parse_json_path
from sqlglot.time import format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
from sqlglot.trie import TrieResult, in_trie, new_trie
Expand Down Expand Up @@ -61,22 +60,11 @@ def parse_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)


def parse_json_path(path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]:
if isinstance(path, exp.Literal):
path_text = path.name
if path.is_number:
path_text = f"[{path_text}]"
try:
return exp.JSONPath(expressions=_parse_json_path(path_text))
except ParseError:
logger.warning(f"Invalid JSON path syntax: {path_text}")

return path


def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
def _parser(args: t.List) -> E:
expression = expr_type(this=seq_get(args, 0), expression=parse_json_path(seq_get(args, 1)))
def parse_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
def _parser(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
this=seq_get(args, 0), expression=dialect.to_json_path(seq_get(args, 1))
)
if len(args) > 2 and expr_type is exp.JSONExtract:
expression.set("expressions", args[2:])

Expand Down Expand Up @@ -558,12 +546,12 @@ class Parser(metaclass=_Parser):
TokenType.ARROW: lambda self, this, path: self.expression(
exp.JSONExtract,
this=this,
expression=parse_json_path(path),
expression=self.dialect.to_json_path(path),
),
TokenType.DARROW: lambda self, this, path: self.expression(
exp.JSONExtractScalar,
this=this,
expression=parse_json_path(path),
expression=self.dialect.to_json_path(path),
),
TokenType.HASH_ARROW: lambda self, this, path: self.expression(
exp.JSONBExtract,
Expand Down
8 changes: 8 additions & 0 deletions tests/dialects/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ def test_duckdb(self):
},
)

self.validate_identity("""SELECT '{"duck": [1, 2, 3]}' -> '$.duck[#-1]'""")
self.validate_all(
"""SELECT JSON_EXTRACT('{"duck": [1, 2, 3]}', '/duck/0')""",
write={
"": """SELECT JSON_EXTRACT('{"duck": [1, 2, 3]}', '/duck/0')""",
"duckdb": """SELECT '{"duck": [1, 2, 3]}' -> '/duck/0'""",
},
)
self.validate_all(
"""SELECT JSON('{"fruit":"banana"}') -> 'fruit'""",
write={
Expand Down

0 comments on commit f3bdcb0

Please sign in to comment.