Skip to content

Commit

Permalink
feat(transpiler): handle different hex behavior for dialects (#3463)
Browse files Browse the repository at this point in the history
* feat(transpiler): handle different hex behavior for dialects

* Update sqlglot/dialects/bigquery.py

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>

* remove UpperHex, add LowerHex

* clean

---------

Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
  • Loading branch information
viplazylmht and georgesittas committed May 15, 2024
1 parent e3ff67b commit 2433993
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 3 deletions.
7 changes: 5 additions & 2 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _build_date(args: t.List) -> exp.Date | exp.DateFromParts:
def _build_to_hex(args: t.List) -> exp.Hex | exp.MD5:
# TO_HEX(MD5(..)) is common in BigQuery, so it's parsed into MD5 to simplify its transpilation
arg = seq_get(args, 0)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.Hex(this=arg)
return exp.MD5(this=arg.this) if isinstance(arg, exp.MD5Digest) else exp.LowerHex(this=arg)


def _array_contains_sql(self: BigQuery.Generator, expression: exp.ArrayContains) -> str:
Expand Down Expand Up @@ -245,6 +245,8 @@ class BigQuery(Dialect):
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}

HEX_LOWERCASE = True

def normalize_identifier(self, expression: E) -> E:
if isinstance(expression, exp.Identifier):
parent = expression.parent
Expand Down Expand Up @@ -603,7 +605,8 @@ class Generator(generator.Generator):
),
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.Hex: lambda self, e: self.func("UPPER", self.func("TO_HEX", self.sql(e, "this"))),
exp.LowerHex: rename_func("TO_HEX"),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
Expand Down
16 changes: 16 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,22 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""

HEX_LOWERCASE = False
"""
Different dialect, `HEX` function will producing a different string, some are in
lowercase, other are in uppercase. HEX_LOWERCASE property can determine the case
of the string which current dialect use. `HEX` can be wrapped by an additional
lower or upper function to convert the output to exact dialect.
For example,
`SELECT TO_HEX(x)`;
in Bigquery will be rewritten as the following one in Presto and Trino
`SELECT LOWER(TO_HEX(x))`;
In another example,
`SELECT TO_HEX(x)`;
in Presto will be rewritten as the following one in Bigquery
`SELECT UPPER(TO_HEX(x))`;
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down
3 changes: 3 additions & 0 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ class Generator(generator.Generator):
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.Hex: rename_func("TO_HEX"),
exp.LowerHex: lambda self, e: self.func(
"LOWER", self.func("TO_HEX", self.sql(e, "this"))
),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5197,6 +5197,10 @@ class Hex(Func):
pass


class LowerHex(Hex):
pass


class Xor(Connector, Func):
arg_types = {"this": False, "expression": False, "expressions": False}

Expand Down
13 changes: 13 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,19 @@ def identifier_sql(self, expression: exp.Identifier) -> str:
text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
return text

def hex_sql(self, expression: exp.Hex) -> str:
text = self.func("HEX", self.sql(expression, "this"))
if self.dialect.HEX_LOWERCASE:
text = self.func("LOWER", text)

return text

def lowerhex_sql(self, expression: exp.LowerHex) -> str:
text = self.func("HEX", self.sql(expression, "this"))
if not self.dialect.HEX_LOWERCASE:
text = self.func("LOWER", text)
return text

def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
input_format = self.sql(expression, "input_format")
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
Expand Down
20 changes: 20 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,23 @@ def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)


def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.LowerHex:
arg = seq_get(args, 0)
return exp.LowerHex(this=arg) if dialect.HEX_LOWERCASE else exp.Hex(this=arg)


def build_lower(args: t.List) -> exp.Lower | exp.Hex:
# LOWER(HEX(..)) can be simplified to LowerHex to simplify its transpilation
arg = seq_get(args, 0)
return exp.LowerHex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Lower(this=arg)


def build_upper(args: t.List) -> exp.Upper | exp.Hex:
# UPPER(HEX(..)) can be simplified to Hex to simplify its transpilation
arg = seq_get(args, 0)
return exp.Hex(this=arg.this) if isinstance(arg, exp.Hex) else exp.Upper(this=arg)


def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
def _builder(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
Expand Down Expand Up @@ -148,6 +165,9 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": build_var_map,
"LOWER": build_lower,
"UPPER": build_upper,
"HEX": build_hex,
}

NO_PAREN_FUNCTIONS = {
Expand Down
55 changes: 55 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,58 @@ def test_bigquery(self):
"mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
},
)
self.validate_all(
"LOWER(TO_HEX(x))",
write={
"": "LOWER(HEX(x))",
"bigquery": "TO_HEX(x)",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
)
self.validate_all(
"TO_HEX(x)",
read={
"": "LOWER(HEX(x))",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
write={
"": "LOWER(HEX(x))",
"bigquery": "TO_HEX(x)",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
)
self.validate_all(
"UPPER(TO_HEX(x))",
read={
"": "HEX(x)",
"presto": "TO_HEX(x)",
"trino": "TO_HEX(x)",
"clickhouse": "HEX(x)",
"hive": "HEX(x)",
"spark": "HEX(x)",
},
write={
"": "HEX(x)",
"bigquery": "UPPER(TO_HEX(x))",
"presto": "TO_HEX(x)",
"trino": "TO_HEX(x)",
"clickhouse": "HEX(x)",
"hive": "HEX(x)",
"spark": "HEX(x)",
},
)
self.validate_all(
"MD5(x)",
read={
Expand All @@ -653,6 +705,9 @@ def test_bigquery(self):
read={
"duckdb": "SELECT MD5(some_string)",
"spark": "SELECT MD5(some_string)",
"clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
"presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
"trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
},
write={
"": "SELECT MD5(some_string)",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,9 @@ def test_functions(self):
self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash)
self.assertIsInstance(parse_one("DATE(foo)"), exp.Date)
self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("LOWER(HEX(foo))"), exp.LowerHex)
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.LowerHex)
self.assertIsInstance(parse_one("UPPER(TO_HEX(foo))", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)
Expand Down

0 comments on commit 2433993

Please sign in to comment.