Skip to content

Commit

Permalink
fix: qualified wildcard column recognition
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Aug 13, 2023
1 parent bff222c commit 13d1197
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 35 deletions.
37 changes: 22 additions & 15 deletions sqllineage/core/parser/sqlfluff/extractors/dml_merge_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
extract_column_qualifier,
extract_identifier,
extract_innermost_bracketed,
get_child,
Expand Down Expand Up @@ -93,11 +94,15 @@ def extract(
for set_clause in get_children(set_clause_list, "set_clause"):
columns = get_children(set_clause, "column_reference")
if len(columns) == 2:
src_col = Column(extract_identifier(columns[1]))
src_col.parent = direct_source
tgt_col = Column(extract_identifier(columns[0]))
tgt_col.parent = list(holder.write)[0]
holder.add_column_lineage(src_col, tgt_col)
src_col = tgt_col = None
if src_cqt := extract_column_qualifier(columns[1]):
src_col = Column(src_cqt.column)
src_col.parent = direct_source
if tgt_cqt := extract_column_qualifier(columns[0]):
tgt_col = Column(tgt_cqt.column)
tgt_col.parent = list(holder.write)[0]
if src_col is not None and tgt_col is not None:
holder.add_column_lineage(src_col, tgt_col)
for merge_when_not_matched_clause in get_children(
merge_match, "merge_when_not_matched_clause"
):
Expand All @@ -107,20 +112,22 @@ def extract(
insert_columns = []
if bracketed := get_child(merge_insert, "bracketed"):
for column_reference in get_children(bracketed, "column_reference"):
tgt_col = Column(extract_identifier(column_reference))
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
if cqt := extract_column_qualifier(column_reference):
tgt_col = Column(cqt.column)
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
if values_clause := get_child(merge_insert, "values_clause"):
if bracketed := get_child(values_clause, "bracketed"):
for j, e in enumerate(
get_children(bracketed, "literal", "expression")
):
if column_reference := get_child(e, "column_reference"):
src_col = Column(
extract_identifier(column_reference)
)
src_col.parent = direct_source
holder.add_column_lineage(
src_col, insert_columns[j]
)
if cqt := extract_column_qualifier(
column_reference
):
src_col = Column(cqt.column)
src_col.parent = direct_source
holder.add_column_lineage(
src_col, insert_columns[j]
)
return holder
35 changes: 16 additions & 19 deletions sqllineage/core/parser/sqlfluff/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from sqllineage import SQLPARSE_DIALECT
from sqllineage.core.models import Column, Schema, SubQuery, Table
from sqllineage.core.parser.sqlfluff.utils import (
extract_column_qualifier,
extract_identifier,
is_subquery,
is_wildcard,
Expand Down Expand Up @@ -111,8 +112,11 @@ def of(column: BaseSegment, **kwargs) -> Column:
if source_columns:
column_name = None
for sub_segment in list_child_segments(column):
if sub_segment.type == "column_reference":
column_name = extract_identifier(sub_segment)
if sub_segment.type == "column_reference" or is_wildcard(
sub_segment
):
if cqt := extract_column_qualifier(sub_segment):
column_name = cqt.column
elif sub_segment.type == "expression":
# special handling for postgres style type cast, col as target column name instead of col::type
if len(sub2_segments := list_child_segments(sub_segment)) == 1:
Expand All @@ -130,7 +134,10 @@ def of(column: BaseSegment, **kwargs) -> Column:
if (
sub3_segment := sub3_segments[0]
).type == "column_reference":
column_name = extract_identifier(sub3_segment)
if cqt := extract_column_qualifier(
sub3_segment
):
column_name = cqt.column
return Column(
column.raw if column_name is None else column_name,
source_columns=source_columns,
Expand All @@ -149,12 +156,11 @@ def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]:
:param segment: segment to be processed
:return: list of extracted source columns
"""
if segment.type == "identifier" or is_wildcard(segment):
return [ColumnQualifierTuple(segment.raw, None)]
if segment.type == "column_reference":
parent, column = SqlFluffColumn._get_column_and_parent(segment)
return [ColumnQualifierTuple(column, parent)]
if segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE:
col_list = []
if segment.type in ("identifier", "column_reference") or is_wildcard(segment):
if cqt := extract_column_qualifier(segment):
col_list = [cqt]
elif segment.type in NON_IDENTIFIER_OR_COLUMN_SEGMENT_TYPE:
sub_segments = list_child_segments(segment)
col_list = []
for sub_segment in sub_segments:
Expand All @@ -172,8 +178,7 @@ def _extract_source_columns(segment: BaseSegment) -> List[ColumnQualifierTuple]:
):
res = SqlFluffColumn._extract_source_columns(sub_segment)
col_list.extend(res)
return col_list
return []
return col_list

@staticmethod
def _get_column_from_subquery(
Expand Down Expand Up @@ -229,12 +234,4 @@ def _get_column_and_alias(
):
res = SqlFluffColumn._extract_source_columns(sub_segment)
columns += res if res else []

return columns, alias

@staticmethod
def _get_column_and_parent(col_segment: BaseSegment) -> Tuple[Optional[str], str]:
identifiers = list_child_segments(col_segment)
if len(identifiers) > 1:
return identifiers[-2].raw, identifiers[-1].raw
return None, identifiers[-1].raw
19 changes: 18 additions & 1 deletion sqllineage/core/parser/sqlfluff/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sqlfluff.core.parser import BaseSegment

from sqllineage.utils.entities import SubQueryTuple
from sqllineage.utils.entities import ColumnQualifierTuple, SubQueryTuple


def is_negligible(segment: BaseSegment) -> bool:
Expand Down Expand Up @@ -228,6 +228,23 @@ def extract_as_and_target_segment(
return as_segment, target


def extract_column_qualifier(segment: BaseSegment) -> Optional[ColumnQualifierTuple]:
cqt = None
if is_wildcard(segment):
identifiers = segment.raw.split(".")
column = identifiers[-1]
parent = identifiers[-2] if len(identifiers) > 1 else None
cqt = ColumnQualifierTuple(column, parent)
elif segment.type == "column_reference":
sub_segments = list_child_segments(segment)
column = sub_segments[-1].raw
parent = sub_segments[-2].raw if len(sub_segments) > 1 else None
cqt = ColumnQualifierTuple(column, parent)
elif segment.type == "identifier":
cqt = ColumnQualifierTuple(segment.raw, None)
return cqt


def extract_innermost_bracketed(bracketed_segment: BaseSegment) -> BaseSegment:
# in case of subquery in nested parenthesis like: `SELECT * FROM ((table))`, find the innermost one first
while True:
Expand Down
21 changes: 21 additions & 0 deletions tests/sql/column/test_column_select_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@ def test_select_column_wildcard():
)


def test_select_column_wildcard_with_qualifier():
sql = """INSERT INTO tab1
SELECT tab2.*
FROM tab2 a
INNER JOIN tab3 b
ON a.id = b.id"""
assert_column_lineage_equal(
sql,
[(ColumnQualifierTuple("*", "tab2"), ColumnQualifierTuple("*", "tab1"))],
)
sql = """INSERT INTO tab1
SELECT a.*
FROM tab2 a
INNER JOIN tab3 b
ON a.id = b.id"""
assert_column_lineage_equal(
sql,
[(ColumnQualifierTuple("*", "tab2"), ColumnQualifierTuple("*", "tab1"))],
)


def test_select_distinct_column():
sql = """INSERT INTO tab1
SELECT DISTINCT col1
Expand Down

0 comments on commit 13d1197

Please sign in to comment.