From f6f6c2f13a9ccf03258b0d17990e02cc893a62a9 Mon Sep 17 00:00:00 2001 From: Barry Hart Date: Tue, 15 Mar 2022 11:20:13 -0400 Subject: [PATCH] L009: Handle adding newline after `{% endif %}` at end of file (#2862) * L009: Handle adding newline after {% endif %} at end of file * Lexer should add trailing placeholders if needed * Fix L009 to run on the actual final segment, *even if meta* * Update iter_patches() to return EnrichedFixPatch for better fix placement * Fix broken test * Comments, type hints * First draft of refactor how we handle FixPatch vs EnrichedFixPatch * Second phase of refactoring FixPatch Co-authored-by: Barry Hart --- src/sqlfluff/core/linter/common.py | 18 ----- src/sqlfluff/core/linter/linted_file.py | 35 +++------ src/sqlfluff/core/parser/lexer.py | 25 ++++++ src/sqlfluff/core/parser/segments/base.py | 76 ++++++++++++++++--- src/sqlfluff/core/rules/base.py | 16 ++-- .../core/templaters/slicers/tracer.py | 11 ++- src/sqlfluff/rules/L009.py | 2 +- test/api/simple_test.py | 10 +-- test/core/templaters/jinja_test.py | 20 ++++- test/fixtures/rules/std_rule_cases/L009.yml | 6 ++ 10 files changed, 150 insertions(+), 69 deletions(-) diff --git a/src/sqlfluff/core/linter/common.py b/src/sqlfluff/core/linter/common.py index db2c7a38fac..6ec3be56db9 100644 --- a/src/sqlfluff/core/linter/common.py +++ b/src/sqlfluff/core/linter/common.py @@ -67,21 +67,3 @@ class ParsedString(NamedTuple): config: FluffConfig fname: str source_str: str - - -class EnrichedFixPatch(NamedTuple): - """An edit patch for a source file.""" - - source_slice: slice - templated_slice: slice - fixed_raw: str - # The patch category, functions mostly for debugging and explanation - # than for function. It allows traceability of *why* this patch was - # generated. - patch_category: str - templated_str: str - source_str: str - - def dedupe_tuple(self): - """Generate a tuple of this fix for deduping.""" - return (self.source_slice, self.fixed_raw) diff --git a/src/sqlfluff/core/linter/linted_file.py b/src/sqlfluff/core/linter/linted_file.py index 4c149fc6f0c..168ddbcae64 100644 --- a/src/sqlfluff/core/linter/linted_file.py +++ b/src/sqlfluff/core/linter/linted_file.py @@ -30,9 +30,9 @@ from sqlfluff.core.templaters import TemplatedFile # Classes needed only for type checking -from sqlfluff.core.parser.segments.base import BaseSegment, FixPatch +from sqlfluff.core.parser.segments.base import BaseSegment, FixPatch, EnrichedFixPatch -from sqlfluff.core.linter.common import NoQaDirective, EnrichedFixPatch +from sqlfluff.core.linter.common import NoQaDirective # Instantiate the linter logger linter_logger: logging.Logger = logging.getLogger("sqlfluff.linter") @@ -203,9 +203,7 @@ def is_clean(self) -> bool: return not any(self.get_violations(filter_ignore=True)) @staticmethod - def _log_hints( - patch: Union[EnrichedFixPatch, FixPatch], templated_file: TemplatedFile - ): + def _log_hints(patch: FixPatch, templated_file: TemplatedFile): """Log hints for debugging during patch generation.""" # This next bit is ALL FOR LOGGING AND DEBUGGING max_log_length = 10 @@ -279,18 +277,16 @@ def fix_string(self) -> Tuple[Any, bool]: dedupe_buffer = [] # We use enumerate so that we get an index for each patch. This is entirely # so when debugging logs we can find a given patch again! - patch: Union[EnrichedFixPatch, FixPatch] + patch: FixPatch # Could be FixPatch or its subclass, EnrichedFixPatch for idx, patch in enumerate( - self.tree.iter_patches(templated_str=self.templated_file.templated_str) + self.tree.iter_patches(templated_file=self.templated_file) ): linter_logger.debug(" %s Yielded patch: %s", idx, patch) self._log_hints(patch, self.templated_file) - # Attempt to convert to source space. + # Get source_slice. try: - source_slice = self.templated_file.templated_slice_to_source_slice( - patch.templated_slice, - ) + enriched_patch = patch.enrich(self.templated_file) except ValueError: # pragma: no cover linter_logger.info( " - Skipping. Source space Value Error. i.e. attempted " @@ -301,10 +297,10 @@ def fix_string(self) -> Tuple[Any, bool]: continue # Check for duplicates - dedupe_tuple = (source_slice, patch.fixed_raw) - if dedupe_tuple in dedupe_buffer: + if enriched_patch.dedupe_tuple() in dedupe_buffer: linter_logger.info( - " - Skipping. Source space Duplicate: %s", dedupe_tuple + " - Skipping. Source space Duplicate: %s", + enriched_patch.dedupe_tuple(), ) continue @@ -318,19 +314,10 @@ def fix_string(self) -> Tuple[Any, bool]: # Get the affected raw slices. local_raw_slices = self.templated_file.raw_slices_spanning_source_slice( - source_slice + enriched_patch.source_slice ) local_type_list = [slc.slice_type for slc in local_raw_slices] - enriched_patch = EnrichedFixPatch( - source_slice=source_slice, - templated_slice=patch.templated_slice, - patch_category=patch.patch_category, - fixed_raw=patch.fixed_raw, - templated_str=self.templated_file.templated_str[patch.templated_slice], - source_str=self.templated_file.source_str[source_slice], - ) - # Deal with the easy cases of 1) New code at end 2) only literals if not local_type_list or set(local_type_list) == {"literal"}: linter_logger.info( diff --git a/src/sqlfluff/core/parser/lexer.py b/src/sqlfluff/core/parser/lexer.py index b625b20e176..d4ea18a66bd 100644 --- a/src/sqlfluff/core/parser/lexer.py +++ b/src/sqlfluff/core/parser/lexer.py @@ -535,6 +535,31 @@ def elements_to_segments( ) ) + # Generate placeholders for any source-only slices that *follow* + # the last element. This happens, for example, if a Jinja templated + # file ends with "{% endif %}", and there's no trailing newline. + if idx == len(elements) - 1: + so_slices = [ + so + for so in source_only_slices + if so.source_idx >= source_slice.stop + ] + for so_slice in so_slices: + segment_buffer.append( + TemplateSegment( + pos_marker=PositionMarker( + slice(so_slice.source_idx, so_slice.end_source_idx()), + slice( + element.template_slice.stop, + element.template_slice.stop, + ), + templated_file, + ), + source_str=so_slice.raw, + block_type=so_slice.slice_type, + ) + ) + # Convert to tuple before return return tuple(segment_buffer) diff --git a/src/sqlfluff/core/parser/segments/base.py b/src/sqlfluff/core/parser/segments/base.py index f12eb4eb924..09afa33b28e 100644 --- a/src/sqlfluff/core/parser/segments/base.py +++ b/src/sqlfluff/core/parser/segments/base.py @@ -13,7 +13,16 @@ from copy import deepcopy from dataclasses import dataclass, field, replace from io import StringIO -from typing import Any, Callable, Dict, Optional, List, Tuple, NamedTuple, Iterator +from typing import ( + Any, + Callable, + Dict, + Optional, + List, + Tuple, + Iterator, + Union, +) import logging from tqdm import tqdm @@ -36,21 +45,54 @@ from sqlfluff.core.parser.matchable import Matchable from sqlfluff.core.parser.markers import PositionMarker from sqlfluff.core.parser.context import ParseContext +from sqlfluff.core.templaters.base import TemplatedFile # Instantiate the linter logger (only for use in methods involved with fixing.) linter_logger = logging.getLogger("sqlfluff.linter") -class FixPatch(NamedTuple): +@dataclass +class FixPatch: """An edit patch for a templated file.""" templated_slice: slice fixed_raw: str # The patch category, functions mostly for debugging and explanation # than for function. It allows traceability of *why* this patch was - # generated. It has no siginificance for processing. + # generated. It has no significance for processing. patch_category: str + def enrich(self, templated_file: TemplatedFile) -> "EnrichedFixPatch": + """Convert patch to source space.""" + source_slice = templated_file.templated_slice_to_source_slice( + self.templated_slice, + ) + return EnrichedFixPatch( + source_slice=source_slice, + templated_slice=self.templated_slice, + patch_category=self.patch_category, + fixed_raw=self.fixed_raw, + templated_str=templated_file.templated_str[self.templated_slice], + source_str=templated_file.source_str[source_slice], + ) + + +@dataclass +class EnrichedFixPatch(FixPatch): + """An edit patch for a source file.""" + + source_slice: slice + templated_str: str + source_str: str + + def enrich(self, templated_file: TemplatedFile) -> "EnrichedFixPatch": + """No-op override of base class function.""" + return self + + def dedupe_tuple(self): + """Generate a tuple of this fix for deduping.""" + return (self.source_slice, self.fixed_raw) + @dataclass class AnchorEditInfo: @@ -1176,7 +1218,9 @@ def _validate_segment_after_fixes(self, rule_code, dialect, fixes_applied, segme def _log_apply_fixes_check_issue(message, *args): # pragma: no cover linter_logger.critical(message, *args) - def iter_patches(self, templated_str: str) -> Iterator[FixPatch]: + def iter_patches( + self, templated_file: TemplatedFile + ) -> Iterator[Union[EnrichedFixPatch, FixPatch]]: """Iterate through the segments generating fix patches. The patches are generated in TEMPLATED space. This is important @@ -1188,6 +1232,7 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]: """ # Does it match? If so we can ignore it. assert self.pos_marker + templated_str = templated_file.templated_str matches = self.raw == templated_str[self.pos_marker.templated_slice] if matches: return @@ -1256,7 +1301,7 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]: insert_buff = "" # Now we deal with any changes *within* the segment itself. - yield from segment.iter_patches(templated_str=templated_str) + yield from segment.iter_patches(templated_file=templated_file) # Once we've dealt with any patches from the segment, update # our position markers. @@ -1266,13 +1311,22 @@ def iter_patches(self, templated_str: str) -> Iterator[FixPatch]: # or insert. Also valid if we still have an insertion buffer here. end_diff = self.pos_marker.templated_slice.stop - templated_idx if end_diff or insert_buff: - yield FixPatch( - slice( - self.pos_marker.templated_slice.stop - end_diff, - self.pos_marker.templated_slice.stop, - ), - insert_buff, + source_slice = segment.pos_marker.source_slice + templated_slice = slice( + self.pos_marker.templated_slice.stop - end_diff, + self.pos_marker.templated_slice.stop, + ) + # By returning an EnrichedFixPatch (rather than FixPatch), which + # includes a source_slice field, we ensure that fixes adjacent + # to source-only slices (e.g. {% endif %}) are placed + # appropriately relative to source-only slices. + yield EnrichedFixPatch( + source_slice=source_slice, + templated_slice=templated_slice, patch_category="end_point", + fixed_raw=insert_buff, + templated_str=templated_file.templated_str[templated_slice], + source_str=templated_file.source_str[source_slice], ) def edit(self, raw): diff --git a/src/sqlfluff/core/rules/base.py b/src/sqlfluff/core/rules/base.py index 797d14f0e7b..7bb69abf8ed 100644 --- a/src/sqlfluff/core/rules/base.py +++ b/src/sqlfluff/core/rules/base.py @@ -656,16 +656,18 @@ def indent(self) -> str: space = " " return space * self.tab_space_size if self.indent_unit == "space" else tab - def is_final_segment(self, context: RuleContext) -> bool: + def is_final_segment(self, context: RuleContext, filter_meta: bool = True) -> bool: """Is the current segment the final segment in the parse tree.""" - if len(self.filter_meta(context.siblings_post)) > 0: + siblings_post = context.siblings_post + if filter_meta: + siblings_post = self.filter_meta(siblings_post) + if len(siblings_post) > 0: # This can only fail on the last segment return False elif len(context.segment.segments) > 0: # This can only fail on the last base segment return False - elif context.segment.is_meta: - # We can't fail on a meta segment + elif filter_meta and context.segment.is_meta: return False else: # We know we are at a leaf of the tree but not necessarily at the end of the @@ -674,9 +676,9 @@ def is_final_segment(self, context: RuleContext) -> bool: # one. child_segment = context.segment for parent_segment in context.parent_stack[::-1]: - possible_children = [ - s for s in parent_segment.segments if not s.is_meta - ] + possible_children = parent_segment.segments + if filter_meta: + possible_children = [s for s in possible_children if not s.is_meta] if len(possible_children) > possible_children.index(child_segment) + 1: return False child_segment = parent_segment diff --git a/src/sqlfluff/core/templaters/slicers/tracer.py b/src/sqlfluff/core/templaters/slicers/tracer.py index 537eb2b6b15..1ae181822a3 100644 --- a/src/sqlfluff/core/templaters/slicers/tracer.py +++ b/src/sqlfluff/core/templaters/slicers/tracer.py @@ -289,7 +289,6 @@ def _slice_template(self) -> List[RawFileSlice]: # parts of the tag at a time. unique_alternate_id = None alternate_code = None - trimmed_content = "" if elem_type.endswith("_end") or elem_type == "raw_begin": block_type = block_types[elem_type] block_subtype = None @@ -436,6 +435,16 @@ def _slice_template(self) -> List[RawFileSlice]: "endfor", "endif", ): + # Replace RawSliceInfo for this slice with one that has + # alternate ID and code for tracking. This ensures, for + # instance, that if a file ends with "{% endif %} (with + # no newline following), that we still generate a + # TemplateSliceInfo for it. + unique_alternate_id = self.next_slice_id() + alternate_code = f"{result[-1].raw}\0{unique_alternate_id}_0" + self.raw_slice_info[result[-1]] = RawSliceInfo( + unique_alternate_id, alternate_code, [] + ) # Record potential forward jump over this block. self.raw_slice_info[result[stack[-1]]].next_slice_indices.append( block_idx diff --git a/src/sqlfluff/rules/L009.py b/src/sqlfluff/rules/L009.py index 781b4b5380c..d953e9b043b 100644 --- a/src/sqlfluff/rules/L009.py +++ b/src/sqlfluff/rules/L009.py @@ -91,7 +91,7 @@ def _eval(self, context: RuleContext) -> Optional[LintResult]: """ # We only care about the final segment of the parse tree. - if not self.is_final_segment(context): + if not self.is_final_segment(context, filter_meta=False): return None # Include current segment for complete stack and reverse. diff --git a/test/api/simple_test.py b/test/api/simple_test.py index 76ad3551776..50128436dd5 100644 --- a/test/api/simple_test.py +++ b/test/api/simple_test.py @@ -72,16 +72,16 @@ "description": "Keywords must be consistently upper case.", }, { - "code": "L009", + "code": "L014", "line_no": 1, "line_pos": 34, - "description": "Files must end with a single trailing newline.", + "description": "Unquoted identifiers must be consistently lower case.", }, { - "code": "L014", + "code": "L009", "line_no": 1, - "line_pos": 34, - "description": "Unquoted identifiers must be consistently lower case.", + "line_pos": 41, + "description": "Files must end with a single trailing newline.", }, ] diff --git a/test/core/templaters/jinja_test.py b/test/core/templaters/jinja_test.py index b50ab31bc71..ba7dd72a3a1 100644 --- a/test/core/templaters/jinja_test.py +++ b/test/core/templaters/jinja_test.py @@ -822,6 +822,10 @@ def test__templater_jinja_slice_template(test, result): ("block_end", slice(113, 127, None), slice(11, 11, None)), ("block_start", slice(27, 46, None), slice(11, 11, None)), ("literal", slice(46, 57, None), slice(11, 22, None)), + ("block_end", slice(57, 70, None), slice(22, 22, None)), + ("block_start", slice(70, 89, None), slice(22, 22, None)), + ("block_end", slice(100, 113, None), slice(22, 22, None)), + ("block_end", slice(113, 127, None), slice(22, 22, None)), ], ), ( @@ -910,8 +914,20 @@ def test__templater_jinja_slice_template(test, result): ("literal", slice(91, 92, None), slice(0, 0, None)), ("block_end", slice(92, 104, None), slice(0, 0, None)), ("literal", slice(104, 113, None), slice(0, 9, None)), - ("templated", slice(113, 139, None), slice(9, 29, None)), - ("literal", slice(139, 156, None), slice(29, 46, None)), + ("templated", slice(113, 139, None), slice(9, 28, None)), + ("literal", slice(139, 156, None), slice(28, 28, None)), + ], + ), + ( + # Test for issue 2822: Handle slicing when there's no newline after + # the Jinja block end. + "{% if true %}\nSELECT 1 + 1\n{%- endif %}", + None, + [ + ("block_start", slice(0, 13, None), slice(0, 0, None)), + ("literal", slice(13, 26, None), slice(0, 13, None)), + ("literal", slice(26, 27, None), slice(13, 13, None)), + ("block_end", slice(27, 39, None), slice(13, 13, None)), ], ), ], diff --git a/test/fixtures/rules/std_rule_cases/L009.yml b/test/fixtures/rules/std_rule_cases/L009.yml index f68cf806f02..50f9e289206 100644 --- a/test/fixtures/rules/std_rule_cases/L009.yml +++ b/test/fixtures/rules/std_rule_cases/L009.yml @@ -33,3 +33,9 @@ test_pass_templated_macro_newlines: {{ columns }} {% endmacro %} SELECT {{ get_keyed_nulls("other_id") }} + +test_fail_templated_no_newline: + # Tricky because there's no newline at the end of the file (following the + # templated code). + fail_str: "{% if true %}\nSELECT 1 + 1\n{%- endif %}" + fix_str: "{% if true %}\nSELECT 1 + 1\n{%- endif %}\n"