Skip to content

Commit

Permalink
fix: column lineage for subquery in nested parenthesis
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Mar 27, 2022
1 parent bef4d3a commit c62d82a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 8 deletions.
23 changes: 15 additions & 8 deletions sqllineage/utils/sqlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@
from sqllineage.utils.entities import SubQueryTuple


def get_innermost_parenthesis(token: Parenthesis):
# in case of subquery in nested parenthesis, find the innermost one first
while True:
idx, sub_paren = token.token_next_by(i=Parenthesis)
if sub_paren is not None and idx == 1:
token = sub_paren
else:
break
return token


def is_subquery(token: TokenList) -> bool:
flag = False
if isinstance(token, Parenthesis):
# in case of subquery in nested parenthesis, find the innermost one first
while True:
idx, sub_paren = token.token_next_by(i=Parenthesis)
if sub_paren is not None and idx == 1:
token = sub_paren
else:
break
token = get_innermost_parenthesis(token)
# check if innermost parenthesis contains SELECT
_, sub_token = token.token_next_by(m=(DML, "SELECT"))
if sub_token is not None:
Expand Down Expand Up @@ -68,7 +73,9 @@ def get_subquery_parentheses(
elif is_subquery(tk):
subquery.append(SubQueryTuple(tk, token.get_real_name()))
if is_subquery(target):
subquery = [SubQueryTuple(target, token.get_real_name())]
subquery = [
SubQueryTuple(get_innermost_parenthesis(target), token.get_real_name())
]
return subquery


Expand Down
10 changes: 10 additions & 0 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,16 @@ def test_select_column_in_subquery():
)


def test_select_column_in_subquery_with_two_parenthesis():
sql = """INSERT OVERWRITE TABLE tab1
SELECT col1
FROM ((SELECT col1 FROM tab2)) dt"""
assert_column_lineage_equal(
sql,
[(ColumnQualifierTuple("col1", "tab2"), ColumnQualifierTuple("col1", "tab1"))],
)


def test_select_column_from_table_join():
sql = """INSERT OVERWRITE TABLE tab1
SELECT tab2.col1,
Expand Down

0 comments on commit c62d82a

Please sign in to comment.