Skip to content

Commit

Permalink
feat: configurably support split tsql without semicolon
Browse files Browse the repository at this point in the history
  • Loading branch information
reata committed Sep 2, 2023
1 parent a619ae5 commit 52558c6
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 59 deletions.
6 changes: 0 additions & 6 deletions sqllineage/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import os


NAME = "sqllineage"
VERSION = "1.4.7"
DEFAULT_LOGGING = {
Expand Down Expand Up @@ -31,9 +28,6 @@
}

STATIC_FOLDER = "build"
DATA_FOLDER = os.environ.get(
"SQLLINEAGE_DIRECTORY", os.path.join(os.path.dirname(__file__), "data")
)
DEFAULT_HOST = "localhost"
DEFAULT_PORT = 5000
SQLPARSE_DIALECT = "non-validating"
Expand Down
6 changes: 2 additions & 4 deletions sqllineage/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import logging
import logging.config

import warnings

from sqllineage import DEFAULT_DIALECT, DEFAULT_HOST, DEFAULT_LOGGING, DEFAULT_PORT
from sqllineage.drawing import draw_lineage_graph
Expand Down Expand Up @@ -78,9 +78,7 @@ def main(args=None) -> None:
)
args = parser.parse_args(args)
if args.e and args.f:
logging.warning(
"Both -e and -f options are specified. -e option will be ignored"
)
warnings.warn("Both -e and -f options are specified. -e option will be ignored")
if args.f or args.e:
sql = extract_sql_from_args(args)
runner = LineageRunner(
Expand Down
26 changes: 26 additions & 0 deletions sqllineage/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os


class _SQLLineageConfigLoader:
"""
Load all configurable items from environment variable, otherwise fallback to default
"""

# inspired by https://github.com/joke2k/django-environ
config = {
# for frontend directory drawer
"DIRECTORY": (str, os.path.join(os.path.dirname(__file__), "data")),
# to enable tsql no semicolon splitter mode
"TSQL_NO_SEMICOLON": (bool, False),
}

def __getattr__(self, item):
if item in self.config:
type_, default = self.config[item]
# require SQLLINEAGE_ prefix from environment variable
return type_(os.environ.get("SQLLINEAGE_" + item, default))
else:
return super().__getattribute__(item)


SQLLineageConfig = _SQLLineageConfigLoader()
70 changes: 45 additions & 25 deletions sqllineage/core/parser/sqlfluff/analyzer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Dict, List

from sqlfluff.core import Linter, SQLLexError, SQLParseError
from sqlfluff.core.parser import BaseSegment

from sqllineage.core.analyzer import LineageAnalyzer
from sqllineage.core.holders import StatementLineageHolder
Expand All @@ -17,8 +19,46 @@ class SqlFluffLineageAnalyzer(LineageAnalyzer):

def __init__(self, dialect: str):
self._dialect = dialect
self.tsql_split_cache: Dict[str, BaseSegment] = {}

def split_tsql(self, sql: str) -> List[str]:
"""
use sqlfluff parse to split tsql statements. This is in particular for semicolon not present cases.
The result is cached so that later analyze method doesn't have to parse regarding single statement sql.
"""
sqls = []
for segment in self._list_specific_statement_segment(sql):
self.tsql_split_cache[segment.raw] = segment
sqls.append(segment.raw)
return sqls

def analyze(self, sql: str) -> StatementLineageHolder:
if sql in self.tsql_split_cache:
statement_segments = [self.tsql_split_cache[sql]]
else:
statement_segments = self._list_specific_statement_segment(sql)
if len(statement_segments) == 0:
raise UnsupportedStatementException(
f"SQLLineage cannot parse SQL:" f"{sql}"
) # pragma: no cover
else:
statement_segment = statement_segments[0]
for extractor in [
extractor_cls(self._dialect)
for extractor_cls in BaseExtractor.__subclasses__()
]:
if extractor.can_extract(statement_segment.type):
lineage_holder = extractor.extract(
statement_segment, AnalyzerContext()
)
return StatementLineageHolder.of(lineage_holder)
else:
raise UnsupportedStatementException(
f"SQLLineage doesn't support analyzing statement type [{statement_segment.type}] for SQL:"
f"{sql}"
)

def _list_specific_statement_segment(self, sql: str):
parsed = Linter(dialect=self._dialect).parse_string(sql)
violations = [
str(e)
Expand All @@ -32,12 +72,10 @@ def analyze(self, sql: str) -> StatementLineageHolder:
f"{sql}\n"
f"{violation_msg}"
)

statement_segment = None
segments = []
for top_segment in getattr(parsed.tree, "segments", []):
if top_segment.type == "statement":
statement_segment = top_segment.segments[0]
break
segments.append(top_segment.segments[0])
elif top_segment.type == "batch":
statements = top_segment.get_children("statement")
if len(statements) > 1:
Expand All @@ -47,24 +85,6 @@ def analyze(self, sql: str) -> StatementLineageHolder:
SyntaxWarning,
stacklevel=2,
)
statement_segment = statements[0].segments[0]
break
if statement_segment is None:
raise UnsupportedStatementException(
f"SQLLineage cannot parse SQL:" f"{sql}"
) # pragma: no cover
else:
for extractor in [
extractor_cls(self._dialect)
for extractor_cls in BaseExtractor.__subclasses__()
]:
if extractor.can_extract(statement_segment.type):
lineage_holder = extractor.extract(
statement_segment, AnalyzerContext()
)
return StatementLineageHolder.of(lineage_holder)
else:
raise UnsupportedStatementException(
f"SQLLineage doesn't support analyzing statement type [{statement_segment.type}] for SQL:"
f"{sql}"
)
for statement in statements:
segments.append(statement.segments[0])
return segments
13 changes: 4 additions & 9 deletions sqllineage/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,8 @@
from urllib.parse import urlencode
from wsgiref.simple_server import make_server

