Skip to content

Commit

Permalink
Fix: ensure UDF identifier quotes are preserved (#3057)
Browse files Browse the repository at this point in the history
* Fix: ensure UDF identifier quotes are preserved

* Formatting changes due to updated ruff version

* Fixup

* Don't use sql method

* Fixups
  • Loading branch information
georgesittas committed Feb 29, 2024
1 parent 39b3813 commit 6a9501f
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 8 deletions.
2 changes: 1 addition & 1 deletion sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def quote_identifier(self, expression: E, identify: bool = True) -> E:
identify: If set to `False`, the quotes will only be added if the identifier is deemed
"unsafe", with respect to its characters and this dialect's normalization strategy.
"""
if isinstance(expression, exp.Identifier):
if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func):
name = expression.this
expression.set(
"quoted",
Expand Down
3 changes: 0 additions & 3 deletions sqlglot/dialects/drill.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,3 @@ class Generator(generator.Generator):
exp.TsOrDiToDi: lambda self,
e: f"CAST(SUBSTR(REPLACE(CAST({self.sql(e, 'this')} AS VARCHAR), '-', ''), 1, 8) AS INT)",
}

def normalize_func(self, name: str) -> str:
return name if exp.SAFE_IDENTIFIER_RE.match(name) else f"`{name}`"
8 changes: 8 additions & 0 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4005,6 +4005,10 @@ class Overlaps(Binary):


class Dot(Binary):
@property
def is_star(self) -> bool:
return self.expression.is_star

@property
def name(self) -> str:
return self.expression.name
Expand Down Expand Up @@ -4400,6 +4404,10 @@ class Anonymous(Func):
arg_types = {"this": True, "expressions": False}
is_var_len_args = True

@property
def name(self) -> str:
return self.this if isinstance(self.this, str) else self.this.name


class AnonymousAggFunc(AggFunc):
arg_types = {"this": True, "expressions": False}
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2629,7 +2629,7 @@ def reference_sql(self, expression: exp.Reference) -> str:
return f"REFERENCES {this}{expressions}{options}"

def anonymous_sql(self, expression: exp.Anonymous) -> str:
return self.func(expression.name, *expression.expressions)
return self.func(self.sql(expression, "this"), *expression.expressions)

def paren_sql(self, expression: exp.Paren) -> str:
if isinstance(expression.unnest(), exp.Select):
Expand Down
4 changes: 2 additions & 2 deletions sqlglot/optimizer/qualify_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def _expand_stars(
tables = list(scope.selected_sources)
_add_except_columns(expression, tables, except_columns)
_add_replace_columns(expression, tables, replace_columns)
elif expression.is_star:
elif expression.is_star and not isinstance(expression, exp.Dot):
tables = [expression.table]
_add_except_columns(expression.this, tables, except_columns)
_add_replace_columns(expression.this, tables, replace_columns)
Expand Down Expand Up @@ -439,7 +439,7 @@ def _expand_stars(

if pivot_columns:
new_selections.extend(
exp.alias_(exp.column(name, table=pivot.alias), name, copy=False)
alias(exp.column(name, table=pivot.alias), name, copy=False)
for name in pivot_columns
if name not in columns_to_exclude
)
Expand Down
16 changes: 15 additions & 1 deletion sqlglot/optimizer/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,7 @@ def gen(expression: t.Any) -> str:
GEN_MAP = {
exp.Add: lambda e: _binary(e, "+"),
exp.And: lambda e: _binary(e, "AND"),
exp.Anonymous: lambda e: f"{e.this.upper()} {','.join(gen(e) for e in e.expressions)}",
exp.Anonymous: lambda e: _anonymous(e),
exp.Between: lambda e: f"{gen(e.this)} BETWEEN {gen(e.args.get('low'))} AND {gen(e.args.get('high'))}",
exp.Boolean: lambda e: "TRUE" if e.this else "FALSE",
exp.Bracket: lambda e: f"{gen(e.this)}[{gen(e.expressions)}]",
Expand Down Expand Up @@ -1219,6 +1219,20 @@ def gen(expression: t.Any) -> str:
}


def _anonymous(e: exp.Anonymous) -> str:
this = e.this
if isinstance(this, str):
name = this.upper()
elif isinstance(this, exp.Identifier):
name = f'"{this.name}"' if this.quoted else this.name.upper()
else:
raise ValueError(
f"Anonymous.this expects a str or an Identifier, got '{this.__class__.__name__}'."
)

return f"{name} {','.join(gen(e) for e in e.expressions)}"


def _binary(e: exp.Binary, op: str) -> str:
return f"{gen(e.left)} {op} {gen(e.right)}"

Expand Down
2 changes: 2 additions & 0 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4158,6 +4158,8 @@ def _parse_function_call(

this = func
else:
if token_type == TokenType.IDENTIFIER:
this = exp.Identifier(this=this, quoted=True)
this = self.expression(exp.Anonymous, this=this, expressions=args)

if isinstance(this, exp.Expression):
Expand Down
4 changes: 4 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def test_bigquery(self):
self.assertEqual(exp.to_table("`x.y.z`", dialect="bigquery").sql("bigquery"), "`x.y.z`")
self.assertEqual(exp.to_table("`x`.`y`", dialect="bigquery").sql("bigquery"), "`x`.`y`")

select_with_quoted_udf = self.validate_identity("SELECT `p.d.UdF`(data) FROM `p.d.t`")
self.assertEqual(select_with_quoted_udf.selects[0].name, "p.d.UdF")

self.validate_identity("SELECT `p.d.UdF`(data).* FROM `p.d.t`")
self.validate_identity("SELECT * FROM `my-project.my-dataset.my-table`")
self.validate_identity("CREATE OR REPLACE TABLE `a.b.c` CLONE `a.b.d`")
self.validate_identity("SELECT x, 1 AS y GROUP BY 1 ORDER BY 1")
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/optimizer/qualify_columns.sql
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ SELECT _q_0._col_0 AS _col_0, _q_0._col_1 AS _col_1 FROM (VALUES (1, 2)) AS _q_0
select * from (values (1, 2)) x;
SELECT x._col_0 AS _col_0, x._col_1 AS _col_1 FROM (VALUES (1, 2)) AS x(_col_0, _col_1);

# execute: false
SELECT SOME_UDF(data).* FROM t;
SELECT SOME_UDF(t.data).* FROM t AS t;

--------------------------------------
-- Derived tables
--------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions tests/fixtures/optimizer/quote_identifiers.sql
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ SELECT "dual" FROM "t";
# dialect: snowflake
SELECT * FROM t AS dual;
SELECT * FROM "t" AS "dual";

# dialect: bigquery
SELECT `p.d.udf`(data).* FROM `p.d.t`;
SELECT `p.d.udf`(`data`).* FROM `p.d.t`;
17 changes: 17 additions & 0 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,23 @@ def test_simplify(self):
self.assertEqual("CONCAT('a', x, 'bc')", simplified_concat.sql(dialect="presto"))
self.assertEqual("CONCAT('a', x, 'bc')", simplified_safe_concat.sql())

anon_unquoted_str = parse_one("anonymous(x, y)")
self.assertEqual(optimizer.simplify.gen(anon_unquoted_str), "ANONYMOUS x,y")

anon_unquoted_identifier = exp.Anonymous(
this=exp.to_identifier("anonymous"), expressions=[exp.column("x"), exp.column("y")]
)
self.assertEqual(optimizer.simplify.gen(anon_unquoted_identifier), "ANONYMOUS x,y")

anon_quoted = parse_one('"anonymous"(x, y)')
self.assertEqual(optimizer.simplify.gen(anon_quoted), '"anonymous" x,y')

with self.assertRaises(ValueError) as e:
invalid_anonymous = exp.Anonymous(this=5)
optimizer.simplify.gen(invalid_anonymous)

self.assertIn("Anonymous.this expects a str or an Identifier, got 'int'.", str(e.exception))

def test_unnest_subqueries(self):
self.check_file(
"unnest_subqueries",
Expand Down

0 comments on commit 6a9501f

Please sign in to comment.