Skip to content

Commit

Permalink
chore: update sqlfluff with strict type
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Sep 30, 2023
1 parent 07e199b commit 18ac748
Show file tree
Hide file tree
Showing 10 changed files with 97 additions and 125 deletions.
8 changes: 5 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
repos:
- repo: https://github.com/psf/black
rev: 23.1.0
rev: 23.9.1
hooks:
- id: black
language_version: python3.8
- repo: https://github.com/pycqa/flake8.git
rev: 6.0.0
rev: 6.1.0
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.0.1
rev: v1.5.1
hooks:
- id: mypy
additional_dependencies:
- sqlfluff
3 changes: 0 additions & 3 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,3 @@ warn_no_return=True
warn_redundant_casts=True
warn_unused_ignores=True
disallow_any_generics=True
[mypy-sqllineage.core.parser.sqlfluff.utils]
disallow_untyped_calls=False
warn_return_any = False
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def run(self) -> None:
install_requires=[
"sqlparse>=0.4.4",
"networkx>=2.4",
"sqlfluff==2.1.4",
"sqlfluff==2.3.2",
],
entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]},
extras_require={
Expand Down
3 changes: 1 addition & 2 deletions sqllineage/core/parser/sqlfluff/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from sqllineage.core.models import SubQuery, Table
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
get_children,
is_subquery,
list_subqueries,
)
Expand Down Expand Up @@ -61,7 +60,7 @@ def list_subquery(cls, segment: BaseSegment) -> List[SubQuery]:
types it returns an empty list.
"""
result: List[SubQuery] = []
identifiers = get_children(segment, "from_expression")
identifiers = segment.get_children("from_expression")
if identifiers and len(identifiers) > 1:
# for SQL89 style of JOIN or multiple CTEs, this is actually SubQueries
return reduce(
Expand Down
10 changes: 4 additions & 6 deletions sqllineage/core/parser/sqlfluff/extractors/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.core.parser.sqlfluff.utils import (
find_from_expression_element,
get_child,
get_children,
list_child_segments,
)
from sqllineage.utils.entities import AnalyzerContext
Expand All @@ -31,11 +29,11 @@ def extract(
for segment in list_child_segments(statement):
if segment.type == "from_clause":
if from_expression_element := find_from_expression_element(segment):
for table_expression in get_children(
from_expression_element, "table_expression"
for table_expression in from_expression_element.get_children(
"table_expression"
):
if storage_location := get_child(
table_expression, "storage_location"
if storage_location := table_expression.get_child(
"storage_location"
):
holder.add_read(Path(storage_location.raw))
elif segment.type == "keyword":
Expand Down
21 changes: 10 additions & 11 deletions sqllineage/core/parser/sqlfluff/extractors/create_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
from sqllineage.core.parser.sqlfluff.models import SqlFluffColumn, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
get_child,
get_children,
is_set_expression,
list_child_segments,
)
Expand Down Expand Up @@ -47,12 +45,12 @@ def extract(
elif segment.type in ("select_statement", "set_expression"):
holder |= self.delegate_to_select(segment, holder)
elif segment.type == "values_clause":
for bracketed in get_children(segment, "bracketed"):
for expression in get_children(bracketed, "expression"):
if sub_bracketed := get_child(expression, "bracketed"):
if sub_expression := get_child(sub_bracketed, "expression"):
if select_statement := get_child(
sub_expression, "select_statement"
for bracketed in segment.get_children("bracketed"):
for expression in bracketed.get_children("expression"):
if sub_bracketed := expression.get_child("bracketed"):
if sub_expression := sub_bracketed.get_child("expression"):
if select_statement := sub_expression.get_child(
"select_statement"
):
holder |= self.delegate_to_select(
select_statement, holder
Expand All @@ -63,8 +61,8 @@ def extract(
# note regular subquery within SELECT statement is handled by SelectExtractor, this is only to handle
# top-level subquery in DML like: 1) create table foo as (subquery); 2) insert into foo (subquery)
# subquery here isn't added as read source, and it inherits DML-level write_columns if parsed
if subquery_segment := get_child(
segment, "select_statement", "set_expression"
if subquery_segment := segment.get_child(
"select_statement", "set_expression"
):
holder |= self.delegate_to_select(subquery_segment, holder)

Expand All @@ -81,7 +79,8 @@ def extract(
columns = []
for sub_segment in sub_segments:
if sub_segment.type == "column_definition":
sub_segment = get_child(sub_segment, "identifier")
if identifier := sub_segment.get_child("identifier"):
sub_segment = identifier
columns.append(SqlFluffColumn.of(sub_segment))
holder.add_write_column(*columns)

Expand Down
4 changes: 2 additions & 2 deletions sqllineage/core/parser/sqlfluff/extractors/cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from sqllineage.core.parser.sqlfluff.extractors.select import SelectExtractor
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery
from sqllineage.core.parser.sqlfluff.utils import get_children, list_child_segments
from sqllineage.core.parser.sqlfluff.utils import list_child_segments
from sqllineage.utils.entities import AnalyzerContext


Expand Down Expand Up @@ -39,7 +39,7 @@ def extract(
elif segment.type == "common_table_expression":
identifier = None
segment_has_alias = any(
s for s in get_children(segment, "keyword") if s.raw_upper == "AS"
s for s in segment.get_children("keyword") if s.raw_upper == "AS"
)
sub_segments = list_child_segments(segment)
for sub_segment in sub_segments:
Expand Down
76 changes: 37 additions & 39 deletions sqllineage/core/parser/sqlfluff/extractors/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
extract_column_qualifier,
extract_identifier,
extract_innermost_bracketed,
get_child,
get_children,
list_child_segments,
)
from sqllineage.utils.entities import AnalyzerContext
Expand All @@ -38,19 +36,19 @@ def extract(
for i, segment in enumerate(segments):
if segment.type == "merge_match":
merge_match = segment
for merge_when_matched_clause in get_children(
merge_match, "merge_when_matched_clause"
for merge_when_matched_clause in merge_match.get_children(
"merge_when_matched_clause"
):
if merge_update_clause := get_child(
merge_when_matched_clause, "merge_update_clause"
if merge_update_clause := merge_when_matched_clause.get_child(
"merge_update_clause"
):
if set_clause_list := get_child(
merge_update_clause, "set_clause_list"
if set_clause_list := merge_update_clause.get_child(
"set_clause_list"
):
for set_clause in get_children(
set_clause_list, "set_clause"
for set_clause in set_clause_list.get_children(
"set_clause"
):
columns = get_children(set_clause, "column_reference")
columns = set_clause.get_children("column_reference")
if len(columns) == 2:
src_col = tgt_col = None
if src_cqt := extract_column_qualifier(columns[1]):
Expand All @@ -61,37 +59,37 @@ def extract(
tgt_col.parent = list(holder.write)[0]
if src_col is not None and tgt_col is not None:
holder.add_column_lineage(src_col, tgt_col)
for merge_when_not_matched_clause in get_children(
merge_match, "merge_when_not_matched_clause"
for merge_when_not_matched_clause in merge_match.get_children(
"merge_when_not_matched_clause"
):
merge_insert = get_child(
merge_when_not_matched_clause, "merge_insert_clause"
)
insert_columns = []
if bracketed := get_child(merge_insert, "bracketed"):
for column_reference in get_children(
bracketed, "column_reference"
):
if cqt := extract_column_qualifier(column_reference):
tgt_col = Column(cqt.column)
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
if values_clause := get_child(merge_insert, "values_clause"):
if bracketed := get_child(values_clause, "bracketed"):
for j, e in enumerate(
get_children(bracketed, "literal", "expression")
):
if column_reference := get_child(
e, "column_reference"
if merge_insert := merge_when_not_matched_clause.get_child(
"merge_insert_clause"
):
insert_columns = []
if bracketed := merge_insert.get_child("bracketed"):
for column_reference in bracketed.get_children(
"column_reference"
):
if cqt := extract_column_qualifier(column_reference):
tgt_col = Column(cqt.column)
tgt_col.parent = list(holder.write)[0]
insert_columns.append(tgt_col)
if values_clause := merge_insert.get_child("values_clause"):
if bracketed := values_clause.get_child("bracketed"):
for j, e in enumerate(
bracketed.get_children("literal", "expression")
):
if cqt := extract_column_qualifier(
column_reference
if column_reference_optional := e.get_child(
"column_reference"
):
src_col = Column(cqt.column)
src_col.parent = direct_source
holder.add_column_lineage(
src_col, insert_columns[j]
)
if cqt := extract_column_qualifier(
column_reference_optional
):
src_col = Column(cqt.column)
src_col.parent = direct_source
holder.add_column_lineage(
src_col, insert_columns[j]
)
elif segment.type == "keyword":
if segment.raw_upper in ["MERGE", "INTO"]:
tgt_flag = True
Expand Down
22 changes: 10 additions & 12 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from sqllineage.core.parser.sqlfluff.utils import (
find_from_expression_element,
find_table_identifier,
get_child,
get_children,
is_set_expression,
list_child_segments,
list_join_clause,
Expand Down Expand Up @@ -76,14 +74,14 @@ def _handle_swap_partition(segment: BaseSegment, holder: SubQueryLineageHolder):
A handler for swap_partitions_between_tables function
"""
if segment.type == "select_clause":
if select_clause_element := get_child(segment, "select_clause_element"):
if function := get_child(select_clause_element, "function"):
if select_clause_element := segment.get_child("select_clause_element"):
if function := select_clause_element.get_child("function"):
if (
function.first_non_whitespace_segment_raw_upper
== "SWAP_PARTITIONS_BETWEEN_TABLES"
):
if bracketed := get_child(function, "bracketed"):
expressions = get_children(bracketed, "expression")
if bracketed := function.get_child("bracketed"):
expressions = bracketed.get_children("expression")
holder.add_read(
SqlFluffTable(
escape_identifier_name(expressions[0].raw)
Expand Down Expand Up @@ -111,7 +109,7 @@ def _handle_table(
handle from_clause or join_clause, join_clause is a child node of from_clause.
"""
if segment.type in ["from_clause", "join_clause"]:
from_expressions = get_children(segment, "from_expression")
from_expressions = segment.get_children("from_expression")
if len(from_expressions) > 1:
# SQL89 style of join
for from_expression in from_expressions:
Expand All @@ -134,7 +132,7 @@ def _handle_column(self, segment: BaseSegment) -> None:
Column handler method
"""
if segment.type == "select_clause":
for sub_segment in get_children(segment, "select_clause_element"):
for sub_segment in segment.get_children("select_clause_element"):
self.columns.append(SqlFluffColumn.of(sub_segment))

def _handle_set(self, segment: BaseSegment) -> None:
Expand Down Expand Up @@ -167,14 +165,14 @@ def _add_dataset_from_expression_element(
all_segments = [
seg for seg in list_child_segments(segment) if seg.type != "keyword"
]
if table_expression := get_child(segment, "table_expression"):
if get_child(table_expression, "function"):
if table_expression := segment.get_child("table_expression"):
if table_expression.get_child("function"):
# for UNNEST or generator function, no dataset involved
return
first_segment = all_segments[0]
if first_segment.type == "bracketed":
if table_expression := get_child(first_segment, "table_expression"):
if get_child(table_expression, "values_clause"):
if table_expression := first_segment.get_child("table_expression"):
if table_expression.get_child("values_clause"):
# (VALUES ...) AS alias, no dataset involved
return
subqueries = list_subqueries(segment)
Expand Down

0 comments on commit 18ac748

Please sign in to comment.