Skip to content

Commit

Permalink
feat: expand wildcard with MetadataProvider (#515)
Browse files Browse the repository at this point in the history
* feat: support parsing wildcard in sqlparse

* refactor: simplify code logic

* refactor: introduce session metadata

* refactor: unify analyzer.analyze method

* docs: add comment on wildcard in identifier list logic

* refactor: put wildcard expansion logic at end

* test: use assert_column_lineage_equal

* feat: support sqlfluff expand wildcard

* refactor: make some holder method private

* refactor: use context manager for metadata session

* refactor: use sqllineage.core.models as metadata param

* test: split test cases

* refactor: optimize column.parent method

* refactor: simplify code

---------

Co-authored-by: lixxvsky <lixxvsky@163.com>
  • Loading branch information
reata and lixxvsky committed Dec 28, 2023
1 parent 8b37464 commit 11a4b4a
Show file tree
Hide file tree
Showing 13 changed files with 542 additions and 58 deletions.
9 changes: 7 additions & 2 deletions sqllineage/core/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider


class LineageAnalyzer:
Expand All @@ -13,10 +14,14 @@ class LineageAnalyzer:
SUPPORTED_DIALECTS: List[str] = []

@abstractmethod
def analyze(self, sql: str, silent_mode: bool) -> StatementLineageHolder:
def analyze(
self, sql: str, metadata_provider: MetaDataProvider
) -> StatementLineageHolder:
"""
to analyze single statement sql and store the result into
:class:`sqllineage.core.holders.StatementLineageHolder`
:param silent_mode: skip unsupported statements.
:param sql: single-statement SQL string to be processed
:param metadata_provider: :class:`sqllineage.core.metadata_provider.MetaDataProvider` provides metadata on
tables to help lineage analyzing
"""
95 changes: 83 additions & 12 deletions sqllineage/core/holders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import itertools
from typing import List, Set, Tuple, Union
from typing import List, Optional, Set, Tuple, Union

import networkx as nx
from networkx import DiGraph
Expand Down Expand Up @@ -73,6 +73,8 @@ def add_read(self, value) -> None:

@property
def write(self) -> Set[Union[SubQuery, Table]]:
# SubQueryLineageHolder.write can return a single SubQuery or Table, or both when __or__ together.
# This is different from StatementLineageHolder.write, where Table is the only possibility.
return self._property_getter(NodeTag.WRITE)

def add_write(self, value) -> None:
Expand All @@ -88,12 +90,12 @@ def add_cte(self, value) -> None:
@property
def write_columns(self) -> List[Column]:
"""
in case of DML with column specified, like INSERT INTO tab1 (col1, col2) SELECT ...
write_columns tell us that tab1 has column col1 and col2 in that order.
return a list of columns that write table contains
it's either manually added via `add_write_column` if specified in DML
or automatic added via `add_column_lineage` after parsing from SELECT
"""
tgt_cols = []
if self.write:
tgt_tbl = list(self.write)[0]
if tgt_tbl := self._get_target_table():
tgt_col_with_idx: List[Tuple[Column, int]] = sorted(
[
(col, attr.get(EdgeTag.INDEX, 0))
Expand All @@ -106,6 +108,10 @@ def write_columns(self) -> List[Column]:
return tgt_cols

def add_write_column(self, *tgt_cols: Column) -> None:
"""
in case of DML with column specified, like INSERT INTO tab1 (col1, col2) SELECT col3, col4
this method is called to make sure tab1 has column col1 and col2 instead of col3 and col4
"""
if self.write:
tgt_tbl = list(self.write)[0]
for idx, tgt_col in enumerate(tgt_cols):
Expand All @@ -124,6 +130,74 @@ def add_column_lineage(self, src: Column, tgt: Column) -> None:
# starting NetworkX v2.6, None is not allowed as node, see https://github.com/networkx/networkx/pull/4892
self.graph.add_edge(src.parent, src, type=EdgeType.HAS_COLUMN)

def get_table_columns(self, table: Union[Table, SubQuery]) -> List[Column]:
return [
tgt
for (src, tgt, edge_type) in self.graph.out_edges(nbunch=table, data="type")
if edge_type == EdgeType.HAS_COLUMN
and isinstance(tgt, Column)
and tgt.raw_name != "*"
]

def expand_wildcard(self, metadata_provider: MetaDataProvider) -> None:
if tgt_table := self._get_target_table():
for column in self.write_columns:
if column.raw_name == "*":
tgt_wildcard = column
for src_wildcard in self._get_source_columns(tgt_wildcard):
if source_table := src_wildcard.parent:
src_table_columns = []
if isinstance(source_table, SubQuery):
# the columns of SubQuery can be inferred from graph
src_table_columns = self.get_table_columns(source_table)
elif isinstance(source_table, Table) and metadata_provider:
# search by metadata service
src_table_columns = metadata_provider.get_table_columns(
source_table
)
if src_table_columns:
self._replace_wildcard(
tgt_table,
src_table_columns,
tgt_wildcard,
src_wildcard,
)

def _get_target_table(self) -> Optional[Union[SubQuery, Table]]:
table = None
if write_only := self.write.difference(self.read):
table = next(iter(write_only))
return table

def _get_source_columns(self, node: Column) -> List[Column]:
return [
src
for (src, tgt, edge_type) in self.graph.in_edges(nbunch=node, data="type")
if edge_type == EdgeType.LINEAGE and isinstance(src, Column)
]

def _replace_wildcard(
self,
tgt_table: Union[Table, SubQuery],
src_table_columns: List[Column],
tgt_wildcard: Column,
src_wildcard: Column,
) -> None:
target_columns = self.get_table_columns(tgt_table)
for src_col in src_table_columns:
new_column = Column(src_col.raw_name)
new_column.parent = tgt_table
if new_column in target_columns or src_col.raw_name == "*":
continue
self.graph.add_edge(tgt_table, new_column, type=EdgeType.HAS_COLUMN)
self.graph.add_edge(src_col.parent, src_col, type=EdgeType.HAS_COLUMN)
self.graph.add_edge(src_col, new_column, type=EdgeType.LINEAGE)
# remove wildcard
if self.graph.has_node(tgt_wildcard):
self.graph.remove_node(tgt_wildcard)
if self.graph.has_node(src_wildcard):
self.graph.remove_node(src_wildcard)


class StatementLineageHolder(SubQueryLineageHolder, ColumnLineageMixin):
"""
Expand Down Expand Up @@ -318,13 +392,10 @@ def _build_digraph(
isinstance(parent, Table)
and str(parent.schema) != Schema.unknown
):
columns = metadata_provider.get_table_columns(
str(parent.schema), str(parent.raw_name)
)
if columns is not None and unresolved_col.raw_name in columns:
src_col = Column(unresolved_col.raw_name)
src_col.parent = parent
src_cols.append(src_col)
columns = metadata_provider.get_table_columns(parent)
for src_col in columns:
if unresolved_col.raw_name == src_col.raw_name:
src_cols.append(src_col)

if len(src_cols) > 1:
raise InvalidSyntaxException(
Expand Down
3 changes: 2 additions & 1 deletion sqllineage/core/metadata/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ class DummyMetaDataProvider(MetaDataProvider):
"""

def __init__(self, metadata: Optional[Dict[str, List[str]]] = None):
super().__init__()
self.metadata = metadata if metadata is not None else {}

def get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
return self.metadata.get(f"{schema}.{table}", [])

def __bool__(self):
Expand Down
69 changes: 61 additions & 8 deletions sqllineage/core/metadata_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import abstractmethod
from typing import List
from typing import Dict, List

from sqllineage.core.models import Column, Table


class MetaDataProvider:
"""Abstract class used to provide metadata service like table schema.
"""Base class used to provide metadata service like table schema.
When parse below sql:
Insert into db1.table1
Expand All @@ -17,17 +19,68 @@ class MetaDataProvider:
It can help parse column lineage correctly.
"""

def __init__(self) -> None:
self._session_metadata: Dict[str, List[str]] = {}

def get_table_columns(self, table: Table, **kwargs) -> List[Column]:
cols = self._session_metadata.get(
str(table),
self._get_table_columns(str(table.schema), table.raw_name, **kwargs),
)
columns = []
for col in cols:
column = Column(col)
column.parent = table
columns.append(column)
return columns

@abstractmethod
def get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
"""Get all columns of a table.
def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
"""To be implemented by subclasses."""

:param schema: database name
:param table: table name
:return: a list of column names
"""
def register_session_metadata(self, table: Table, columns: List[Column]) -> None:
"""Register session-level metadata, like temporary table or view created."""
self._session_metadata[str(table)] = [c.raw_name for c in columns]

def deregister_session_metadata(self) -> None:
"""Deregister session-level metadata."""
self._session_metadata.clear()

def session(self):
return MetaDataSession(self)

def __bool__(self):
"""
bool value tells whether this provider is ready to provide metadata
"""
return True

def open(self) -> None: # noqa
"""Open metadata connection if needed."""
pass

def close(self) -> None:
"""Close metadata connection if needed"""
pass


class MetaDataSession:
"""
Create an analyzer session which can register session-level metadata as a supplement to global metadata.
This way, table or views created during the session can be queried.
All session-level metadata will be deregistered once session closed.
"""

def __init__(self, metadata_provider: MetaDataProvider):
self.metadata_provider = metadata_provider
self.metadata_provider.open()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.metadata_provider.deregister_session_metadata()
self.metadata_provider.close()

def register_session_metadata(self, table: Table, columns: List[Column]) -> None:
self.metadata_provider.register_session_metadata(table, columns)
2 changes: 1 addition & 1 deletion sqllineage/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __hash__(self):

@property
def parent(self) -> Optional[Union[Path, Table, SubQuery]]:
return list(self._parent)[0] if len(self._parent) == 1 else None
return next(iter(self._parent)) if len(self._parent) == 1 else None

@parent.setter
def parent(self, value: Union[Path, Table, SubQuery]):
Expand Down
12 changes: 8 additions & 4 deletions sqllineage/core/parser/sqlfluff/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
from sqllineage.exceptions import (
InvalidSyntaxException,
Expand All @@ -20,8 +21,9 @@ class SqlFluffLineageAnalyzer(LineageAnalyzer):
PARSER_NAME = "sqlfluff"
SUPPORTED_DIALECTS = list(dialect.label for dialect in dialect_readout())

def __init__(self, dialect: str):
def __init__(self, dialect: str, silent_mode: bool = False):
self._dialect = dialect
self._silent_mode = silent_mode
self.tsql_split_cache: Dict[str, BaseSegment] = {}

def split_tsql(self, sql: str) -> List[str]:
Expand All @@ -35,7 +37,9 @@ def split_tsql(self, sql: str) -> List[str]:
sqls.append(segment.raw)
return sqls

def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder:
def analyze(
self, sql: str, metadata_provider: MetaDataProvider
) -> StatementLineageHolder:
if sql in self.tsql_split_cache:
statement_segments = [self.tsql_split_cache[sql]]
else:
Expand All @@ -47,7 +51,7 @@ def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder
else:
statement_segment = statement_segments[0]
for extractor in [
extractor_cls(self._dialect)
extractor_cls(self._dialect, metadata_provider)
for extractor_cls in BaseExtractor.__subclasses__()
]:
if extractor.can_extract(statement_segment.type):
Expand All @@ -56,7 +60,7 @@ def analyze(self, sql: str, silent_mode: bool = False) -> StatementLineageHolder
)
return StatementLineageHolder.of(lineage_holder)
else:
if silent_mode:
if self._silent_mode:
warnings.warn(
f"SQLLineage doesn't support analyzing statement type [{statement_segment.type}] for SQL:"
f"{sql}"
Expand Down
10 changes: 7 additions & 3 deletions sqllineage/core/parser/sqlfluff/extractors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import SubQuery, Table
from sqllineage.core.parser.sqlfluff.models import SqlFluffSubQuery, SqlFluffTable
from sqllineage.core.parser.sqlfluff.utils import (
Expand All @@ -21,8 +22,9 @@ class BaseExtractor:

SUPPORTED_STMT_TYPES: List[str] = []

def __init__(self, dialect: str):
def __init__(self, dialect: str, metadata_provider: MetaDataProvider):
self.dialect = dialect
self.metadata_provider = metadata_provider

def can_extract(self, statement_type: str) -> bool:
"""
Expand Down Expand Up @@ -99,7 +101,9 @@ def delegate_to(
"""
delegate to another type of extractor to extract
"""
return extractor_cls(self.dialect).extract(segment, context)
return extractor_cls(self.dialect, self.metadata_provider).extract(
segment, context
)

def extract_subquery(
self, subqueries: List[SubQuery], holder: SubQueryLineageHolder
Expand All @@ -116,7 +120,7 @@ def extract_subquery(
if sq.query.get_child("with_compound_statement")
else SelectExtractor
)
holder |= extractor_cls(self.dialect).extract(
holder |= extractor_cls(self.dialect, self.metadata_provider).extract(
sq.query, AnalyzerContext(cte=holder.cte, write={sq})
)

Expand Down
7 changes: 5 additions & 2 deletions sqllineage/core/parser/sqlfluff/extractors/select.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.holders import SubQueryLineageHolder
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import Path
from sqllineage.core.parser import SourceHandlerMixin
from sqllineage.core.parser.sqlfluff.extractors.base import BaseExtractor
Expand Down Expand Up @@ -28,8 +29,8 @@ class SelectExtractor(BaseExtractor, SourceHandlerMixin):

SUPPORTED_STMT_TYPES = ["select_statement", "set_expression", "bracketed"]

def __init__(self, dialect: str):
super().__init__(dialect)
def __init__(self, dialect: str, metadata_provider: MetaDataProvider):
super().__init__(dialect, metadata_provider)
self.columns = []
self.tables = []
self.union_barriers = []
Expand Down Expand Up @@ -74,6 +75,8 @@ def extract(

self.extract_subquery(subqueries, holder)

holder.expand_wildcard(self.metadata_provider)

return holder

@staticmethod
Expand Down

0 comments on commit 11a4b4a

Please sign in to comment.