from sqllineage import (
DATA_FOLDER,
DEFAULT_DIALECT,
DEFAULT_HOST,
DEFAULT_PORT,
)
from sqllineage import STATIC_FOLDER
from sqllineage import DEFAULT_DIALECT, DEFAULT_HOST, DEFAULT_PORT, STATIC_FOLDER
from sqllineage.config import SQLLineageConfig
from sqllineage.exceptions import SQLLineageException
from sqllineage.utils.constant import LineageLevel
from sqllineage.utils.helpers import extract_sql_from_args
Expand All @@ -34,7 +29,7 @@
class SQLLineageApp:
def __init__(self) -> None:
self.routes: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]] = {}
self.root_path = Path(DATA_FOLDER)
self.root_path = Path(SQLLineageConfig.DIRECTORY)

def route(self, path: str):
def wrapper(handler):
Expand Down Expand Up @@ -193,7 +188,7 @@ def directory(payload):
elif payload.get("d"):
root = Path(payload["d"])
else:
root = Path(DATA_FOLDER)
root = Path(SQLLineageConfig.DIRECTORY)
data = {
"id": str(root),
"name": root.name,
Expand Down
11 changes: 10 additions & 1 deletion sqllineage/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Tuple

from sqllineage import DEFAULT_DIALECT, SQLPARSE_DIALECT
from sqllineage.config import SQLLineageConfig
from sqllineage.core.holders import SQLLineageHolder
from sqllineage.core.models import Column, Table
from sqllineage.core.parser.sqlfluff.analyzer import SqlFluffLineageAnalyzer
Expand Down Expand Up @@ -168,12 +169,20 @@ def print_table_lineage(self) -> None:
print(str(self))

def _eval(self):
self._stmt = split(self._sql.strip())
analyzer = (
SqlParseLineageAnalyzer()
if self._dialect == SQLPARSE_DIALECT
else SqlFluffLineageAnalyzer(self._dialect)
)
if SQLLineageConfig.TSQL_NO_SEMICOLON and self._dialect == "tsql":
self._stmt = analyzer.split_tsql(self._sql.strip())
else:
if SQLLineageConfig.TSQL_NO_SEMICOLON and self._dialect != "tsql":
warnings.warn(
f"Dialect={self._dialect}, TSQL_NO_SEMICOLON will be ignored unless dialect is tsql"
)
self._stmt = split(self._sql.strip())

self._stmt_holders = [analyzer.analyze(stmt) for stmt in self._stmt]
self._sql_holder = SQLLineageHolder.of(*self._stmt_holders)
self._evaluated = True
4 changes: 2 additions & 2 deletions tests/core/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

import pytest

from sqllineage import DATA_FOLDER
from sqllineage.cli import main
from sqllineage.config import SQLLineageConfig


@patch("socketserver.BaseServer.serve_forever")
def test_cli_dummy(_):
main([])
main(["-e", "select * from dual"])
main(["-e", "insert into foo select * from dual", "-l", "column"])
for dirname, _, files in os.walk(DATA_FOLDER):
for dirname, _, files in os.walk(SQLLineageConfig.DIRECTORY):
if len(files) > 0:
sql_file = str(Path(dirname).joinpath(Path(files[0])))
main(["-f", sql_file])
Expand Down
16 changes: 12 additions & 4 deletions tests/core/test_drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from http import HTTPStatus
from io import StringIO

from sqllineage import DATA_FOLDER
from sqllineage.config import SQLLineageConfig
from sqllineage.drawing import app


Expand Down Expand Up @@ -39,10 +39,16 @@ def mock_request(method, path, body=None):
mock_request("POST", "/script", {"e": "SELECT * FROM dual", "p": 5000})
assert container.status.startswith(str(HTTPStatus.OK.value))
mock_request(
"POST", "/directory", {"f": os.path.join(DATA_FOLDER, "tpcds/query01.sql")}
"POST",
"/directory",
{"f": os.path.join(SQLLineageConfig.DIRECTORY, "tpcds/query01.sql")},
)
assert container.status.startswith(str(HTTPStatus.OK.value))
mock_request("POST", "/directory", {"d": os.path.join(DATA_FOLDER, "tpcds/")})
mock_request(
"POST",
"/directory",
{"d": os.path.join(SQLLineageConfig.DIRECTORY, "tpcds/")},
)
assert container.status.startswith(str(HTTPStatus.OK.value))
mock_request("POST", "/directory", {})
assert container.status.startswith(str(HTTPStatus.OK.value))
Expand All @@ -62,7 +68,9 @@ def mock_request(method, path, body=None):
mock_request("GET", "/static")
assert container.status.startswith(str(HTTPStatus.NOT_FOUND.value))
mock_request(
"POST", "/script", {"f": os.path.join(DATA_FOLDER, "tpcds/query100.sql")}
"POST",
"/script",
{"f": os.path.join(SQLLineageConfig.DIRECTORY, "tpcds/query100.sql")},
)
assert container.status.startswith(str(HTTPStatus.NOT_FOUND.value))
mock_request("POST", "/non-exist-resource", {"e": "SELECT * FROM where foo='bar'"})
Expand Down
22 changes: 14 additions & 8 deletions tests/core/test_exception.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import warnings
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -43,19 +43,25 @@ def test_unsupported_query_type_in_sqlfluff():
)._eval()


