Skip to content

Commit

Permalink
Refactor!: improve transpilation of JSON paths across dialects (#2883)
Browse files Browse the repository at this point in the history
* Refactor!: improve transpilation of JSON paths across dialects

* Improve test coverage

* Catch json path parse errors so that the sql parser is more lenient

* Add support for more DuckDB variants

* Add more tests

* Reduce multi-arg MySQL JSON_EXTRACT to single-arg

* Add a single-arg test

* Add support for Postgres' JSON_EXTRACT_PATH variant

* Add support for T-SQL's JSON_QUERY/VALUE

* Fix MySQL, handle SQLite multi-arg variant

* PR feedback

* Get rid of copy

* Refactor JSON representation: create proper AST (subclass of Expression)

* Type hint TRANSFORMS

* Style

* Get rid of special parsing logic for MySQL, SQLite, produce multi-arg JSONExtract instead

* Nested path tests, unsupported tests, json path mapping refactor

* Use lstrip

* Factor out is_int

* Move is_int to helper.py

* arrow_json_extract_sql cleanup

* Cleanup expression arg_types

* Improve test

* Add another test
  • Loading branch information
georgesittas committed Jan 30, 2024
1 parent 2fb1e9f commit b4e8868
Show file tree
Hide file tree
Showing 24 changed files with 905 additions and 450 deletions.
7 changes: 7 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,6 +623,13 @@ class Generator(generator.Generator):
exp.VariancePop: rename_func("VAR_POP"),
}

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathChild,
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIGDECIMAL: "BIGNUMERIC",
Expand Down
61 changes: 53 additions & 8 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from sqlglot import exp
from sqlglot.errors import ParseError
from sqlglot.generator import Generator
from sqlglot.helper import AutoName, flatten, seq_get
from sqlglot.helper import AutoName, flatten, is_int, seq_get
from sqlglot.jsonpath import generate as generate_json_path
from sqlglot.parser import Parser
from sqlglot.time import TIMEZONES, format_time
from sqlglot.tokens import Token, Tokenizer, TokenType
Expand Down Expand Up @@ -500,14 +501,14 @@ def _if_sql(self: Generator, expression: exp.If) -> str:
return _if_sql


def arrow_json_extract_sql(self: Generator, expression: exp.JSONExtract | exp.JSONBExtract) -> str:
return self.binary(expression, "->")


def arrow_json_extract_scalar_sql(
self: Generator, expression: exp.JSONExtractScalar | exp.JSONBExtractScalar
def arrow_json_extract_sql(
self: Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
return self.binary(expression, "->>")
this = expression.this
if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string:
this.replace(exp.cast(this, "json"))

return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")


def inline_array_sql(self: Generator, expression: exp.Array) -> str:
Expand Down Expand Up @@ -1023,3 +1024,47 @@ def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str:
)

return self.merge_sql(expression)


def parse_json_extract_path(
expr_type: t.Type[E],
supports_null_if_invalid: bool = False,
) -> t.Callable[[t.List], E]:
def _parse_json_extract_path(args: t.List) -> E:
null_if_invalid = None

segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()]
for arg in args[1:]:
if isinstance(arg, exp.Literal):
text = arg.name
if is_int(text):
segments.append(exp.JSONPathSubscript(this=int(text)))
else:
segments.append(exp.JSONPathChild(this=text))
elif supports_null_if_invalid:
null_if_invalid = arg

this = seq_get(args, 0)
jsonpath = exp.JSONPath(expressions=segments)

# This is done to avoid failing in the expression validator due to the arg count
del args[2:]

if expr_type is exp.JSONExtractScalar:
return expr_type(this=this, expression=jsonpath, null_if_invalid=null_if_invalid)

return expr_type(this=this, expression=jsonpath)

return _parse_json_extract_path


