Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transpiler): handle different hex behavior for dialects #3463

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sqlglot/dataframe/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,6 +908,10 @@ def hex(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Hex)


def upperhex(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.UpperHex)
georgesittas marked this conversation as resolved.
Show resolved Hide resolved


def unhex(col: ColumnOrName) -> Column:
return Column.invoke_expression_over_column(col, expression.Unhex)

Expand Down
5 changes: 5 additions & 0 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,8 @@ class BigQuery(Dialect):
# https://cloud.google.com/bigquery/docs/querying-partitioned-tables#query_an_ingestion-time_partitioned_table
PSEUDOCOLUMNS = {"_PARTITIONTIME", "_PARTITIONDATE"}

HEX_LOWERCASE: t.Optional[bool] = True
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved

def normalize_identifier(self, expression: E) -> E:
if isinstance(expression, exp.Identifier):
parent = expression.parent
Expand Down Expand Up @@ -604,6 +606,9 @@ class Generator(generator.Generator):
exp.GenerateSeries: rename_func("GENERATE_ARRAY"),
exp.GroupConcat: rename_func("STRING_AGG"),
exp.Hex: rename_func("TO_HEX"),
exp.UpperHex: lambda self, e: self.func(
"UPPER", self.func("TO_HEX", self.sql(e, "this"))
),
exp.If: if_sql(false_value="NULL"),
exp.ILike: no_ilike_sql,
exp.IntDiv: rename_func("DIV"),
Expand Down
17 changes: 17 additions & 0 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,23 @@ class Dialect(metaclass=_Dialect):
) SELECT c FROM y;
"""

HEX_LOWERCASE: t.Optional[bool] = False
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved
"""
Different dialect, `HEX` function will producing a different string, some are in
lowercase, other are in uppercase. To prevent this mismatch behavior, we can use
HEX_LOWERCASE property to determine the case which current dialect use. In that
case, `HEX` can be wrapped by an additional lower or upper function to convert
the output to exact dialect.
For example,
`SELECT TO_HEX(x)`;
in Bigquery will be rewritten as the following one in Presto and Trino
`SELECT LOWER(TO_HEX(x))`;
In another example,
`SELECT TO_HEX(x)`;
in Presto will be rewritten as the following one in Bigquery
`SELECT UPPER(TO_HEX(x))`;
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved
"""

# --- Autofilled ---

tokenizer_class = Tokenizer
Expand Down
5 changes: 3 additions & 2 deletions sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class Parser(parser.Parser):
this=seq_get(args, 0), substr=seq_get(args, 1), instance=seq_get(args, 2)
),
"TO_CHAR": _build_to_char,
"TO_HEX": exp.Hex.from_arg_list,
"TO_HEX": exp.UpperHex.from_arg_list,
georgesittas marked this conversation as resolved.
Show resolved Hide resolved
"TO_UNIXTIME": exp.TimeToUnix.from_arg_list,
"TO_UTF8": lambda args: exp.Encode(
this=seq_get(args, 0), charset=exp.Literal.string("utf-8")
Expand Down Expand Up @@ -384,7 +384,8 @@ class Generator(generator.Generator):
exp.GroupConcat: lambda self, e: self.func(
"ARRAY_JOIN", self.func("ARRAY_AGG", e.this), e.args.get("separator")
),
exp.Hex: rename_func("TO_HEX"),
exp.Hex: lambda self, e: self.func("LOWER", self.func("TO_HEX", self.sql(e, "this"))),
exp.UpperHex: rename_func("TO_HEX"),
exp.If: if_sql(),
exp.ILike: no_ilike_sql,
exp.Initcap: _initcap_sql,
Expand Down
4 changes: 4 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5193,6 +5193,10 @@ class Hex(Func):
pass


class UpperHex(Func):
pass
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved


class Xor(Connector, Func):
arg_types = {"this": False, "expression": False, "expressions": False}

Expand Down
13 changes: 13 additions & 0 deletions sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1301,6 +1301,19 @@ def identifier_sql(self, expression: exp.Identifier) -> str:
text = f"{self.dialect.IDENTIFIER_START}{text}{self.dialect.IDENTIFIER_END}"
return text

def hex_sql(self, expression: exp.Hex) -> str:
text = self.func("HEX", self.sql(expression, "this"))
if not self.dialect.HEX_LOWERCASE:
text = self.func("LOWER", text)

return text

def upperhex_sql(self, expression: exp.UpperHex) -> str:
text = self.func("HEX", self.sql(expression, "this"))
if self.dialect.HEX_LOWERCASE:
text = self.func("UPPER", text)
return text

def inputoutputformat_sql(self, expression: exp.InputOutputFormat) -> str:
input_format = self.sql(expression, "input_format")
input_format = f"INPUTFORMAT {input_format}" if input_format else ""
Expand Down
30 changes: 30 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,33 @@ def build_logarithm(args: t.List, dialect: Dialect) -> exp.Func:
return (exp.Ln if dialect.parser_class.LOG_DEFAULTS_TO_LN else exp.Log)(this=this)


def build_hex(args: t.List, dialect: Dialect) -> exp.Hex | exp.UpperHex:
arg = seq_get(args, 0)
return exp.Hex(this=arg) if dialect.HEX_LOWERCASE else exp.UpperHex(this=arg)


def build_lower(args: t.List, dialect: Dialect) -> exp.Lower | exp.Hex:
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved
# if the dialect provides HEX_UPPERCASE, LOWER(HEX(..)) can be simplified to Hex to simplify its transpilation
arg = seq_get(args, 0)
return (
exp.Hex(this=arg.this)
if (not dialect.HEX_LOWERCASE and isinstance(arg, exp.UpperHex))
or (dialect.HEX_LOWERCASE and isinstance(arg, exp.Hex))
else exp.Lower(this=arg)
)


def build_upper(args: t.List, dialect: Dialect) -> exp.Upper | exp.UpperHex:
# if the dialect provides HEX_UPPERCASE, UPPER(HEX(..)) can be simplified to UpperHex to simplify its transpilation
arg = seq_get(args, 0)
return (
exp.UpperHex(this=arg.this)
if (not dialect.HEX_LOWERCASE and isinstance(arg, exp.UpperHex))
or (dialect.HEX_LOWERCASE and isinstance(arg, exp.Hex))
else exp.Upper(this=arg)
)


def build_extract_json_with_path(expr_type: t.Type[E]) -> t.Callable[[t.List, Dialect], E]:
def _builder(args: t.List, dialect: Dialect) -> E:
expression = expr_type(
Expand Down Expand Up @@ -148,6 +175,9 @@ class Parser(metaclass=_Parser):
length=exp.Literal.number(10),
),
"VAR_MAP": build_var_map,
"LOWER": build_lower,
"UPPER": build_upper,
"HEX": build_hex,
}

NO_PAREN_FUNCTIONS = {
Expand Down
8 changes: 7 additions & 1 deletion tests/dataframe/unit/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,8 +1188,14 @@ def test_bin(self):

def test_hex(self):
col_str = SF.hex("cola")
self.assertEqual("HEX(cola)", col_str.sql())
self.assertEqual("LOWER(HEX(cola))", col_str.sql())
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved
col = SF.hex(SF.col("cola"))
self.assertEqual("LOWER(HEX(cola))", col.sql())

def test_upperhex(self):
col_str = SF.upperhex("cola")
self.assertEqual("HEX(cola)", col_str.sql())
col = SF.upperhex(SF.col("cola"))
self.assertEqual("HEX(cola)", col.sql())

def test_unhex(self):
Expand Down
56 changes: 56 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,59 @@ def test_bigquery(self):
"mysql": "SELECT DATE_SUB(TIMESTAMP('2008-12-25 15:30:00+00'), INTERVAL 10 MINUTE)",
},
)
self.validate_all(
"LOWER(TO_HEX(x))",
write={
"": "LOWER(HEX(x))",
"bigquery": "TO_HEX(x)",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
)
self.validate_all(
"TO_HEX(x)",
read={
"": "LOWER(HEX(x))",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
write={
"": "LOWER(HEX(x))",
"bigquery": "TO_HEX(x)",
"presto": "LOWER(TO_HEX(x))",
"trino": "LOWER(TO_HEX(x))",
"clickhouse": "LOWER(HEX(x))",
"hive": "LOWER(HEX(x))",
"spark": "LOWER(HEX(x))",
},
)
self.validate_all(
"UPPER(TO_HEX(x))",
read={
"": "HEX(x)",
"bigquery": "UPPER(TO_HEX(x))",
viplazylmht marked this conversation as resolved.
Show resolved Hide resolved
"presto": "TO_HEX(x)",
"trino": "TO_HEX(x)",
"clickhouse": "HEX(x)",
"hive": "HEX(x)",
"spark": "HEX(x)",
},
write={
"": "HEX(x)",
"bigquery": "UPPER(TO_HEX(x))",
"presto": "TO_HEX(x)",
"trino": "TO_HEX(x)",
"clickhouse": "HEX(x)",
"hive": "HEX(x)",
"spark": "HEX(x)",
},
)
self.validate_all(
"MD5(x)",
read={
Expand All @@ -653,6 +706,9 @@ def test_bigquery(self):
read={
"duckdb": "SELECT MD5(some_string)",
"spark": "SELECT MD5(some_string)",
"clickhouse": "SELECT LOWER(HEX(MD5(some_string)))",
"presto": "SELECT LOWER(TO_HEX(MD5(some_string)))",
"trino": "SELECT LOWER(TO_HEX(MD5(some_string)))",
},
write={
"": "SELECT MD5(some_string)",
Expand Down
4 changes: 3 additions & 1 deletion tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,10 @@ def test_functions(self):
self.assertIsInstance(parse_one("ARRAY(time, foo)"), exp.Array)
self.assertIsInstance(parse_one("STANDARD_HASH('hello', 'sha256')"), exp.StandardHash)
self.assertIsInstance(parse_one("DATE(foo)"), exp.Date)
self.assertIsInstance(parse_one("HEX(foo)"), exp.Hex)
self.assertIsInstance(parse_one("LOWER(HEX(foo))"), exp.Hex)
self.assertIsInstance(parse_one("HEX(foo)"), exp.UpperHex)
self.assertIsInstance(parse_one("TO_HEX(foo)", read="bigquery"), exp.Hex)
self.assertIsInstance(parse_one("UPPER(TO_HEX(foo))", read="bigquery"), exp.UpperHex)
self.assertIsInstance(parse_one("TO_HEX(MD5(foo))", read="bigquery"), exp.MD5)
self.assertIsInstance(parse_one("TRANSFORM(a, b)", read="spark"), exp.Transform)
self.assertIsInstance(parse_one("ADD_MONTHS(a, b)"), exp.AddMonths)
Expand Down
Loading