Skip to content

Commit

Permalink
feat: SQLAlchemy-based MetaDataProvider
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Dec 30, 2023
1 parent 11a4b4a commit 74885ce
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 44 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def run(self) -> None:
"sqlparse==0.4.4",
"networkx>=2.4",
"sqlfluff==2.3.5",
"sqlalchemy>=2.0.0",
],
entry_points={"console_scripts": ["sqllineage = sqllineage.cli:main"]},
extras_require={
Expand Down
46 changes: 46 additions & 0 deletions sqllineage/core/metadata/sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import logging
from typing import List

from sqlalchemy import MetaData, Table, create_engine, inspect, make_url
from sqlalchemy.exc import NoSuchModuleError, NoSuchTableError, OperationalError

from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.exceptions import MetaDataProviderException


logger = logging.getLogger(__name__)


class SQLAlchemyMetaDataProvider(MetaDataProvider):
"""
SQLAlchemyMetaDataProvider queries metadata from database using SQLAlchemy
"""

def __init__(self, url: str):
super().__init__()
self.metadata_obj = MetaData()
try:
self.engine = create_engine(url)
except NoSuchModuleError as e:
u = make_url(url)
raise MetaDataProviderException(
f"SQLAlchemy dialect driver {u.drivername} is not installed correctly"
) from e
try:
self.engine.connect()
except OperationalError as e:
raise MetaDataProviderException(f"Could not connect to {url}") from e

def _get_table_columns(self, schema: str, table: str, **kwargs) -> List[str]:
columns = []
if inspect(self.engine).has_schema(schema):
try:
sqlalchemy_table = Table(
table, self.metadata_obj, schema=schema, autoload_with=self.engine
)
columns = [c.name for c in sqlalchemy_table.columns]
except NoSuchTableError:
logger.warning("%s does not exist in database %s", table, schema)
else:
logger.warning("Schema %s does not exist", schema)
return columns
18 changes: 4 additions & 14 deletions sqllineage/core/metadata_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ def __init__(self) -> None:
self._session_metadata: Dict[str, List[str]] = {}

def get_table_columns(self, table: Table, **kwargs) -> List[Column]:
cols = self._session_metadata.get(
str(table),
self._get_table_columns(str(table.schema), table.raw_name, **kwargs),
)
if (key := str(table)) in self._session_metadata:
cols = self._session_metadata[key]
else:
cols = self._get_table_columns(str(table.schema), table.raw_name, **kwargs)
columns = []
for col in cols:
column = Column(col)
Expand Down Expand Up @@ -55,14 +55,6 @@ def __bool__(self):
"""
return True

def open(self) -> None: # noqa
"""Open metadata connection if needed."""
pass

def close(self) -> None:
"""Close metadata connection if needed"""
pass


class MetaDataSession:
"""
Expand All @@ -73,14 +65,12 @@ class MetaDataSession:

def __init__(self, metadata_provider: MetaDataProvider):
self.metadata_provider = metadata_provider
self.metadata_provider.open()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.metadata_provider.deregister_session_metadata()
self.metadata_provider.close()

def register_session_metadata(self, table: Table, columns: List[Column]) -> None:
self.metadata_provider.register_session_metadata(table, columns)
4 changes: 4 additions & 0 deletions sqllineage/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,7 @@ class UnsupportedStatementException(SQLLineageException):

class InvalidSyntaxException(SQLLineageException):
"""Raised for SQL statement that parser cannot parse"""


class MetaDataProviderException(SQLLineageException):
"""Raised for MetaDataProvider errors"""
21 changes: 21 additions & 0 deletions tests/core/test_metadata_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pytest

from sqllineage.core.metadata.sqlalchemy import SQLAlchemyMetaDataProvider
from sqllineage.exceptions import MetaDataProviderException


def test_sqlalchemy_metadata_provider_connection_fail():
# connect to /root directory as sqlite db, which is not possible. Simulate connection failure
with pytest.raises(MetaDataProviderException):
SQLAlchemyMetaDataProvider("sqlite:////root/")
# use an unknown driver to connect. Simulate driver not installed
with pytest.raises(MetaDataProviderException):
SQLAlchemyMetaDataProvider("sqlite+unknown_driver:///:memory:")


def test_sqlalchemy_metadata_provider_query_fail():
provider = SQLAlchemyMetaDataProvider("sqlite:///:memory:")
assert (
provider._get_table_columns("non_existing_schema", "non_existing_table") == []
)
assert provider._get_table_columns("main", "non_existing_table") == []
38 changes: 38 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import tempfile
from pathlib import Path
from typing import Optional

import networkx as nx
from sqlalchemy import (
Column as SQLAlchemyColumn,
Integer,
MetaData,
Table as SQLAlchemyTable,
inspect,
text,
)

from sqllineage import SQLPARSE_DIALECT
from sqllineage.core.metadata.dummy import DummyMetaDataProvider
from sqllineage.core.metadata.sqlalchemy import SQLAlchemyMetaDataProvider
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.core.models import Column, Table
from sqllineage.runner import LineageRunner
Expand Down Expand Up @@ -93,3 +104,30 @@ def assert_lr_graphs_match(lr: LineageRunner, lr_sqlfluff: LineageRunner) -> Non
f"\n\tGraph with sqlparse: {lr._sql_holder.graph}\n\t"
f"Graph with sqlfluff: {lr_sqlfluff._sql_holder.graph}"
)