def json_path_segments(self: Generator, expression: exp.JSONPath) -> t.List[str]:
segments = []
for segment in expression.expressions:
path = generate_json_path(
segment, mapping=self._JSON_PATH_MAPPING, unsupported_callback=self.unsupported
)
if path:
segments.append(f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}")

return segments
16 changes: 12 additions & 4 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
NormalizationStrategy,
approx_count_distinct_sql,
arg_max_or_min_no_count,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
binary_from_function,
bool_xor_sql,
Expand Down Expand Up @@ -229,6 +228,9 @@ class Parser(parser.Parser):
this=seq_get(args, 0), scale=exp.UnixToTime.MILLIS
),
"JSON": exp.ParseJSON.from_arg_list,
"JSON_EXTRACT_PATH": parser.parse_extract_json_with_path(exp.JSONExtract),
"JSON_EXTRACT_STRING": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
"JSON_EXTRACT_PATH_TEXT": parser.parse_extract_json_with_path(exp.JSONExtractScalar),
"LIST_HAS": exp.ArrayContains.from_arg_list,
"LIST_REVERSE_SORT": _sort_array_reverse,
"LIST_SORT": exp.SortArray.from_arg_list,
Expand Down Expand Up @@ -358,10 +360,8 @@ class Generator(generator.Generator):
exp.IntDiv: lambda self, e: self.binary(e, "//"),
exp.IsInf: rename_func("ISINF"),
exp.IsNan: rename_func("ISNAN"),
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.JSONFormat: _json_format_sql,
exp.LogicalOr: rename_func("BOOL_OR"),
exp.LogicalAnd: rename_func("BOOL_AND"),
Expand Down Expand Up @@ -423,6 +423,14 @@ class Generator(generator.Generator):
exp.Xor: bool_xor_sql,
}

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathChild,
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
exp.JSONPathWildcard,
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BINARY: "BLOB",
Expand Down
8 changes: 8 additions & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,14 @@ class Generator(generator.Generator):
exp.Union,
}

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathChild,
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
exp.JSONPathWildcard,
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BIT: "BOOLEAN",
Expand Down
5 changes: 3 additions & 2 deletions sqlglot/dialects/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from sqlglot.dialects.dialect import (
Dialect,
NormalizationStrategy,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
date_add_interval_sql,
datestrtodate_sql,
format_time_lambda,
Expand Down Expand Up @@ -630,6 +630,7 @@ class Generator(generator.Generator):
VALUES_AS_TABLE = False
NVL2_SUPPORTED = False
LAST_DAY_SUPPORTS_DATE_PART = False
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
JSON_KEY_VALUE_PAIR_SEP = ","

TRANSFORMS = {
Expand All @@ -649,7 +650,7 @@ class Generator(generator.Generator):
exp.GetPath: path_to_jsonpath(),
exp.GroupConcat: lambda self, e: f"""GROUP_CONCAT({self.sql(e, "this")} SEPARATOR {self.sql(e, "separator") or "','"})""",
exp.ILike: no_ilike_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Max: max_or_greatest,
exp.Min: min_or_least,
exp.Month: _remove_ts_or_ds_to_date(),
Expand Down
37 changes: 32 additions & 5 deletions sqlglot/dialects/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,10 @@
DATE_ADD_OR_SUB,
Dialect,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
bool_xor_sql,
datestrtodate_sql,
format_time_lambda,
json_path_segments,
max_or_greatest,
merge_without_target_sql,
min_or_least,
Expand All @@ -20,6 +19,7 @@
no_paren_current_date_sql,
no_pivot_sql,
no_trycast_sql,
parse_json_extract_path,
parse_timestamp_trunc,
rename_func,
str_position_sql,
Expand Down Expand Up @@ -188,6 +188,20 @@ def _to_timestamp(args: t.List) -> exp.Expression:
return format_time_lambda(exp.StrToTime, "postgres")(args)


def _json_extract_sql(
self: Postgres.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
return self.func(
(
"JSON_EXTRACT_PATH"
if isinstance(expression, exp.JSONExtract)
else "JSON_EXTRACT_PATH_TEXT"
),
expression.this,
*json_path_segments(self, expression.expression),
)


class Postgres(Dialect):
INDEX_OFFSET = 1
TYPED_DIVISION = True
Expand Down Expand Up @@ -292,6 +306,8 @@ class Parser(parser.Parser):
**parser.Parser.FUNCTIONS,
"DATE_TRUNC": parse_timestamp_trunc,
"GENERATE_SERIES": _generate_series,
"JSON_EXTRACT_PATH": parse_json_extract_path(exp.JSONExtract),
"JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(exp.JSONExtractScalar),
"MAKE_TIME": exp.TimeFromParts.from_arg_list,
"MAKE_TIMESTAMP": exp.TimestampFromParts.from_arg_list,
"NOW": exp.CurrentTimestamp.from_arg_list,
Expand Down Expand Up @@ -375,9 +391,20 @@ class Generator(generator.Generator):
TABLESAMPLE_SIZE_IS_ROWS = False
TABLESAMPLE_SEED_KEYWORD = "REPEATABLE"
SUPPORTS_SELECT_INTO = True
# https://www.postgresql.org/docs/current/sql-createtable.html
JSON_TYPE_REQUIRED_FOR_EXTRACTION = True
SUPPORTS_UNLOGGED_TABLES = True

JSON_PATH_MAPPING = {
exp.JSONPathChild: lambda n, **kwargs: n.name,
exp.JSONPathKey: lambda n, **kwargs: n.name,
exp.JSONPathRoot: lambda n, **kwargs: "",
exp.JSONPathSubscript: lambda n, **kwargs: generator.generate_json_path(
n.this, **kwargs
),
}

SUPPORTED_JSON_PATH_PARTS = set(JSON_PATH_MAPPING)

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TINYINT: "SMALLINT",
Expand Down Expand Up @@ -412,8 +439,8 @@ class Generator(generator.Generator):
exp.DateSub: _date_add_sql("-"),
exp.Explode: rename_func("UNNEST"),
exp.GroupConcat: _string_agg_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: _json_extract_sql,
exp.JSONBExtract: lambda self, e: self.binary(e, "#>"),
exp.JSONBExtractScalar: lambda self, e: self.binary(e, "#>>"),
exp.JSONBContains: lambda self, e: self.binary(e, "?"),
Expand Down
27 changes: 23 additions & 4 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
concat_ws_to_dpipe_sql,
date_delta_sql,
generatedasidentitycolumnconstraint_sql,
json_path_segments,
no_tablesample_sql,
parse_json_extract_path,
rename_func,
)
from sqlglot.dialects.postgres import Postgres
Expand All @@ -20,8 +22,15 @@
from sqlglot._typing import E


def _json_sql(self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar) -> str:
return f'{self.sql(expression, "this")}."{expression.expression.name}"'
def _json_extract_sql(
self: Redshift.Generator, expression: exp.JSONExtract | exp.JSONExtractScalar
) -> str:
return self.func(
"JSON_EXTRACT_PATH_TEXT",
expression.this,
*json_path_segments(self, expression.expression),
expression.args.get("null_if_invalid"),
)


def _parse_date_delta(expr_type: t.Type[E]) -> t.Callable[[t.List], E]:
Expand Down Expand Up @@ -62,6 +71,9 @@ class Parser(Postgres.Parser):
"DATE_ADD": _parse_date_delta(exp.TsOrDsAdd),
"DATEDIFF": _parse_date_delta(exp.TsOrDsDiff),
"DATE_DIFF": _parse_date_delta(exp.TsOrDsDiff),
"JSON_EXTRACT_PATH_TEXT": parse_json_extract_path(
exp.JSONExtractScalar, supports_null_if_invalid=True
),
"LISTAGG": exp.GroupConcat.from_arg_list,
"STRTOL": exp.FromBase.from_arg_list,
}
Expand Down Expand Up @@ -197,8 +209,8 @@ class Generator(Postgres.Generator):
exp.DistStyleProperty: lambda self, e: self.naked_property(e),
exp.FromBase: rename_func("STRTOL"),
exp.GeneratedAsIdentityColumnConstraint: generatedasidentitycolumnconstraint_sql,
exp.JSONExtract: _json_sql,
exp.JSONExtractScalar: _json_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: _json_extract_sql,
exp.GroupConcat: rename_func("LISTAGG"),
exp.ParseJSON: rename_func("JSON_PARSE"),
exp.Select: transforms.preprocess(
Expand Down Expand Up @@ -228,6 +240,13 @@ def with_properties(self, properties: exp.Properties) -> str:
"""Redshift doesn't have `WITH` as part of their with_properties so we remove it"""
return self.properties(properties, prefix=" ", suffix="")

def cast_sql(self, expression: exp.Cast, safe_prefix: t.Optional[str] = None) -> str:
if expression.is_type(exp.DataType.Type.JSON):
# Redshift doesn't support a JSON type, so casting to it is treated as a noop
return self.sql(expression, "this")

return super().cast_sql(expression, safe_prefix=safe_prefix)

def datatype_sql(self, expression: exp.DataType) -> str:
"""
Redshift converts the `TEXT` data type to `VARCHAR(255)` by default when people more generally mean
Expand Down
11 changes: 11 additions & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,17 @@ class Generator(generator.Generator):
exp.Xor: rename_func("BOOLXOR"),
}

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathChild,
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}

JSON_PATH_MAPPING = {
exp.JSONPathRoot: lambda n, **kwargs: "",
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.TIMESTAMP: "TIMESTAMPNTZ",
Expand Down
20 changes: 15 additions & 5 deletions sqlglot/dialects/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Dialect,
NormalizationStrategy,
any_value_to_max_sql,
arrow_json_extract_scalar_sql,
arrow_json_extract_sql,
concat_to_dpipe_sql,
count_if_to_sum,
Expand All @@ -28,6 +27,12 @@ def _date_add_sql(self: SQLite.Generator, expression: exp.DateAdd) -> str:
return self.func("DATE", expression.this, modifier)


def _json_extract_sql(self: SQLite.Generator, expression: exp.JSONExtract) -> str:
if expression.expressions:
return self.function_fallback_sql(expression)
return arrow_json_extract_sql(self, expression)


def _transform_create(expression: exp.Expression) -> exp.Expression:
"""Move primary key to a column and enforce auto_increment on primary keys."""
schema = expression.this
Expand Down Expand Up @@ -86,6 +91,13 @@ class Generator(generator.Generator):
QUERY_HINTS = False
NVL2_SUPPORTED = False

SUPPORTED_JSON_PATH_PARTS = {
exp.JSONPathChild,
exp.JSONPathKey,
exp.JSONPathRoot,
exp.JSONPathSubscript,
}

TYPE_MAPPING = {
**generator.Generator.TYPE_MAPPING,
exp.DataType.Type.BOOLEAN: "INTEGER",
Expand Down Expand Up @@ -120,10 +132,8 @@ class Generator(generator.Generator):
exp.DateAdd: _date_add_sql,
exp.DateStrToDate: lambda self, e: self.sql(e, "this"),
exp.ILike: no_ilike_sql,
exp.JSONExtract: arrow_json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONBExtract: arrow_json_extract_sql,
exp.JSONBExtractScalar: arrow_json_extract_scalar_sql,
exp.JSONExtract: _json_extract_sql,
exp.JSONExtractScalar: arrow_json_extract_sql,
exp.Levenshtein: rename_func("EDITDIST3"),
exp.LogicalOr: rename_func("MAX"),
exp.LogicalAnd: rename_func("MIN"),
Expand Down

0 comments on commit b4e8868

Please sign in to comment.