Skip to content

Commit

Permalink
fix: use get_type() to determine if a statement is DROP; Filter endin…
Browse files Browse the repository at this point in the history
…g empty statement. (#24)
  • Loading branch information
reata committed Aug 11, 2019
1 parent 245a835 commit c7ea5d8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
12 changes: 6 additions & 6 deletions sqllineage/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import sqlparse
from sqlparse.sql import Function, Identifier, Parenthesis, Statement, TokenList
from sqlparse.tokens import DDL, Keyword, Token
from sqlparse.tokens import Keyword, Token

SOURCE_TABLE_TOKENS = ('FROM', 'JOIN', 'INNER JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'LEFT OUTER JOIN', 'RIGHT OUTER JOIN',
'FULL OUTER JOIN', 'CROSS JOIN')
Expand All @@ -18,9 +18,9 @@ def __init__(self, sql: str, encoding=None):
self._encoding = encoding
self._source_tables = set()
self._target_tables = set()
self._stmt = sqlparse.parse(sql, self._encoding)
self._stmt = sqlparse.parse(sql.strip(), self._encoding)
for stmt in self._stmt:
if stmt.token_first().ttype == DDL and stmt.token_first().normalized == "DROP":
if stmt.get_type() == "DROP":
self._target_tables -= {t.get_real_name() for t in stmt.tokens if isinstance(t, Identifier)}
else:
self._extract_from_token(stmt)
Expand Down Expand Up @@ -76,7 +76,7 @@ def _extract_from_token(self, token: Token):
continue
else:
assert isinstance(sub_token, Identifier)
if isinstance(sub_token.token_first(), Parenthesis):
if isinstance(sub_token.token_first(skip_cm=True), Parenthesis):
# SELECT col1 FROM (SELECT col2 FROM tab1) dt, the subquery will be parsed as Identifier
# and this Identifier's get_real_name method would return alias name dt
# referring https://github.com/andialbrecht/sqlparse/issues/218 for further information
Expand All @@ -90,8 +90,8 @@ def _extract_from_token(self, token: Token):
elif isinstance(sub_token, Function):
# insert into tab (col1, col2), tab (col1, col2) will be parsed as Function
# referring https://github.com/andialbrecht/sqlparse/issues/483 for further information
assert isinstance(sub_token.token_first(), Identifier)
self._target_tables.add(sub_token.token_first().get_real_name())
assert isinstance(sub_token.token_first(skip_cm=True), Identifier)
self._target_tables.add(sub_token.token_first(skip_cm=True).get_real_name())
target_table_token_flag = False
else:
assert isinstance(sub_token, Identifier)
Expand Down
10 changes: 10 additions & 0 deletions tests/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def test_drop():
helper("DROP TABLE IF EXISTS tab1", None, None)


def test_drop_with_comment():
helper("""--comment
DROP TABLE IF EXISTS tab1""", None, None)


def test_drop_after_create():
helper("CREATE TABLE IF NOT EXISTS tab1 (col1 STRING);DROP TABLE IF EXISTS tab1", None, None)

Expand All @@ -33,3 +38,8 @@ def test_create_select():
def test_split_statements():
sql = "SELECT * FROM tab1; SELECT * FROM tab2;"
assert len(LineageParser(sql).statements) == 2


def test_split_statements_with_heading_and_ending_new_line():
sql = "\nSELECT * FROM tab1;\nSELECT * FROM tab2;\n"
assert len(LineageParser(sql).statements) == 2

0 comments on commit c7ea5d8

Please sign in to comment.