def generate_metadata_providers(test_schemas):
dummy_provider = DummyMetaDataProvider(test_schemas)

sqlite3_sqlalchemy_provider = SQLAlchemyMetaDataProvider("sqlite:///:memory:")
metadata = MetaData()
with tempfile.TemporaryDirectory() as tmpdirname:
for full_table_name, columns_names in test_schemas.items():
schema, table = full_table_name.split(".")
if schema not in ("main", "temp") and not inspect(
sqlite3_sqlalchemy_provider.engine
).has_schema(schema):
db_file_path = Path(tmpdirname).joinpath(f"{schema}.db")
with sqlite3_sqlalchemy_provider.engine.connect() as conn:
conn.execute(
text(f"ATTACH DATABASE '{db_file_path}' AS '{schema}'")
)
SQLAlchemyTable(
table,
metadata,
*[SQLAlchemyColumn(c, Integer) for c in columns_names],
schema=schema,
)
metadata.create_all(bind=sqlite3_sqlalchemy_provider.engine)

return [dummy_provider, sqlite3_sqlalchemy_provider]
32 changes: 19 additions & 13 deletions tests/sql/column/test_metadata_unqualified_column.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import pytest
from tests.helpers import assert_column_lineage_equal
from tests.helpers import assert_column_lineage_equal, generate_metadata_providers

from sqllineage.core.metadata.dummy import DummyMetaDataProvider
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.exceptions import InvalidSyntaxException
from sqllineage.runner import LineageRunner
from sqllineage.utils.entities import ColumnQualifierTuple


test_schemas = {
"db1.table1": ["id", "a", "b", "c", "d"],
"db2.table2": ["id", "h", "i", "j", "k"],
"db3.table3": ["pk", "p", "q", "r"],
}
provider = DummyMetaDataProvider(test_schemas)
providers = generate_metadata_providers(
{
"db1.table1": ["id", "a", "b", "c", "d"],
"db2.table2": ["id", "h", "i", "j", "k"],
"db3.table3": ["pk", "p", "q", "r"],
}
)


def test_select_column_from_tables():
@pytest.mark.parametrize("provider", providers)
def test_select_column_from_tables(provider: MetaDataProvider):
sql = """insert into db.tbl
select t1.id, a as x, b, h, i as y
from db1.table1 t1
Expand Down Expand Up @@ -90,7 +92,8 @@ def test_select_column_from_tables():
)


def test_select_column_from_subqueries():
@pytest.mark.parametrize("provider", providers)
def test_select_column_from_subqueries(provider: MetaDataProvider):
sql = """insert into db.tbl
select a, b as x, h, y
from (select a, b, c from db1.table1) t1
Expand Down Expand Up @@ -158,7 +161,8 @@ def test_select_column_from_subqueries():
)


def test_select_column_from_table_subquery():
@pytest.mark.parametrize("provider", providers)
def test_select_column_from_table_subquery(provider: MetaDataProvider):
sql = """insert into db.tbl
select a as x, b, q, pk as y
from db1.table1 t1
Expand Down Expand Up @@ -226,7 +230,8 @@ def test_select_column_from_table_subquery():
)


def test_select_column_from_tempview_view_subquery():
@pytest.mark.parametrize("provider", providers)
def test_select_column_from_tempview_view_subquery(provider: MetaDataProvider):
sql = """
create or replace view test_view
as
Expand Down Expand Up @@ -271,7 +276,8 @@ def test_select_column_from_tempview_view_subquery():
)


