Skip to content

Commit

Permalink
refactor: merge sqlfluff handler to extractor
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Aug 26, 2023
1 parent 2eae7da commit aad9fb8
Show file tree
Hide file tree
Showing 24 changed files with 730 additions and 912 deletions.
6 changes: 2 additions & 4 deletions sqllineage/core/parser/sqlfluff/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import (
LineageHolderExtractor,
)
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.exceptions import (
InvalidSyntaxException,
UnsupportedStatementException,
Expand Down Expand Up @@ -58,7 +56,7 @@ def analyze(self, sql: str) -> StatementLineageHolder:
else:
for extractor in [
extractor_cls(self._dialect)
for extractor_cls in LineageHolderExtractor.__subclasses__()
for extractor_cls in BaseExtractor.__subclasses__()
]:
if extractor.can_extract(statement_segment.type):
lineage_holder = extractor.extract(
Expand Down
2 changes: 1 addition & 1 deletion sqllineage/core/parser/sqlfluff/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
import os
import pkgutil

# import each module so that LineageHolderExtractor's __subclasses__ will work
# import each module so that BaseExtractor's __subclasses__ will work
for module in pkgutil.iter_modules([os.path.dirname(__file__)]):
importlib.import_module(__name__ + "." + module.name)
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from functools import reduce
from operator import add
from typing import List
from typing import List, Optional, Type

from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.models import SubQuery
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
from sqllineage.core.models import SubQuery, Table
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
get_children,
is_subquery,
Expand All @@ -15,7 +15,7 @@
from sqllineage.utils.entities import AnalyzerContext, SubQueryTuple


class LineageHolderExtractor:
class BaseExtractor:
"""
Abstract class implementation for extract 'SubQueryLineageHolder' from different statement types
"""
Expand Down Expand Up @@ -46,7 +46,14 @@ def extract(
raise NotImplementedError

@classmethod
def parse_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
def find_table(cls, segment: BaseSegment) -> Optional[Table]:
table = None
if segment.type in ["table_reference", "object_reference"]:
table = SqlFluffTable.of(segment)
return table

@classmethod
def list_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
"""
The parse_subquery function takes a segment as an argument.
:param segment: segment to determine if it is a subquery
Expand Down Expand Up @@ -84,6 +91,17 @@ def _parse_subquery(cls, subqueries: List[SubQueryTuple]) -> List[SubQuery]:
for bracketed_segment, alias in subqueries
]

def delegate_to(
self,
extractor_cls: "Type[BaseExtractor]",
segment: BaseSegment,
context: AnalyzerContext,
) -> SubQueryLineageHolder:
"""
delegate to another type of extractor to extract
"""
return extractor_cls(self.dialect).extract(segment, context)

@staticmethod
def _init_holder(context: AnalyzerContext) -> SubQueryLineageHolder:
"""
Expand Down
58 changes: 58 additions & 0 deletions sqllineage/core/parser/sqlfluff/extractors/copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.models import Path
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.utils import (
find_from_expression_element,
get_child,
get_children,
list_child_segments,
)
from sqllineage.utils.entities import AnalyzerContext
from sqllineage.utils.helpers import escape_identifier_name


class CopyExtractor(BaseExtractor):
"""
Copy statement lineage extractor
"""

SUPPORTED_STMT_TYPES = [
"copy_statement",
"copy_into_table_statement",
]

def extract(
self, statement: BaseSegment, context: AnalyzerContext
) -> StatementLineageHolder:
holder = StatementLineageHolder()
src_flag = tgt_flag = False
for segment in list_child_segments(statement):
if segment.type == "from_clause":
if from_expression_element := find_from_expression_element(segment):
for table_expression in get_children(
from_expression_element, "table_expression"
):
if storage_location := get_child(
table_expression, "storage_location"
):
holder.add_read(Path(storage_location.raw))
elif segment.type == "keyword":
if segment.raw_upper in ["COPY", "INTO"]:
tgt_flag = True
elif segment.raw_upper == "FROM":
src_flag = True
continue

if tgt_flag:
if table := self.find_table(segment):
holder.add_write(table)
tgt_flag = False
if src_flag:
if segment.type in ["literal", "storage_location"]:
path = Path(escape_identifier_name(segment.raw))
holder.add_read(path)
src_flag = False

return holder
130 changes: 130 additions & 0 deletions sqllineage/core/parser/sqlfluff/extractors/create_insert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.models import Path
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
get_child,
is_set_expression,
list_child_segments,
)
from sqllineage.utils.entities import AnalyzerContext
from sqllineage.utils.helpers import escape_identifier_name


class CreateInsertExtractor(BaseExtractor):
"""
Create statement and Insert statement lineage extractor
"""

SUPPORTED_STMT_TYPES = [
"create_table_statement",
"create_table_as_statement",
"create_view_statement",
"insert_statement",
"insert_overwrite_directory_hive_fmt_statement",
]

def extract(
self,
statement: BaseSegment,
context: AnalyzerContext,
) -> SubQueryLineageHolder:
holder = self._init_holder(context)
src_flag = tgt_flag = False
for segment in list_child_segments(statement):
if segment.type == "with_compound_statement":
holder |= self.delegate_to_cte(segment, holder)
elif segment.type == "bracketed" and any(
s.type == "with_compound_statement" for s in segment.segments
):
for sgmt in segment.segments:
if sgmt.type == "with_compound_statement":
holder |= self.delegate_to_cte(segment, holder)
elif segment.type in ("select_statement", "set_expression"):
holder |= self.delegate_to_select(segment, holder)
elif segment.type == "bracketed" and (
self.list_subquery(segment) or is_set_expression(segment)
):
# note regular subquery within SELECT statement is handled by SelectExtractor, this is only to handle
# top-level subquery in DML like: 1) create table foo as (subquery); 2) insert into foo (subquery)
# subquery here isn't added as read source, and it inherits DML-level write_columns if parsed
if subquery_segment := get_child(
segment, "select_statement", "set_expression"
):
holder |= self.delegate_to_select(subquery_segment, holder)

elif segment.type == "bracketed":
# In case of bracketed column reference, add these target columns to holder
# so that when we compute the column level lineage
# we keep these columns into consideration
sub_segments = list_child_segments(segment)
if all(
sub_segment.type in ["column_reference", "column_definition"]
for sub_segment in sub_segments
):
# target columns only apply to bracketed column_reference and column_definition
columns = []
for sub_segment in sub_segments:
if sub_segment.type == "column_definition":
sub_segment = get_child(sub_segment, "identifier")
columns.append(SqlFluffColumn.of(sub_segment))
holder.add_write_column(*columns)

elif segment.type == "keyword":
if segment.raw_upper in [
"INTO",
"OVERWRITE",
"TABLE",
"VIEW",
"DIRECTORY",
] or (
tgt_flag is True and segment.raw_upper in ["IF", "NOT", "EXISTS"]
):
tgt_flag = True
elif segment.raw_upper in ["LIKE", "CLONE"]:
src_flag = True
continue

if tgt_flag:
if segment.type in ["table_reference", "object_reference"]:
write_obj = SqlFluffTable.of(segment)
holder.add_write(write_obj)
elif segment.type == "literal":
if segment.raw.isnumeric():
# Special Handling for Spark Bucket Table DDL
pass
else:
holder.add_write(Path(escape_identifier_name(segment.raw)))
tgt_flag = False
if src_flag:
if segment.type in ["table_reference", "object_reference"]:
holder.add_read(SqlFluffTable.of(segment))
src_flag = False
return holder

def delegate_to_cte(
self, segment: BaseSegment, holder: SubQueryLineageHolder
) -> SubQueryLineageHolder:
from .cte import CteExtractor

return self.delegate_to(
CteExtractor, segment, AnalyzerContext(cte=holder.cte, write=holder.write)
)

def delegate_to_select(
self,
segment: BaseSegment,
holder: SubQueryLineageHolder,
) -> SubQueryLineageHolder:
return self.delegate_to(
SelectExtractor,
segment,
AnalyzerContext(
cte=holder.cte,
write=holder.write,
write_columns=holder.write_columns,
),
)
Original file line number Diff line number Diff line change
@@ -1,58 +1,43 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.parser.sqlfluff.extractors.dml_insert_extractor import (
DmlInsertExtractor,
)
from sqllineage.core.parser.sqlfluff.extractors.dml_select_extractor import (
DmlSelectExtractor,
)
from sqllineage.core.parser.sqlfluff.extractors.lineage_holder_extractor import (
LineageHolderExtractor,
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.extractors.create_insert import (
CreateInsertExtractor,
)
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
from sqllineage.core.parser.sqlfluff.utils import get_children, list_child_segments
from sqllineage.utils.entities import AnalyzerContext


class DmlCteExtractor(LineageHolderExtractor):
class CteExtractor(BaseExtractor):
"""
DML CTE queries lineage extractor
CTE queries lineage extractor
"""

SUPPORTED_STMT_TYPES = ["with_compound_statement"]

def __init__(self, dialect: str):
super().__init__(dialect)

def extract(
self,
statement: BaseSegment,
context: AnalyzerContext,
) -> SubQueryLineageHolder:
"""
Extract lineage for a given statement.
:param statement: a sqlfluff segment with a statement
:param context: 'AnalyzerContext'
:return 'SubQueryLineageHolder' object
"""
holder = self._init_holder(context)
subqueries = []
for segment in list_child_segments(statement):
if segment.type in ["select_statement", "set_expression"]:
holder |= DmlSelectExtractor(self.dialect).extract(
holder |= self.delegate_to(
SelectExtractor,
segment,
AnalyzerContext(cte=holder.cte, write=holder.write),
)

if segment.type == "insert_statement":
holder |= DmlInsertExtractor(self.dialect).extract(
segment,
AnalyzerContext(cte=holder.cte),
elif segment.type == "insert_statement":
holder |= self.delegate_to(
CreateInsertExtractor, segment, AnalyzerContext(cte=holder.cte)
)

identifier = None
if segment.type == "common_table_expression":
elif segment.type == "common_table_expression":
identifier = None
segment_has_alias = any(
s for s in get_children(segment, "keyword") if s.raw_upper == "AS"
)
Expand All @@ -63,7 +48,7 @@ def extract(
if not segment_has_alias:
holder.add_cte(SqlFluffSubQuery.of(sub_segment, identifier))
if sub_segment.type == "bracketed":
for sq in self.parse_subquery(sub_segment):
for sq in self.list_subquery(sub_segment):
if identifier:
sq.alias = identifier
subqueries.append(sq)
Expand All @@ -72,7 +57,7 @@ def extract(

# By recursively extracting each extractor of the parent and merge, we're doing Depth-first search
for sq in subqueries:
holder |= DmlSelectExtractor(self.dialect).extract(
holder |= SelectExtractor(self.dialect).extract(
sq.query,
AnalyzerContext(cte=holder.cte, write={sq}),
)
Expand Down

0 comments on commit aad9fb8

Please sign in to comment.