Skip to content

Commit

Permalink
refactor: remove hold_utils
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Aug 27, 2023
1 parent cbd915c commit 7f8ae5b
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 120 deletions.
44 changes: 35 additions & 9 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from typing import Union

from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.models import Path, SubQuery, Table
from sqllineage.core.models import Path
from sqllineage.core.parser import SourceHandlerMixin
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.holder_utils import retrieve_holder_data_from
from sqllineage.core.parser.sqlfluff.models import (
SqlFluffColumn,
SqlFluffSubQuery,
Expand Down Expand Up @@ -167,7 +164,6 @@ def _add_dataset_from_expression_element(
Append tables and subqueries identified in the 'from_expression_element' type segment to the table and
holder extra subqueries sets
"""
dataset: Union[Path, Table, SubQuery]
all_segments = [
seg for seg in list_child_segments(segment) if seg.type != "keyword"
]
Expand All @@ -190,7 +186,37 @@ def _add_dataset_from_expression_element(
else:
table_identifier = find_table_identifier(segment)
if table_identifier:
dataset = retrieve_holder_data_from(
all_segments, holder, table_identifier
)
self.tables.append(dataset)
subquery_flag = False
alias = None
if len(all_segments) > 1 and all_segments[1].type == "alias_expression":
all_segments = list_child_segments(all_segments[1])
alias = str(
all_segments[1].raw
if len(all_segments) > 1
else all_segments[0].raw
)
if "." not in table_identifier.raw:
cte_dict = {s.alias: s for s in holder.cte}
cte = cte_dict.get(table_identifier.raw)
if cte is not None:
# could reference CTE with or without alias
self.tables.append(
SqlFluffSubQuery.of(
cte.query,
alias or table_identifier.raw,
)
)
subquery_flag = True
if subquery_flag is False:
if table_identifier.type == "file_reference":
self.tables.append(
Path(
escape_identifier_name(
table_identifier.segments[-1].raw
)
)
)
else:
self.tables.append(
SqlFluffTable.of(table_identifier, alias=alias)
)
44 changes: 0 additions & 44 deletions sqllineage/core/parser/sqlfluff/holder_utils.py

This file was deleted.

15 changes: 8 additions & 7 deletions sqllineage/core/parser/sqlparse/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import reduce
from operator import add
from typing import List, Union
from typing import List, Optional, Union

import sqlparse
from sqlparse.sql import (
Expand All @@ -17,12 +17,11 @@

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder, SubQueryLineageHolder
from sqllineage.core.models import Column, SubQuery
from sqllineage.core.models import Column, SubQuery, Table
from sqllineage.core.parser.sqlparse.handlers.base import (
CurrentTokenBaseHandler,
NextTokenBaseHandler,
)
from sqllineage.core.parser.sqlparse.holder_utils import get_dataset_from_identifier
from sqllineage.core.parser.sqlparse.models import SqlParseSubQuery, SqlParseTable
from sqllineage.core.parser.sqlparse.utils import (
get_subquery_parentheses,
Expand Down Expand Up @@ -103,7 +102,7 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:
holder = StatementLineageHolder()
src_flag = tgt_flag = update_flag = insert_flag = False
insert_columns = []
direct_source = None
direct_source: Optional[Union[Table, SubQuery]] = None
for token in stmt.tokens:
if is_token_negligible(token):
continue
Expand All @@ -122,18 +121,20 @@ def _extract_from_dml_merge(cls, stmt: Statement) -> StatementLineageHolder:

if tgt_flag:
if isinstance(token, Identifier):
holder.add_write(get_dataset_from_identifier(token, holder))
holder.add_write(SqlParseTable.of(token))
tgt_flag = False
elif src_flag:
if isinstance(token, Identifier):
direct_source = get_dataset_from_identifier(token, holder)
holder.add_read(direct_source)
if subqueries := cls.parse_subquery(token):
for sq in subqueries:
direct_source = SqlParseSubQuery.of(sq.query, sq.alias)
holder |= cls._extract_from_dml(
sq.query,
AnalyzerContext(cte=holder.cte, write={sq}),
)
else:
direct_source = SqlParseTable.of(token)
holder.add_read(direct_source)
src_flag = False
elif update_flag:
comparisons = []
Expand Down
40 changes: 37 additions & 3 deletions sqllineage/core/parser/sqlparse/handlers/source.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from typing import Optional, Union

from sqlparse.sql import (
Case,
Expand All @@ -12,15 +13,16 @@
from sqlparse.tokens import Literal, Wildcard

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.models import Path
from sqllineage.core.models import Path, SubQuery, Table
from sqllineage.core.parser import SourceHandlerMixin
from sqllineage.core.parser.sqlparse.handlers.base import NextTokenBaseHandler
from sqllineage.core.parser.sqlparse.holder_utils import get_dataset_from_identifier
from sqllineage.core.parser.sqlparse.models import (
SqlParseColumn,
SqlParseSubQuery,
SqlParseTable,
)
from sqllineage.core.parser.sqlparse.utils import (
get_subquery_parentheses,
is_subquery,
is_values_clause,
)
Expand Down Expand Up @@ -123,5 +125,37 @@ def _handle_column(self, token: Token) -> None:
def _add_dataset_from_identifier(
self, identifier: Identifier, holder: SubQueryLineageHolder
) -> None:
if dataset := get_dataset_from_identifier(identifier, holder):
first_token = identifier.token_first(skip_cm=True)
if isinstance(first_token, Function):
# function() as alias, no dataset involved
return None
elif isinstance(first_token, Parenthesis) and is_values_clause(first_token):
# (VALUES ...) AS alias, no dataset involved
return None
dataset: Union[Table, SubQuery, Path]
path_match = re.match(r"(parquet|csv|json)\.`(.*)`", identifier.value)
if path_match is not None:
dataset = Path(path_match.groups()[1])
else:
read: Optional[Union[Table, SubQuery]] = None
subqueries = get_subquery_parentheses(identifier)
if len(subqueries) > 0:
# SELECT col1 FROM (SELECT col2 FROM tab1) dt, the subquery will be parsed as Identifier
# referring https://github.com/andialbrecht/sqlparse/issues/218 for further information
parenthesis, alias = subqueries[0]
read = SqlParseSubQuery.of(parenthesis, alias)
else:
cte_dict = {s.alias: s for s in holder.cte}
if "." not in identifier.value:
cte = cte_dict.get(identifier.get_real_name())
if cte is not None:
# could reference CTE with or without alias
read = SqlParseSubQuery.of(
cte.query,
identifier.get_alias() or identifier.get_real_name(),
)
if read is None:
read = SqlParseTable.of(identifier)
dataset = read
if dataset:
self.tables.append(dataset)
57 changes: 0 additions & 57 deletions sqllineage/core/parser/sqlparse/holder_utils.py

This file was deleted.

0 comments on commit 7f8ae5b

Please sign in to comment.