Skip to content

Commit

Permalink
feat: support subquery in values clause
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Aug 27, 2023
1 parent a5ac784 commit 419ac8e
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
12 changes: 12 additions & 0 deletions sqllineage/core/parser/sqlfluff/extractors/create_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
get_child,
get_children,
is_set_expression,
list_child_segments,
)
Expand Down Expand Up @@ -45,6 +46,17 @@ def extract(
holder |= self.delegate_to_cte(segment, holder)
elif segment.type in ("select_statement", "set_expression"):
holder |= self.delegate_to_select(segment, holder)
elif segment.type == "values_clause":
for bracketed in get_children(segment, "bracketed"):
for expression in get_children(bracketed, "expression"):
if sub_bracketed := get_child(expression, "bracketed"):
if sub_expression := get_child(sub_bracketed, "expression"):
if select_statement := get_child(
sub_expression, "select_statement"
):
holder |= self.delegate_to_select(
select_statement, holder
)
elif segment.type == "bracketed" and (
self.list_subquery(segment) or is_set_expression(segment)
):
Expand Down
2 changes: 1 addition & 1 deletion sqllineage/core/parser/sqlparse/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def _extract_from_dml(
@classmethod
def parse_subquery(cls, token: TokenList) -> List[SubQuery]:
result = []
if isinstance(token, (Identifier, Function, Where)):
if isinstance(token, (Identifier, Function, Where, Values)):
# usually SubQuery is an Identifier, but not all Identifiers are SubQuery
# Function for CTE without AS keyword
result = cls._parse_subquery(token)
Expand Down
10 changes: 8 additions & 2 deletions sqllineage/core/parser/sqlparse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Identifier,
Parenthesis,
TokenList,
Values,
Where,
)
from sqlparse.tokens import DML, Keyword, Name, Wildcard
Expand Down Expand Up @@ -99,7 +100,7 @@ def is_values_clause(token: Parenthesis) -> bool:


def get_subquery_parentheses(
token: Union[Identifier, Function, Where]
token: Union[Identifier, Function, Values, Where]
) -> List[SubQueryTuple]:
"""
Retrieve subquery list
Expand All @@ -115,7 +116,7 @@ def get_subquery_parentheses(
if isinstance(token, Function):
# CTE without AS: tbl (SELECT 1)
target = token.tokens[-1]
elif isinstance(token, Where):
elif isinstance(token, (Values, Where)):
# WHERE col1 IN (SELECT max(col1) FROM tab2)
target = token
else:
Expand All @@ -131,6 +132,11 @@ def get_subquery_parentheses(
subquery.append(SubQueryTuple(tk.right, tk.right.get_real_name()))
elif is_subquery(tk):
subquery.append(SubQueryTuple(tk, token.get_real_name()))
elif isinstance(target, Values):
for row in target.get_sublists():
for col in row:
if is_subquery(col):
subquery.append(SubQueryTuple(col, col.get_real_name()))
elif is_subquery(target):
target = remove_parenthesis_between_union(target)
subquery = [
Expand Down
38 changes: 33 additions & 5 deletions tests/sql/sqlfluff_only/test_parenthesized_expression.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""
sqlparse implementation is capable of generating the correct result, however with redundant nodes in graph
"""

from sqllineage.utils.entities import ColumnQualifierTuple
from ...helpers import assert_column_lineage_equal
from ...helpers import assert_column_lineage_equal, assert_table_lineage_equal


def test_subquery_expression_without_source_table():
"""
sqlparse implementation is capable of generating the correct result, however with redundant nodes in graph
"""
assert_column_lineage_equal(
"""INSERT INTO foo
SELECT (SELECT col1 + col2 AS result) AS sum_result
Expand All @@ -20,5 +21,32 @@ def test_subquery_expression_without_source_table():
ColumnQualifierTuple("sum_result", "foo"),
),
],
test_sqlparse=False,
skip_graph_check=True,
)


def test_insert_into_values_with_subquery():
assert_table_lineage_equal(
"INSERT INTO tab1 VALUES (1, (SELECT max(id) FROM tab2))",
{"tab2"},
{"tab1"},
skip_graph_check=True,
)


def test_insert_into_values_with_multiple_subquery():
assert_table_lineage_equal(
"INSERT INTO tab1 VALUES ((SELECT max(id) FROM tab2), (SELECT max(id) FROM tab3))",
{"tab2", "tab3"},
{"tab1"},
skip_graph_check=True,
)


def test_insert_into_values_with_multiple_subquery_in_multiple_row():
assert_table_lineage_equal(
"INSERT INTO tab1 VALUES (1, (SELECT max(id) FROM tab2)), (2, (SELECT max(id) FROM tab3))",
{"tab2", "tab3"},
{"tab1"},
skip_graph_check=True,
)
2 changes: 1 addition & 1 deletion tests/sql/table/test_insert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from ...helpers import assert_table_lineage_equal


def test_insert_into():
def test_insert_into_values():
assert_table_lineage_equal("INSERT INTO tab1 VALUES (1, 2)", set(), {"tab1"})


Expand Down

0 comments on commit 419ac8e

Please sign in to comment.