-
Notifications
You must be signed in to change notification settings - Fork 215
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor: merge sqlfluff handler to extractor
- Loading branch information
Showing
24 changed files
with
730 additions
and
912 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from sqlfluff.core.parser import BaseSegment | ||
|
||
from sqllineage.core.holders import StatementLineageHolder | ||
from sqllineage.core.models import Path | ||
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor | ||
from sqllineage.core.parser.sqlfluff.utils import ( | ||
find_from_expression_element, | ||
get_child, | ||
get_children, | ||
list_child_segments, | ||
) | ||
from sqllineage.utils.entities import AnalyzerContext | ||
from sqllineage.utils.helpers import escape_identifier_name | ||
|
||
|
||
class CopyExtractor(BaseExtractor): | ||
""" | ||
Copy statement lineage extractor | ||
""" | ||
|
||
SUPPORTED_STMT_TYPES = [ | ||
"copy_statement", | ||
"copy_into_table_statement", | ||
] | ||
|
||
def extract( | ||
self, statement: BaseSegment, context: AnalyzerContext | ||
) -> StatementLineageHolder: | ||
holder = StatementLineageHolder() | ||
src_flag = tgt_flag = False | ||
for segment in list_child_segments(statement): | ||
if segment.type == "from_clause": | ||
if from_expression_element := find_from_expression_element(segment): | ||
for table_expression in get_children( | ||
from_expression_element, "table_expression" | ||
): | ||
if storage_location := get_child( | ||
table_expression, "storage_location" | ||
): | ||
holder.add_read(Path(storage_location.raw)) | ||
elif segment.type == "keyword": | ||
if segment.raw_upper in ["COPY", "INTO"]: | ||
tgt_flag = True | ||
elif segment.raw_upper == "FROM": | ||
src_flag = True | ||
continue | ||
|
||
if tgt_flag: | ||
if table := self.find_table(segment): | ||
holder.add_write(table) | ||
tgt_flag = False | ||
if src_flag: | ||
if segment.type in ["literal", "storage_location"]: | ||
path = Path(escape_identifier_name(segment.raw)) | ||
holder.add_read(path) | ||
src_flag = False | ||
|
||
return holder |
130 changes: 130 additions & 0 deletions
130
sqllineage/core/parser/sqlfluff/extractors/create_insert.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
from sqlfluff.core.parser import BaseSegment | ||
|
||
from sqllineage.core.holders import SubQueryLineageHolder | ||
from sqllineage.core.models import Path | ||
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor | ||
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor | ||
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable | ||
from sqllineage.core.parser.sqlfluff.utils import ( | ||
get_child, | ||
is_set_expression, | ||
list_child_segments, | ||
) | ||
from sqllineage.utils.entities import AnalyzerContext | ||
from sqllineage.utils.helpers import escape_identifier_name | ||
|
||
|
||
class CreateInsertExtractor(BaseExtractor): | ||
""" | ||
Create statement and Insert statement lineage extractor | ||
""" | ||
|
||
SUPPORTED_STMT_TYPES = [ | ||
"create_table_statement", | ||
"create_table_as_statement", | ||
"create_view_statement", | ||
"insert_statement", | ||
"insert_overwrite_directory_hive_fmt_statement", | ||
] | ||
|
||
def extract( | ||
self, | ||
statement: BaseSegment, | ||
context: AnalyzerContext, | ||
) -> SubQueryLineageHolder: | ||
holder = self._init_holder(context) | ||
src_flag = tgt_flag = False | ||
for segment in list_child_segments(statement): | ||
if segment.type == "with_compound_statement": | ||
holder |= self.delegate_to_cte(segment, holder) | ||
elif segment.type == "bracketed" and any( | ||
s.type == "with_compound_statement" for s in segment.segments | ||
): | ||
for sgmt in segment.segments: | ||
if sgmt.type == "with_compound_statement": | ||
holder |= self.delegate_to_cte(segment, holder) | ||
elif segment.type in ("select_statement", "set_expression"): | ||
holder |= self.delegate_to_select(segment, holder) | ||
elif segment.type == "bracketed" and ( | ||
self.list_subquery(segment) or is_set_expression(segment) | ||
): | ||
# note regular subquery within SELECT statement is handled by SelectExtractor, this is only to handle | ||
# top-level subquery in DML like: 1) create table foo as (subquery); 2) insert into foo (subquery) | ||
# subquery here isn't added as read source, and it inherits DML-level write_columns if parsed | ||
if subquery_segment := get_child( | ||
segment, "select_statement", "set_expression" | ||
): | ||
holder |= self.delegate_to_select(subquery_segment, holder) | ||
|
||
elif segment.type == "bracketed": | ||
# In case of bracketed column reference, add these target columns to holder | ||
# so that when we compute the column level lineage | ||
# we keep these columns into consideration | ||
sub_segments = list_child_segments(segment) | ||
if all( | ||
sub_segment.type in ["column_reference", "column_definition"] | ||
for sub_segment in sub_segments | ||
): | ||
# target columns only apply to bracketed column_reference and column_definition | ||
columns = [] | ||
for sub_segment in sub_segments: | ||
if sub_segment.type == "column_definition": | ||
sub_segment = get_child(sub_segment, "identifier") | ||
columns.append(SqlFluffColumn.of(sub_segment)) | ||
holder.add_write_column(*columns) | ||
|
||
elif segment.type == "keyword": | ||
if segment.raw_upper in [ | ||
"INTO", | ||
"OVERWRITE", | ||
"TABLE", | ||
"VIEW", | ||
"DIRECTORY", | ||
] or ( | ||
tgt_flag is True and segment.raw_upper in ["IF", "NOT", "EXISTS"] | ||
): | ||
tgt_flag = True | ||
elif segment.raw_upper in ["LIKE", "CLONE"]: | ||
src_flag = True | ||
continue | ||
|
||
if tgt_flag: | ||
if segment.type in ["table_reference", "object_reference"]: | ||
write_obj = SqlFluffTable.of(segment) | ||
holder.add_write(write_obj) | ||
elif segment.type == "literal": | ||
if segment.raw.isnumeric(): | ||
# Special Handling for Spark Bucket Table DDL | ||
pass | ||
else: | ||
holder.add_write(Path(escape_identifier_name(segment.raw))) | ||
tgt_flag = False | ||
if src_flag: | ||
if segment.type in ["table_reference", "object_reference"]: | ||
holder.add_read(SqlFluffTable.of(segment)) | ||
src_flag = False | ||
return holder | ||
|
||
def delegate_to_cte( | ||
self, segment: BaseSegment, holder: SubQueryLineageHolder | ||
) -> SubQueryLineageHolder: | ||
from .cte import CteExtractor | ||
|
||
return self.delegate_to( | ||
CteExtractor, segment, AnalyzerContext(cte=holder.cte, write=holder.write) | ||
) | ||
|
||
def delegate_to_select( | ||
self, | ||
segment: BaseSegment, | ||
holder: SubQueryLineageHolder, | ||
) -> SubQueryLineageHolder: | ||
return self.delegate_to( | ||
SelectExtractor, | ||
segment, | ||
AnalyzerContext( | ||
cte=holder.cte, | ||
write=holder.write, | ||
write_columns=holder.write_columns, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.