Skip to content

Commit

Permalink
fix: have table exclude this if schema target (#2921)
Browse files Browse the repository at this point in the history
  • Loading branch information
eakmanrq committed Feb 6, 2024
1 parent 326aa31 commit d20d826
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 12 deletions.
6 changes: 4 additions & 2 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,10 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:

return this

def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False
) -> exp.Table:
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
if isinstance(table.this, exp.Identifier) and "." in table.name:
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/dialects/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,14 @@ def _parse_table(
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
this = super()._parse_table(
schema=schema, joins=joins, alias_tokens=alias_tokens, parse_bracket=parse_bracket
schema=schema,
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
is_db_reference=is_db_reference,
)

if self._match(TokenType.FINAL):
Expand Down
2 changes: 2 additions & 0 deletions sqlglot/dialects/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def _parse_table(
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
# Redshift supports UNPIVOTing SUPER objects, e.g. `UNPIVOT foo.obj[0] AS val AT attr`
unpivot = self._match(TokenType.UNPIVOT)
Expand All @@ -99,6 +100,7 @@ def _parse_table(
joins=joins,
alias_tokens=alias_tokens,
parse_bracket=parse_bracket,
is_db_reference=is_db_reference,
)

return self.expression(exp.Pivot, this=table, unpivot=True) if unpivot else table
Expand Down
10 changes: 7 additions & 3 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,9 @@ def _parse_at_before(self, table: exp.Table) -> exp.Table:

return table

def _parse_table_parts(self, schema: bool = False) -> exp.Table:
def _parse_table_parts(
self, schema: bool = False, is_db_reference: bool = False
) -> exp.Table:
# https://docs.snowflake.com/en/user-guide/querying-stage
if self._match(TokenType.STRING, advance=False):
table = self._parse_string()
Expand All @@ -550,7 +552,9 @@ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
self._match(TokenType.L_PAREN)
while self._curr and not self._match(TokenType.R_PAREN):
if self._match_text_seq("FILE_FORMAT", "=>"):
file_format = self._parse_string() or super()._parse_table_parts()
file_format = self._parse_string() or super()._parse_table_parts(
is_db_reference=is_db_reference
)
elif self._match_text_seq("PATTERN", "=>"):
pattern = self._parse_string()
else:
Expand All @@ -560,7 +564,7 @@ def _parse_table_parts(self, schema: bool = False) -> exp.Table:

table = self.expression(exp.Table, this=table, format=file_format, pattern=pattern)
else:
table = super()._parse_table_parts(schema=schema)
table = super()._parse_table_parts(schema=schema, is_db_reference=is_db_reference)

return self._parse_at_before(table)

Expand Down
2 changes: 1 addition & 1 deletion sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2571,7 +2571,7 @@ class HistoricalData(Expression):

class Table(Expression):
arg_types = {
"this": True,
"this": False,
"alias": False,
"db": False,
"catalog": False,
Expand Down
26 changes: 21 additions & 5 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,9 @@ def _parse_drop(self, exists: bool = False) -> exp.Drop | exp.Command:
exp.Drop,
comments=start.comments,
exists=exists or self._parse_exists(),
this=self._parse_table(schema=True),
this=self._parse_table(
schema=True, is_db_reference=self._prev.token_type == TokenType.SCHEMA
),
kind=kind,
temporary=temporary,
materialized=materialized,
Expand Down Expand Up @@ -1440,7 +1442,9 @@ def extend_props(temp_props: t.Optional[exp.Properties]) -> None:
elif create_token.token_type == TokenType.INDEX:
this = self._parse_index(index=self._parse_id_var())
elif create_token.token_type in self.DB_CREATABLES:
table_parts = self._parse_table_parts(schema=True)
table_parts = self._parse_table_parts(
schema=True, is_db_reference=create_token.token_type == TokenType.SCHEMA
)

# exp.Properties.Location.POST_NAME
self._match(TokenType.COMMA)
Expand Down Expand Up @@ -2790,7 +2794,7 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
or self._parse_placeholder()
)

def _parse_table_parts(self, schema: bool = False) -> exp.Table:
def _parse_table_parts(self, schema: bool = False, is_db_reference: bool = False) -> exp.Table:
catalog = None
db = None
table: t.Optional[exp.Expression | str] = self._parse_table_part(schema=schema)
Expand All @@ -2806,8 +2810,15 @@ def _parse_table_parts(self, schema: bool = False) -> exp.Table:
db = table
table = self._parse_table_part(schema=schema) or ""

if not table:
if is_db_reference:
catalog = db
db = table
table = None

if not table and not is_db_reference:
self.raise_error(f"Expected table name but got {self._curr}")
if not db and is_db_reference:
self.raise_error(f"Expected database name but got {self._curr}")

return self.expression(
exp.Table, this=table, db=db, catalog=catalog, pivots=self._parse_pivots()
Expand All @@ -2819,6 +2830,7 @@ def _parse_table(
joins: bool = False,
alias_tokens: t.Optional[t.Collection[TokenType]] = None,
parse_bracket: bool = False,
is_db_reference: bool = False,
) -> t.Optional[exp.Expression]:
lateral = self._parse_lateral()
if lateral:
Expand All @@ -2841,7 +2853,11 @@ def _parse_table(
bracket = parse_bracket and self._parse_bracket(None)
bracket = self.expression(exp.Table, this=bracket) if bracket else None
this = t.cast(
exp.Expression, bracket or self._parse_bracket(self._parse_table_parts(schema=schema))
exp.Expression,
bracket
or self._parse_bracket(
self._parse_table_parts(schema=schema, is_db_reference=is_db_reference)
),
)

if schema:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,3 +805,37 @@ def test_parse_concat_ws(self):
error_level=ErrorLevel.IGNORE,
)
self.assertEqual(ast[0].sql(), "CONCAT_WS()")

def test_parse_drop_schema(self):
for dialect in [None, "bigquery", "snowflake"]:
with self.subTest(dialect):
ast = parse_one("DROP SCHEMA catalog.schema", dialect=dialect)
self.assertEqual(
ast,
exp.Drop(
this=exp.Table(
this=None,
db=exp.Identifier(this="schema", quoted=False),
catalog=exp.Identifier(this="catalog", quoted=False),
),
kind="SCHEMA",
),
)
self.assertEqual(ast.sql(dialect=dialect), "DROP SCHEMA catalog.schema")

def test_parse_create_schema(self):
for dialect in [None, "bigquery", "snowflake"]:
with self.subTest(dialect):
ast = parse_one("CREATE SCHEMA catalog.schema", dialect=dialect)
self.assertEqual(
ast,
exp.Create(
this=exp.Table(
this=None,
db=exp.Identifier(this="schema", quoted=False),
catalog=exp.Identifier(this="catalog", quoted=False),
),
kind="SCHEMA",
),
)
self.assertEqual(ast.sql(dialect=dialect), "CREATE SCHEMA catalog.schema")

0 comments on commit d20d826

Please sign in to comment.