def test_deprecated_warning_in_sqlparse():
with warnings.catch_warnings(record=True) as w:
def test_deprecation_warning_in_sqlparse():
with pytest.warns(DeprecationWarning):
LineageRunner("SELECT * FROM DUAL", dialect="non-validating")._eval()
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)


def test_syntax_warning_no_semicolon_in_tsql():
with warnings.catch_warnings(record=True) as w:
with pytest.warns(SyntaxWarning):
LineageRunner(
"""SELECT * FROM foo
SELECT * FROM bar""",
dialect="tsql",
)._eval()
assert len(w) == 1
assert issubclass(w[0].category, SyntaxWarning)


@patch("os.environ", {"SQLLINEAGE_TSQL_NO_SEMICOLON": "TRUE"})
def test_user_warning_enable_tsql_no_semicolon_with_other_dialect():
with pytest.warns(UserWarning):
LineageRunner(
"""SELECT * FROM foo;
SELECT * FROM bar""",
dialect="ansi",
)._eval()
21 changes: 21 additions & 0 deletions tests/sql/sqlfluff_only/test_tsql_no_semicolon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from unittest.mock import patch

import pytest
from tests.helpers import assert_table_lineage_equal


@patch("os.environ", {"SQLLINEAGE_TSQL_NO_SEMICOLON": "TRUE"})
@pytest.mark.parametrize("dialect", ["tsql"])
def test_tsql_multi_statement_no_semicolon(dialect: str):
"""
tsql multiple statements without explicit semicolon as splitter.
"""
sql = """insert into tab1 select * from foo
insert into tab2 select * from bar"""
assert_table_lineage_equal(
sql,
{"foo", "bar"},
{"tab1", "tab2"},
dialect=dialect,
test_sqlparse=False,
)

0 comments on commit 52558c6

Please sign in to comment.