Skip to content

Commit

Permalink
fix: remove parenthesis around subquery between union
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Jun 19, 2022
1 parent c942f7f commit 662996a
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 2 deletions.
45 changes: 43 additions & 2 deletions sqllineage/utils/sqlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,47 @@ def is_token_negligible(token: TokenList) -> bool:
return token.is_whitespace or isinstance(token, Comment)


def get_innermost_parenthesis(token: Parenthesis):
def remove_parenthesis_between_union(token: Parenthesis) -> Parenthesis:
"""
remove parenthesis around subqueries between union
Example:
INPUT: ((SELECT col1 FROM tab2) UNION ALL (SELECT col1 FROM tab3))
OUTPUT: (SELECT col1 FROM tab2 UNION ALL SELECT col1 FROM tab3)
"""
# pre-compute the idx of UNION/UNION ALL
offsets = [-1]
while True:
offset, _ = token.token_next_by(
m=[(Keyword, "UNION ALL"), (Keyword, "UNION")], idx=offsets[-1]
)
if offset is not None:
offsets.append(offset)
else:
break
expand_parenthesis_idx = set()
for i, offset in enumerate(offsets[1:]):
if i == 0:
# first union idx, we need to handle the previous parenthesis
pre_idx, sub_token = token.token_prev(idx=offset)
if isinstance(sub_token, Parenthesis):
expand_parenthesis_idx.add(pre_idx)
# then all post parenthesis handled for each union
post_idx, sub_token = token.token_next(idx=offset)
if isinstance(sub_token, Parenthesis):
expand_parenthesis_idx.add(post_idx)
if expand_parenthesis_idx:
updated_tokens = []
for i, sub_token in enumerate(token.tokens):
if i in expand_parenthesis_idx:
inner_paren = get_innermost_parenthesis(sub_token)
updated_tokens.extend(inner_paren.tokens[1:-1])
else:
updated_tokens.append(sub_token)
token.tokens = updated_tokens
return token


def get_innermost_parenthesis(token: Parenthesis) -> Parenthesis:
# in case of subquery in nested parenthesis, find the innermost one first
while True:
idx, sub_paren = token.token_next_by(i=Parenthesis)
Expand Down Expand Up @@ -82,7 +122,8 @@ def get_subquery_parentheses(
subquery.append(SubQueryTuple(tk.right, tk.right.get_real_name()))
elif is_subquery(tk):
subquery.append(SubQueryTuple(tk, token.get_real_name()))
if is_subquery(target):
elif is_subquery(target):
target = remove_parenthesis_between_union(target)
subquery = [
SubQueryTuple(get_innermost_parenthesis(target), token.get_real_name())
]
Expand Down
46 changes: 46 additions & 0 deletions tests/test_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,52 @@ def test_select_column_in_subquery_with_two_parenthesis_and_blank_in_between():
)


def test_select_column_in_subquery_with_two_parenthesis_and_union():
sql = """INSERT OVERWRITE TABLE tab1
SELECT col1
FROM (
(SELECT col1 FROM tab2)
UNION ALL
(SELECT col1 FROM tab3)
) dt"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("col1", "tab2"),
ColumnQualifierTuple("col1", "tab1"),
),
(
ColumnQualifierTuple("col1", "tab3"),
ColumnQualifierTuple("col1", "tab1"),
),
],
)


def test_select_column_in_subquery_with_two_parenthesis_and_union_v2():
sql = """INSERT OVERWRITE TABLE tab1
SELECT col1
FROM (
SELECT col1 FROM tab2
UNION ALL
SELECT col1 FROM tab3
) dt"""
assert_column_lineage_equal(
sql,
[
(
ColumnQualifierTuple("col1", "tab2"),
ColumnQualifierTuple("col1", "tab1"),
),
(
ColumnQualifierTuple("col1", "tab3"),
ColumnQualifierTuple("col1", "tab1"),
),
],
)


def test_select_column_from_table_join():
sql = """INSERT OVERWRITE TABLE tab1
SELECT tab2.col1,
Expand Down

0 comments on commit 662996a

Please sign in to comment.