def test_sqlparse_exception():
@pytest.mark.parametrize("provider", providers)
def test_sqlparse_exception(provider: MetaDataProvider):
sql = """insert into table db.tbl
select id
from db1.table1 t1
Expand Down
49 changes: 32 additions & 17 deletions tests/sql/column/test_metadata_wildcard.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
from tests.helpers import assert_column_lineage_equal
import pytest
from tests.helpers import assert_column_lineage_equal, generate_metadata_providers

from sqllineage.core.metadata.dummy import DummyMetaDataProvider
from sqllineage.core.metadata_provider import MetaDataProvider
from sqllineage.utils.entities import ColumnQualifierTuple


test_schemas = {
"db.tbl_x": ["id", "a", "b"],
"db.tbl_y": ["id", "h", "i"],
"db.tbl_z": ["pk", "s", "t"],
}
provider = DummyMetaDataProvider(test_schemas)
providers = generate_metadata_providers(
{
"db.tbl_x": ["id", "a", "b"],
"db.tbl_y": ["id", "h", "i"],
"db.tbl_z": ["pk", "s", "t"],
}
)


def test_select_single_table_wildcard():
@pytest.mark.parametrize("provider", providers)
def test_select_single_table_wildcard(provider: MetaDataProvider):
sql = """insert into test_v
select *
from db.tbl_x
Expand All @@ -37,7 +40,8 @@ def test_select_single_table_wildcard():
)


def test_select_single_table_wildcard_in_subquery():
@pytest.mark.parametrize("provider", providers)
def test_select_single_table_wildcard_in_subquery(provider: MetaDataProvider):
sql = """insert into test_v
select id, b
from (
Expand All @@ -63,7 +67,8 @@ def test_select_single_table_wildcard_in_subquery():
)


def test_select_single_table_partial_wildcard_in_subquery():
@pytest.mark.parametrize("provider", providers)
def test_select_single_table_partial_wildcard_in_subquery(provider: MetaDataProvider):
sql = """insert into test_v
select a, b
from (
Expand Down Expand Up @@ -110,7 +115,8 @@ def test_select_single_table_partial_wildcard_in_subquery():
)


def test_select_table_join_partial_wildcard_at_last():
@pytest.mark.parametrize("provider", providers)
def test_select_table_join_partial_wildcard_at_last(provider: MetaDataProvider):
sql = """insert into test_v
select a, id, h
from (
Expand Down Expand Up @@ -138,7 +144,8 @@ def test_select_table_join_partial_wildcard_at_last():
)


def test_select_table_join_partial_wildcard_at_beginning():
@pytest.mark.parametrize("provider", providers)
def test_select_table_join_partial_wildcard_at_beginning(provider: MetaDataProvider):
sql = """insert into test_v
select a, id, h
from (
Expand Down Expand Up @@ -166,7 +173,8 @@ def test_select_table_join_partial_wildcard_at_beginning():
)


def test_select_table_join_multiple_wildcards():
@pytest.mark.parametrize("provider", providers)
def test_select_table_join_multiple_wildcards(provider: MetaDataProvider):
sql = """insert into test_v
select a, h
from (
Expand All @@ -190,7 +198,10 @@ def test_select_table_join_multiple_wildcards():
)


def test_select_table_join_multiple_wildcards_merged_at_top_level():
@pytest.mark.parametrize("provider", providers)
def test_select_table_join_multiple_wildcards_merged_at_top_level(
provider: MetaDataProvider,
):
sql = """insert into test_v
select *
from (
Expand Down Expand Up @@ -230,7 +241,10 @@ def test_select_table_join_multiple_wildcards_merged_at_top_level():
)


def test_select_table_join_multiple_wildcards_at_different_level():
@pytest.mark.parametrize("provider", providers)
def test_select_table_join_multiple_wildcards_at_different_level(
provider: MetaDataProvider,
):
sql = """insert into test_v
select x.*, y.h, y.i
from (select * from db.tbl_x where a = 0) x
Expand Down Expand Up @@ -270,7 +284,8 @@ def test_select_table_join_multiple_wildcards_at_different_level():
)


def test_wildcard_reference_from_previous_statements():
@pytest.mark.parametrize("provider", providers)
def test_wildcard_reference_from_previous_statements(provider: MetaDataProvider):
sql = """create table test_x as
select a, b
from (
Expand Down

0 comments on commit 74885ce

Please sign in to comment.