diff --git a/examples/01_basic_api_usage.py b/examples/01_basic_api_usage.py index 51318056215..95221996b69 100644 --- a/examples/01_basic_api_usage.py +++ b/examples/01_basic_api_usage.py @@ -6,10 +6,9 @@ my_bad_query = "SeLEct *, 1, blah as fOO from myTable" -# Lint the given string and get a list of violations found. -result = sqlfluff.lint(my_bad_query, dialect="bigquery") - -# result = +# Lint the given string and return an array of violations in JSON representation. +lint_result = sqlfluff.lint(my_bad_query, dialect="bigquery") +# lint_result = # [ # {"code": "L010", "line_no": 1, "line_pos": 1, "description": "Keywords must be consistently upper case."} # ... @@ -18,33 +17,19 @@ # -------- FIXING ---------- # Fix the given string and get a string back which has been fixed. -result = sqlfluff.fix(my_bad_query, dialect="bigquery") -# result = 'SELECT *, 1, blah AS foo FROM mytable\n' +fix_result_1 = sqlfluff.fix(my_bad_query, dialect="bigquery") +# fix_result_1 = 'SELECT *, 1, blah AS foo FROM mytable\n' # We can also fix just specific rules. -result = sqlfluff.fix(my_bad_query, rules=["L010"]) -# result = 'SELECT *, 1, blah AS fOO FROM myTable' +fix_result_2 = sqlfluff.fix(my_bad_query, rules=["L010"]) +# fix_result_2 = 'SELECT *, 1, blah AS fOO FROM myTable' # Or a subset of rules... -result = sqlfluff.fix(my_bad_query, rules=["L010", "L014"]) -# result = 'SELECT *, 1, blah AS fOO FROM mytable' +fix_result_3 = sqlfluff.fix(my_bad_query, rules=["L010", "L014"]) +# fix_result_3 = 'SELECT *, 1, blah AS fOO FROM mytable' # -------- PARSING ---------- -# NOTE: sqlfluff is still in a relatively early phase of its -# development and so until version 1.0.0 will offer no guarantee -# that the names and structure of the objects returned by these -# parse commands won't change between releases. Use with care -# and keep updated with the changelog for the project for any -# changes in this space. - -parsed = sqlfluff.parse(my_bad_query) - -# Get the structure of the query -structure = parsed.tree.to_tuple(show_raw=True, code_only=True) -# structure = ('file', (('statement', (('select_statement', (('select_clause', (('keyword', 'SeLEct'), ... - -# Extract certain elements -keywords = [keyword.raw for keyword in parsed.tree.recursive_crawl("keyword")] -# keywords = ['SeLEct', 'as', 'from'] -tbl_refs = [tbl_ref.raw for tbl_ref in parsed.tree.recursive_crawl("table_reference")] -# tbl_refs == ["myTable"] + +# Parse the given string and return a JSON representation of the parsed tree. +parse_result = sqlfluff.parse(my_bad_query) +# parse_result = {'file': {'statement': {...}, 'newline': '\n'}} diff --git a/examples/03_extracting_references.py b/examples/03_extracting_references.py deleted file mode 100644 index 19d0b76dc69..00000000000 --- a/examples/03_extracting_references.py +++ /dev/null @@ -1,20 +0,0 @@ -"""This is an example of how to extract table names.""" - -import sqlfluff - -query_with_ctes = """ -WITH foo AS (SELECT * FROM bar.bar), -baz AS (SELECT * FROM bap) -SELECT * FROM foo -INNER JOIN baz USING (user_id) -INNER JOIN ban USING (user_id) -""" - -# -------- PARSING ---------- -parsed = sqlfluff.parse(query_with_ctes) - -# -------- EXTRACTION ---------- -# Under the hood we look for all of the table references -# which aren't also CTE aliases. -external_tables = parsed.tree.get_table_references() -# external_tables == {'bar.bar', 'bap', 'ban'} diff --git a/examples/04_getting_rules_and_dialects.py b/examples/03_getting_rules_and_dialects.py similarity index 100% rename from examples/04_getting_rules_and_dialects.py rename to examples/03_getting_rules_and_dialects.py diff --git a/src/sqlfluff/api/simple.py b/src/sqlfluff/api/simple.py index b9b01d280b2..bce9148ff14 100644 --- a/src/sqlfluff/api/simple.py +++ b/src/sqlfluff/api/simple.py @@ -8,7 +8,6 @@ SQLBaseError, SQLFluffUserError, ) -from sqlfluff.core.linter import ParsedString def get_simple_config( @@ -132,7 +131,7 @@ def parse( sql: str, dialect: str = "ansi", config_path: Optional[str] = None, -) -> ParsedString: +) -> Dict[str, Any]: """Parse a SQL string. Args: @@ -143,7 +142,7 @@ def parse( Defaults to None. Returns: - :obj:`ParsedString` containing the parsed structure. + :obj:`Dict[str, Any]` JSON containing the parsed structure. """ cfg = get_simple_config( dialect=dialect, @@ -155,4 +154,7 @@ def parse( # If we encounter any parsing errors, raise them in a combined issue. if parsed.violations: raise APIParsingError(parsed.violations) - return parsed + # Return a JSON representation of the parse tree. + if parsed.tree is None: # pragma: no cover + return {} + return parsed.tree.as_record(show_raw=True) diff --git a/test/api/simple_test.py b/test/api/simple_test.py index 3a0ce5b82fc..21516444061 100644 --- a/test/api/simple_test.py +++ b/test/api/simple_test.py @@ -1,10 +1,11 @@ """Tests for simple use cases of the public api.""" +import json + import pytest import sqlfluff from sqlfluff.core.errors import SQLFluffUserError -from sqlfluff.core.linter import ParsedString my_bad_query = "SeLEct *, 1, blah as fOO from myTable" @@ -178,21 +179,16 @@ def test__api__fix_string_specific_exclude(): def test__api__parse_string(): """Basic checking of parse functionality.""" parsed = sqlfluff.parse(my_bad_query) - # Check we can call `to_tuple` on the result - assert isinstance(parsed, ParsedString) - # Check we can iterate objects within it - keywords = [keyword.raw for keyword in parsed.tree.recursive_crawl("keyword")] - assert keywords == ["SeLEct", "as", "from"] - # Check we can get columns from it - col_refs = [ - col_ref.raw for col_ref in parsed.tree.recursive_crawl("column_reference") - ] - assert col_refs == ["blah"] - # Check we can get table from it - tbl_refs = [ - tbl_ref.raw for tbl_ref in parsed.tree.recursive_crawl("table_reference") - ] - assert tbl_refs == ["myTable"] + + # Check a JSON object is returned. + assert isinstance(parsed, dict) + + # Load in expected result. + with open("test/fixtures/api/parse_test/parse_test.json", "r") as f: + expected_parsed = json.load(f) + + # Compare JSON from parse to expected result. + assert parsed == expected_parsed def test__api__parse_fail(): @@ -217,24 +213,27 @@ def test__api__parse_fail(): def test__api__config_path(): """Test that we can load a specified config file in the Simple API.""" # Load test SQL file. - with open("test/fixtures/api/api_config_test.sql", "r") as f: + with open("test/fixtures/api/config_path_test/config_path_test.sql", "r") as f: sql = f.read() # Pass a config path to the Simple API. - res = sqlfluff.parse( + parsed = sqlfluff.parse( sql, - config_path="test/fixtures/api/extra_configs/.sqlfluff", + config_path="test/fixtures/api/config_path_test/extra_configs/.sqlfluff", ) - # Check there are no errors and the template is rendered correctly. - assert len(res.violations) == 0 - assert res.tree.raw == "SELECT foo FROM bar;\n" + # Load in expected result. + with open("test/fixtures/api/config_path_test/config_path_test.json", "r") as f: + expected_parsed = json.load(f) + + # Compare JSON from parse to expected result. + assert parsed == expected_parsed def test__api__invalid_dialect(): """Test that SQLFluffUserError is raised for a bad dialect.""" # Load test SQL file. - with open("test/fixtures/api/api_config_test.sql", "r") as f: + with open("test/fixtures/api/config_path_test/config_path_test.sql", "r") as f: sql = f.read() # Pass a fake dialect to the API and test the correct error is raised. @@ -242,7 +241,7 @@ def test__api__invalid_dialect(): sqlfluff.parse( sql, dialect="not_a_real_dialect", - config_path="test/fixtures/api/extra_configs/.sqlfluff", + config_path="test/fixtures/api/config_path_test/extra_configs/.sqlfluff", ) assert str(err.value) == "Error: Unknown dialect 'not_a_real_dialect'" diff --git a/test/api/util_test.py b/test/api/util_test.py deleted file mode 100644 index 0ef9a668a41..00000000000 --- a/test/api/util_test.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Test using sqlfluff to extract elements of queries.""" - -import pytest - -import sqlfluff - -my_bad_query = "SeLEct *, 1, blah as fOO from myTable" - -query_with_ctes = """ -WITH foo AS (SELECT * FROM bar.bar), -baz AS (SELECT * FROM bap) -SELECT * FROM foo -INNER JOIN baz USING (user_id) -INNER JOIN ban USING (user_id) -""" - - -@pytest.mark.parametrize( - "sql,table_refs,dialect", - [ - (my_bad_query, {"myTable"}, None), - (query_with_ctes, {"bar.bar", "bap", "ban"}, "snowflake"), - ], -) -def test__api__util_get_table_references(sql, table_refs, dialect): - """Basic checking of lint functionality.""" - parsed = sqlfluff.parse(sql, dialect=dialect) - external_tables = parsed.tree.get_table_references() - assert external_tables == table_refs diff --git a/test/core/linter_test.py b/test/core/linter_test.py index 8d86599aa36..f1e57b9fabe 100644 --- a/test/core/linter_test.py +++ b/test/core/linter_test.py @@ -859,3 +859,29 @@ def test_safe_create_replace_file(case, tmp_path): pass actual = p.read_text(encoding=case["encoding"]) assert case["expected"] == actual + + +def test_advanced_api_methods(): + """Test advanced API methods on segments.""" + # These aren't used by the simple API, which returns + # a simple JSON representation of the parse tree, but + # are available for advanced API usage and within rules. + sql = """ + WITH cte AS ( + SELECT * FROM tab_a + ) + SELECT + cte.col_a, + tab_b.col_b + FROM cte + INNER JOIN tab_b; + """ + linter = Linter() + parsed = linter.parse_string(sql) + + # CTEDefinitionSegment.get_identifier + cte_segment = next(parsed.tree.recursive_crawl("common_table_expression")) + assert cte_segment.get_identifier().raw == "cte" + + # BaseFileSegment.get_table_references & StatementSegment.get_table_references + assert parsed.tree.get_table_references() == {"tab_a", "tab_b"} diff --git a/test/fixtures/api/config_path_test/config_path_test.json b/test/fixtures/api/config_path_test/config_path_test.json new file mode 100644 index 00000000000..2369093f697 --- /dev/null +++ b/test/fixtures/api/config_path_test/config_path_test.json @@ -0,0 +1,33 @@ +{ + "file": { + "statement": { + "select_statement": { + "select_clause": { + "keyword": "SELECT", + "whitespace": " ", + "select_clause_element": { + "column_reference": { + "identifier": "foo" + } + } + }, + "whitespace": " ", + "from_clause": { + "keyword": "FROM", + "whitespace": " ", + "from_expression": { + "from_expression_element": { + "table_expression": { + "table_reference": { + "identifier": "bar" + } + } + } + } + } + } + }, + "statement_terminator": ";", + "newline": "\n" + } +} diff --git a/test/fixtures/api/api_config_test.sql b/test/fixtures/api/config_path_test/config_path_test.sql similarity index 100% rename from test/fixtures/api/api_config_test.sql rename to test/fixtures/api/config_path_test/config_path_test.sql diff --git a/test/fixtures/api/extra_configs/.sqlfluff b/test/fixtures/api/config_path_test/extra_configs/.sqlfluff similarity index 100% rename from test/fixtures/api/extra_configs/.sqlfluff rename to test/fixtures/api/config_path_test/extra_configs/.sqlfluff diff --git a/test/fixtures/api/parse_test/parse_test.json b/test/fixtures/api/parse_test/parse_test.json new file mode 100644 index 00000000000..ec4b2b1a754 --- /dev/null +++ b/test/fixtures/api/parse_test/parse_test.json @@ -0,0 +1,69 @@ +{ + "file": { + "statement": { + "select_statement": { + "select_clause": [ + { + "keyword": "SeLEct" + }, + { + "whitespace": " " + }, + { + "select_clause_element": { + "wildcard_expression": { + "wildcard_identifier": { + "star": "*" + } + } + } + }, + { + "comma": "," + }, + { + "whitespace": " " + }, + { + "select_clause_element": { + "literal": "1" + } + }, + { + "comma": "," + }, + { + "whitespace": " " + }, + { + "select_clause_element": { + "column_reference": { + "identifier": "blah" + }, + "whitespace": " ", + "alias_expression": { + "keyword": "as", + "whitespace": " ", + "identifier": "fOO" + } + } + } + ], + "whitespace": " ", + "from_clause": { + "keyword": "from", + "whitespace": " ", + "from_expression": { + "from_expression_element": { + "table_expression": { + "table_reference": { + "identifier": "myTable" + } + } + } + } + } + } + } + } +}