Skip to content

Commit

Permalink
fix: exception for parenthesized where clause
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Aug 20, 2023
1 parent fe916f4 commit b32c7da
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
27 changes: 17 additions & 10 deletions sqllineage/core/parser/sqlfluff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,29 @@ def list_subqueries(segment: BaseSegment) -> List[SubQueryTuple]:
)
]
elif segment.type == "where_clause":
expression_segments = segment.get_child("expression").segments or []
bracketed_segments = [s for s in expression_segments if s.type == "bracketed"]
if bracketed_segments and is_subquery(bracketed_segments[0]):
return [
SubQueryTuple(extract_innermost_bracketed(bracketed_segments[0]), None)
]
bracketeds = []
if expression := segment.get_child("expression"):
bracketeds = expression.get_children("bracketed")
elif bracketed_where := segment.get_child("bracketed"):
if expression := bracketed_where.get_child("expression"):
bracketeds = expression.get_children("bracketed")
subquery = [
SubQueryTuple(extract_innermost_bracketed(bracketed), None)
for bracketed in bracketeds
if is_subquery(bracketed)
]
elif segment.type in ["from_clause", "from_expression"]:
if from_expression_element := find_from_expression_element(segment):
subquery += list_subqueries(from_expression_element)
subquery = list_subqueries(from_expression_element)
for join_clause in list_join_clause(segment):
if from_expression_element := find_from_expression_element(join_clause):
subquery += list_subqueries(from_expression_element)
elif is_set_expression(segment):
for s in list_child_segments(segment, check_bracketed=True):
if s.type == "bracketed" or s.type == "select_statement":
subquery.append(SubQueryTuple(s, None))
subquery = [
SubQueryTuple(s, None)
for s in list_child_segments(segment, check_bracketed=True)
if s.type in ("bracketed", "select_statement")
]
return subquery


Expand Down
6 changes: 6 additions & 0 deletions tests/sql/table/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def test_select_with_where():
)


def test_select_with_parenthesized_where():
assert_table_lineage_equal(
"SELECT * FROM tab1 WHERE (col1 > val1 AND col2 = 'val2')", {"tab1"}
)


def test_select_with_comment():
assert_table_lineage_equal("SELECT -- comment1\n col1 FROM tab1", {"tab1"})

Expand Down

0 comments on commit b32c7da

Please sign in to comment.