diff --git a/docs/source/configuration.rst b/docs/source/configuration.rst index 9a1704eedf4..7fe5b33aa0f 100644 --- a/docs/source/configuration.rst +++ b/docs/source/configuration.rst @@ -374,6 +374,37 @@ to use them for templated. In the above example, you might define a file at return "GROUP BY 1,2" +If an `__init__.py` is detected, it will be loaded alongside any modules and +submodules found within the library path. + +.. code-block:: jinja + + SELECT + {{ custom_sum('foo', 'bar') }}, + {{ foo.bar.another_sum('foo', 'bar') }} + FROM + baz + +`sqlfluff_libs/__init__.py`: + +.. code-block:: python + + def custom_sum(a: str, b: str) -> str: + return a + b + +`sqlfluff_libs/foo/__init__.py`: + +.. code-block:: python + + # empty file + +`sqlfluff_libs/foo/bar.py`: + +.. code-block:: python + + def another_sum(a: str, b: str) -> str: + return a + b + dbt Project Configuration ------------------------- diff --git a/src/sqlfluff/core/parser/segments/base.py b/src/sqlfluff/core/parser/segments/base.py index 1e936f8c0ca..8164ff116c3 100644 --- a/src/sqlfluff/core/parser/segments/base.py +++ b/src/sqlfluff/core/parser/segments/base.py @@ -978,6 +978,10 @@ def apply_fixes(self, fixes): "create_before", "create_after", ): + if f.edit_type == "create_after": + # in the case of a creation after, also add this segment before the edit. + seg_buffer.append(seg) + # We're doing a replacement (it could be a single segment or an iterable) if isinstance(f.edit, BaseSegment): seg_buffer.append(f.edit) # pragma: no cover TODO? @@ -988,9 +992,7 @@ def apply_fixes(self, fixes): if f.edit_type == "create_before": # in the case of a creation before, also add this segment on the end seg_buffer.append(seg) - elif f.edit_type == "create_after": - # in the case of a creation after, also add this segment to the start - seg_buffer.insert(0, seg) + else: # pragma: no cover raise ValueError( "Unexpected edit_type: {!r} in {!r}".format( diff --git a/src/sqlfluff/core/templaters/jinja.py b/src/sqlfluff/core/templaters/jinja.py index ab9b2b0683f..be85ff3cf3f 100644 --- a/src/sqlfluff/core/templaters/jinja.py +++ b/src/sqlfluff/core/templaters/jinja.py @@ -1,26 +1,25 @@ """Defines the templaters.""" - -import os.path import logging -import importlib.util -from typing import Callable, Dict, List, Tuple, Optional +import os.path +import pkgutil +from functools import reduce +from typing import Callable, Dict, List, Optional, Tuple -from jinja2 import Environment +import jinja2.nodes +from jinja2 import Environment, TemplateError, TemplateSyntaxError, meta from jinja2.environment import Template from jinja2.sandbox import SandboxedEnvironment -from jinja2 import meta, TemplateSyntaxError, TemplateError -import jinja2.nodes from sqlfluff.core.errors import SQLTemplaterError - from sqlfluff.core.templaters.base import ( - TemplatedFile, RawFileSlice, + TemplatedFile, TemplatedFileSlice, ) from sqlfluff.core.templaters.python import PythonTemplater from sqlfluff.core.templaters.slicers.tracer import JinjaTracer + # Instantiate the templater logger templater_logger = logging.getLogger("sqlfluff.templater") @@ -33,6 +32,11 @@ class JinjaTemplater(PythonTemplater): name = "jinja" + class Libraries: + """Mock namespace for user-defined Jinja library.""" + + pass + @staticmethod def _extract_macros_from_template(template, env, ctx): """Take a template string and extract any macros from it. @@ -106,19 +110,45 @@ def _extract_libraries_from_config(self, config): if not library_path: return {} - libraries = {} - for file_name in os.listdir(library_path): - file_path = os.path.join(library_path, file_name) - if not os.path.isfile(file_path) or not file_name.endswith(".py"): + libraries = JinjaTemplater.Libraries() + + # If library_path hash __init__.py we parse it as a one module, else we parse it a set of modules + is_library_module = os.path.exists(os.path.join(library_path, "__init__.py")) + library_module_name = os.path.basename(library_path) + + # Need to go one level up to parse as a module correctly + walk_path = ( + os.path.join(library_path, "..") if is_library_module else library_path + ) + + for loader, module_name, is_pkg in pkgutil.walk_packages([walk_path]): + # skip other modules that can be near module_dir + if is_library_module and not module_name.startswith(library_module_name): continue - module_name = os.path.splitext(file_name)[0] - spec = importlib.util.spec_from_file_location(module_name, file_path) - lib = importlib.util.module_from_spec(spec) - spec.loader.exec_module(lib) - libraries[module_name] = lib + module = loader.find_module(module_name).load_module(module_name) + + if "." in module_name: # nested modules have `.` in module_name + *module_path, last_module_name = module_name.split(".") + # find parent module recursively + parent_module = reduce( + lambda res, path_part: getattr(res, path_part), + module_path, + libraries, + ) + + # set attribute on module object to make jinja working correctly + setattr(parent_module, last_module_name, module) + else: + # set attr on `libraries` obj to make it work in jinja nicely + setattr(libraries, module_name, module) + + if is_library_module: + # when library is module we have one more root module in hierarchy and we remove it + libraries = getattr(libraries, library_module_name) - return libraries + # remove magic methods from result + return {k: v for k, v in libraries.__dict__.items() if not k.startswith("__")} @staticmethod def _generate_dbt_builtins(): diff --git a/src/sqlfluff/dialects/dialect_bigquery.py b/src/sqlfluff/dialects/dialect_bigquery.py index 4b6ba7044fd..a462c3eac86 100644 --- a/src/sqlfluff/dialects/dialect_bigquery.py +++ b/src/sqlfluff/dialects/dialect_bigquery.py @@ -8,6 +8,7 @@ import itertools +from sqlfluff.core.dialects import load_raw_dialect from sqlfluff.core.parser import ( Anything, BaseSegment, @@ -33,9 +34,6 @@ Matchable, ) from sqlfluff.core.parser.segments.base import BracketedSegment - -from sqlfluff.core.dialects import load_raw_dialect - from sqlfluff.dialects.dialect_bigquery_keywords import ( bigquery_reserved_keywords, bigquery_unreserved_keywords, @@ -766,7 +764,7 @@ class DeclareStatementSegment(BaseSegment): parse_grammar = Sequence( "DECLARE", Delimited(Ref("NakedIdentifierSegment")), - Ref("DatatypeIdentifierSegment"), + Ref("DatatypeSegment"), Sequence( "DEFAULT", OneOf( @@ -813,6 +811,7 @@ class SetStatementSegment(BaseSegment): ) ) ), + Ref("ArrayLiteralSegment"), ) ) ), diff --git a/src/sqlfluff/rules/L008.py b/src/sqlfluff/rules/L008.py index 111944ef0fb..91a86d2d72a 100644 --- a/src/sqlfluff/rules/L008.py +++ b/src/sqlfluff/rules/L008.py @@ -1,11 +1,13 @@ """Implementation of Rule L008.""" -from typing import Optional +from typing import Optional, Tuple from sqlfluff.core.parser import WhitespaceSegment from sqlfluff.core.rules.base import BaseRule, LintResult, LintFix, RuleContext from sqlfluff.core.rules.doc_decorators import document_fix_compatible +from sqlfluff.core.parser.segments.base import BaseSegment + @document_fix_compatible class Rule_L008(BaseRule): @@ -34,45 +36,75 @@ class Rule_L008(BaseRule): WHERE a IN ('plop',•'zoo') """ - def _eval(self, context: RuleContext) -> Optional[LintResult]: - """Commas should be followed by a single whitespace unless followed by a comment. + def _get_subsequent_whitespace( + self, + context, + ) -> Tuple[Optional[BaseSegment], Optional[BaseSegment]]: + """Search forwards through the raw segments for subsequent whitespace. - This is a slightly odd one, because we'll almost always evaluate from a point a few places - after the problem site. NB: We need at least two segments behind us for this to work. + Return a tuple of both the trailing whitespace segment and the + first non-whitespace segment discovered. """ - if len(context.raw_stack) < 1: - return None + subsequent_whitespace = None + # Get all raw segments and find position of the current comma within the list. + file_segment = context.parent_stack[0] + raw_segments = file_segment.get_raw_segments() + # Raw stack is appropriate as the only segments we can care about are + # comma, whitespace, newline, and comment, which are all raw. + # Using the raw_segments allows us to account for possible unexpected + # parse tree structures resulting from other rule fixes. + next_raw_index = raw_segments.index(context.segment) + 1 + # Iterate forwards over raw segments to find both the whitespace segment and + # the first non-whitespace segment. + for s in raw_segments[next_raw_index:]: + if s.is_meta: + continue + elif s.is_type("whitespace"): + # Capture the whitespace segment. + subsequent_whitespace = s + else: + # We've found a non-whitespace (and non-meta) segment. + # Therefore return the stored whitespace segment + # and this segment for analysis. + return subsequent_whitespace, s + + # If we find ourselves here it's all + # whitespace (or nothing) to the end of the file. + # This can only happen in bigquery (see test_pass_bigquery_trailing_comma). + return subsequent_whitespace, None - # Get the first element of this segment. - first_elem = context.segment.get_raw_segments()[0] - - cm1 = context.raw_stack[-1] - if cm1.name == "comma": - # comma followed by something that isn't whitespace? - if first_elem.name not in ["whitespace", "newline", "Dedent"]: - self.logger.debug( - "Comma followed by something other than whitespace: %s", first_elem - ) - ins = WhitespaceSegment(raw=" ") - return LintResult( - anchor=cm1, - fixes=[LintFix("edit", context.segment, [ins, context.segment])], - ) - - if len(context.raw_stack) < 2: + def _eval(self, context: RuleContext) -> Optional[LintResult]: + """Commas should be followed by a single whitespace unless followed by a comment.""" + # We only care about commas. + if context.segment.name != "comma": return None - cm2 = context.raw_stack[-2] - if cm2.name == "comma": - # comma followed by too much whitespace? - if ( - cm1.is_whitespace # Must be whitespace - and cm1.raw != " " # ...and not a single one - and cm1.name != "newline" # ...and not a newline - and not first_elem.is_comment # ...and not followed by a comment - ): - self.logger.debug("Comma followed by too much whitespace: %s", cm1) - repl = WhitespaceSegment(raw=" ") - return LintResult(anchor=cm1, fixes=[LintFix("edit", cm1, repl)]) - # Otherwise we're fine + # Get subsequent whitespace segment and the first non-whitespace segment. + subsequent_whitespace, first_non_whitespace = self._get_subsequent_whitespace( + context + ) + + if ( + (subsequent_whitespace is None) + and (first_non_whitespace is not None) + and (not first_non_whitespace.is_type("newline")) + ): + # No trailing whitespace and not followed by a newline, + # therefore create a whitespace after the comma. + return LintResult( + anchor=first_non_whitespace, + fixes=[LintFix("create_after", context.segment, WhitespaceSegment())], + ) + elif ( + (subsequent_whitespace is not None) + and (subsequent_whitespace.raw != " ") + and (first_non_whitespace is not None) + and (not first_non_whitespace.is_comment) + ): + # Excess trailing whitespace therefore edit to only be one space long. + return LintResult( + anchor=subsequent_whitespace, + fixes=[LintFix("edit", subsequent_whitespace, WhitespaceSegment())], + ) + return None diff --git a/test/core/parser/segments_base_test.py b/test/core/parser/segments_base_test.py index b3b6d36501b..1b1667e4b88 100644 --- a/test/core/parser/segments_base_test.py +++ b/test/core/parser/segments_base_test.py @@ -135,3 +135,9 @@ def test__parser__base_segments_file(raw_seg_list): assert base_seg.file_path == "/some/dir/file.sql" assert base_seg.can_start_end_non_code assert base_seg.allow_empty + + +def test__parser__raw_get_raw_segments(raw_seg_list): + """Test niche case of calling get_raw_segments on a raw segment.""" + for s in raw_seg_list: + assert s.get_raw_segments() == [s] diff --git a/test/core/templaters/jinja_test.py b/test/core/templaters/jinja_test.py index 83033eb5963..efd7febc386 100644 --- a/test/core/templaters/jinja_test.py +++ b/test/core/templaters/jinja_test.py @@ -450,6 +450,8 @@ def assert_structure(yaml_loader, path, code_only=True, include_meta=False): # Placeholders and metas ("jinja_l_metas/001", False, True), ("jinja_l_metas/002", False, True), + # Library Loading from a folder when library is module + ("jinja_m_libraries_module/jinja", True, False), ], ) def test__templater_full(subpath, code_only, include_meta, yaml_loader, caplog): diff --git a/test/fixtures/dialects/bigquery/declare_variable.sql b/test/fixtures/dialects/bigquery/declare_variable.sql index 02ab45ebd4c..d6dad600aa3 100644 --- a/test/fixtures/dialects/bigquery/declare_variable.sql +++ b/test/fixtures/dialects/bigquery/declare_variable.sql @@ -1,2 +1,3 @@ declare var1 int64; declare var2, var3 string; +declare var1 array; diff --git a/test/fixtures/dialects/bigquery/declare_variable.yml b/test/fixtures/dialects/bigquery/declare_variable.yml index 8ab4723ca8e..abd4fc73d51 100644 --- a/test/fixtures/dialects/bigquery/declare_variable.yml +++ b/test/fixtures/dialects/bigquery/declare_variable.yml @@ -3,13 +3,14 @@ # computed by SQLFluff when running the tests. Please run # `python test/generate_parse_fixture_yml.py` to generate them after adding or # altering SQL files. -_hash: 607654fc5d760c0d39fa3c9d98f9a0f9ae6cdf1105a8216d6cb8d018e5a220f8 +_hash: ffccb95575561408b5510a30ff40b55e12c56a173fa3fc57ef3f5e05e5ee2488 file: - statement: declare_segment: keyword: declare identifier: var1 - data_type_identifier: int64 + data_type: + data_type_identifier: int64 - statement_terminator: ; - statement: declare_segment: @@ -17,5 +18,17 @@ file: - identifier: var2 - comma: ',' - identifier: var3 - - data_type_identifier: string + - data_type: + data_type_identifier: string +- statement_terminator: ; +- statement: + declare_segment: + keyword: declare + identifier: var1 + data_type: + keyword: array + start_angle_bracket: < + data_type: + data_type_identifier: string + end_angle_bracket: '>' - statement_terminator: ; diff --git a/test/fixtures/dialects/bigquery/declare_variable_with_default.yml b/test/fixtures/dialects/bigquery/declare_variable_with_default.yml index b9452d277ee..0ff7b4d4076 100644 --- a/test/fixtures/dialects/bigquery/declare_variable_with_default.yml +++ b/test/fixtures/dialects/bigquery/declare_variable_with_default.yml @@ -3,13 +3,14 @@ # computed by SQLFluff when running the tests. Please run # `python test/generate_parse_fixture_yml.py` to generate them after adding or # altering SQL files. -_hash: ecfb4efbfe87fb634c571f328fa2324ab7de5c78f81c0cf38a353de93b7f1ec5 +_hash: b89a787f3df1badc8f46432acc28fbe4e1550a3e1a28a8be058344ffd9f33878 file: - statement: declare_segment: - keyword: declare - identifier: var5 - - data_type_identifier: date + - data_type: + data_type_identifier: date - keyword: default - function: function_name: @@ -22,7 +23,8 @@ file: declare_segment: - keyword: declare - identifier: var4 - - data_type_identifier: int64 + - data_type: + data_type_identifier: int64 - keyword: default - literal: '1' - statement_terminator: ; @@ -30,7 +32,8 @@ file: declare_segment: - keyword: declare - identifier: var3 - - data_type_identifier: string + - data_type: + data_type_identifier: string - keyword: default - bracketed: start_bracket: ( diff --git a/test/fixtures/dialects/bigquery/set_variable_single.sql b/test/fixtures/dialects/bigquery/set_variable_single.sql index 0d87478d2e6..2d7eebce094 100644 --- a/test/fixtures/dialects/bigquery/set_variable_single.sql +++ b/test/fixtures/dialects/bigquery/set_variable_single.sql @@ -1 +1,2 @@ set var1 = 5; +set var1 = ['one', 'two']; diff --git a/test/fixtures/dialects/bigquery/set_variable_single.yml b/test/fixtures/dialects/bigquery/set_variable_single.yml index bad84d1cf7d..3b1a48585fe 100644 --- a/test/fixtures/dialects/bigquery/set_variable_single.yml +++ b/test/fixtures/dialects/bigquery/set_variable_single.yml @@ -3,12 +3,26 @@ # computed by SQLFluff when running the tests. Please run # `python test/generate_parse_fixture_yml.py` to generate them after adding or # altering SQL files. -_hash: 9b00e078abcb912ddcc763475dedd48d46a99ffc501e403c20e960fbc092eb98 +_hash: 335175c732ab030ddf93fa3500fb827a1e936aab795bfcb62ce9540c2de6209e file: - statement: +- statement: set_segment: keyword: set identifier: var1 comparison_operator: '=' literal: '5' - statement_terminator: ; +- statement_terminator: ; +- statement: + set_segment: + keyword: set + identifier: var1 + comparison_operator: '=' + array_literal: + - start_square_bracket: '[' + - expression: + literal: "'one'" + - comma: ',' + - expression: + literal: "'two'" + - end_square_bracket: ']' +- statement_terminator: ; diff --git a/test/fixtures/linter/autofix/ansi/016_no_fix_in_template_loops/after.sql b/test/fixtures/linter/autofix/ansi/016_no_fix_in_template_loops/after.sql index 2bb55a6f43c..8dd376fb91e 100644 --- a/test/fixtures/linter/autofix/ansi/016_no_fix_in_template_loops/after.sql +++ b/test/fixtures/linter/autofix/ansi/016_no_fix_in_template_loops/after.sql @@ -7,7 +7,7 @@ SELECT 10; SELECT - 1, + 1, {%- for _ in [1, 2, 3] %} 2{%endfor %}; SELECT diff --git a/test/fixtures/rules/std_rule_cases/L008.yml b/test/fixtures/rules/std_rule_cases/L008.yml index daf87fba5a7..54a219e2656 100644 --- a/test/fixtures/rules/std_rule_cases/L008.yml +++ b/test/fixtures/rules/std_rule_cases/L008.yml @@ -14,3 +14,6 @@ test_fail_no_whitespace_after_comma: test_fail_no_whitespace_after_comma_2: fail_str: SELECT FLOOR(dt) ,count(*) FROM test fix_str: SELECT FLOOR(dt) , count(*) FROM test + +test_pass_bigquery_trailing_comma: + pass_str: SELECT 1, 2, diff --git a/test/fixtures/templater/jinja_j_libraries/jinja.yml b/test/fixtures/templater/jinja_j_libraries/jinja.yml index efbaf4c4a60..519b88617b8 100644 --- a/test/fixtures/templater/jinja_j_libraries/jinja.yml +++ b/test/fixtures/templater/jinja_j_libraries/jinja.yml @@ -20,4 +20,4 @@ file: column_reference: identifier: x comparison_operator: '=' - literal: '23' \ No newline at end of file + literal: '23' diff --git a/test/fixtures/templater/jinja_m_libraries_module/.sqlfluff b/test/fixtures/templater/jinja_m_libraries_module/.sqlfluff new file mode 100644 index 00000000000..26fd6849858 --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/.sqlfluff @@ -0,0 +1,2 @@ +[sqlfluff:templater:jinja] +library_path=libs diff --git a/test/fixtures/templater/jinja_m_libraries_module/jinja.sql b/test/fixtures/templater/jinja_m_libraries_module/jinja.sql new file mode 100644 index 00000000000..1e82de1fa05 --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/jinja.sql @@ -0,0 +1,3 @@ +SELECT 56 +FROM {{ foo.schema }}.{{ foo.table("xyz") }} +WHERE {{ foo.bar.baz.equals("x", 23) }} and {{ root_equals("y", 42) }} diff --git a/test/fixtures/templater/jinja_m_libraries_module/jinja.yml b/test/fixtures/templater/jinja_m_libraries_module/jinja.yml new file mode 100644 index 00000000000..2cd20606d7d --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/jinja.yml @@ -0,0 +1,28 @@ +file: + statement: + select_statement: + select_clause: + keyword: SELECT + select_clause_element: + literal: '56' + from_clause: + keyword: FROM + from_expression: + from_expression_element: + table_expression: + table_reference: + - identifier: sch1 + - dot: . + - identifier: foo_xyz + where_clause: + keyword: WHERE + expression: + - column_reference: + identifier: x + - comparison_operator: '=' + - literal: '23' + - binary_operator: and + - column_reference: + identifier: y + - comparison_operator: '=' + - literal: '42' diff --git a/test/fixtures/templater/jinja_m_libraries_module/libs/__init__.py b/test/fixtures/templater/jinja_m_libraries_module/libs/__init__.py new file mode 100644 index 00000000000..a77865563dc --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/libs/__init__.py @@ -0,0 +1,6 @@ +"""Module used to test __init__.py within the jinja template.""" + + +def root_equals(col: str, val: str) -> str: + """Return a string that has col = val.""" + return f"{col} = {val}" diff --git a/test/fixtures/templater/jinja_m_libraries_module/libs/foo/__init__.py b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/__init__.py new file mode 100644 index 00000000000..1f8e603ebab --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/__init__.py @@ -0,0 +1,7 @@ +"""Module used to test foo within the jinja template.""" +schema = "sch1" + + +def table(name): + """Return the parameter with foo_ in front of it.""" + return f"foo_{name}" diff --git a/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/__init__.py b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/__init__.py new file mode 100644 index 00000000000..5fc1d13c65f --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/__init__.py @@ -0,0 +1 @@ +"""Module used to create module hierarchy.""" diff --git a/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/baz.py b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/baz.py new file mode 100644 index 00000000000..b6650886fb2 --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/libs/foo/bar/baz.py @@ -0,0 +1,6 @@ +"""Module used to test bar within the jinja template.""" + + +def equals(col, val): + """Return a string that has col = val.""" + return f"{col} = {val}" diff --git a/test/fixtures/templater/jinja_m_libraries_module/libs/not_python.txt b/test/fixtures/templater/jinja_m_libraries_module/libs/not_python.txt new file mode 100644 index 00000000000..542aa9f9bdf --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/libs/not_python.txt @@ -0,0 +1 @@ +I am just a text file \ No newline at end of file diff --git a/test/fixtures/templater/jinja_m_libraries_module/other/__init__.py b/test/fixtures/templater/jinja_m_libraries_module/other/__init__.py new file mode 100644 index 00000000000..3fb4aac3ff6 --- /dev/null +++ b/test/fixtures/templater/jinja_m_libraries_module/other/__init__.py @@ -0,0 +1,3 @@ +"""Module that should not be loaded.""" + +raise Exception("this file should not be loaded") diff --git a/test/rules/std_L008_test.py b/test/rules/std_L008_test.py new file mode 100644 index 00000000000..a513d8df31a --- /dev/null +++ b/test/rules/std_L008_test.py @@ -0,0 +1,19 @@ +"""Tests the python routines within L008.""" +import sqlfluff + + +def test__rules__std_L008_single_raise() -> None: + """Test case for multiple L008 errors raised when no post comma whitespace.""" + # This query used to triple count L008. Added memory to log previously fixed commas (issue #2001). + sql = """ + SELECT + col_a AS a + ,col_b AS b + FROM foo; + """ + result = sqlfluff.lint(sql, rules=["L008", "L019"]) + + results_L008 = [r for r in result if r["code"] == "L008"] + results_L019 = [r for r in result if r["code"] == "L019"] + assert len(results_L008) == 1 + assert len(results_L019) == 1