Skip to content

Commit

Permalink
Prevent rules incorrectly returning conflicting fixes to same position (
Browse files Browse the repository at this point in the history
#2830)

* Fix L052 bug deleting space after Snowflake SET statement

* Make a similar fix elsewhere in L052, DRY up the code

* Update tests to raise error on fixes with duplicate anchors

* Fix L036 bug: Deleting the same segment multiple times

* Rework the fixes to use object *identity*, not *equality*

* Comments

* Update L052 to use IdentitySet

* Coverage checker

* Update src/sqlfluff/core/linter/linter.py

* Update test/fixtures/rules/std_rule_cases/L052.yml

Co-authored-by: Barry Pollard <barry_pollard@hotmail.com>

* PR review: Raise log level for these messages to CRITICAL

* Add comment per code review

* Update test/fixtures/rules/std_rule_cases/L052.yml

* Bug fix to IdentitySet

* Update L039 to avoid returning multiple fixes with same anchors

* Add test case for fixed L039 issue

* Fix one of the post-fix parse error bugs in L053

* Fix L053 bug, allow create_before + create_after for same anchor

* Tweaks

* Fix type annotation

* Comments, tweaks

* Fix bug in L050, fix missing space in message

* Remove commented-out code

* Add "pragma: no cover"

* Add "pragma: no cover"

* PR review

* More L039 comments

* Discard fixes from lint results if tehre are conflicts

Co-authored-by: Barry Hart <barry.hart@mailchimp.com>
Co-authored-by: Barry Pollard <barry_pollard@hotmail.com>
  • Loading branch information
3 people committed Mar 10, 2022
1 parent e94005e commit 0e91f42
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 87 deletions.
30 changes: 27 additions & 3 deletions src/sqlfluff/core/linter/linter.py
Expand Up @@ -329,6 +329,12 @@ def remove_templated_errors(
result.append(e)
return result

@staticmethod
def _report_conflicting_fixes_same_anchor(message: str): # pragma: no cover
# This function exists primarily in order to let us monkeypatch it at
# runtime (replacing it with a function that raises an exception).
linter_logger.critical(message)

@staticmethod
def _warn_unfixable(code: str):
linter_logger.warning(
Expand Down Expand Up @@ -514,12 +520,30 @@ def lint_fix_parsed(
if fix and fixes:
linter_logger.info(f"Applying Fixes [{crawler.code}]: {fixes}")
# Do some sanity checks on the fixes before applying.
if fixes == last_fixes: # pragma: no cover
anchor_info = BaseSegment.compute_anchor_edit_info(fixes)
if any(
not info.is_valid for info in anchor_info.values()
): # pragma: no cover
message = (
f"Rule {crawler.code} returned conflicting fixes with the "
f"same anchor. This is only supported for create_before+"
f"create_after, so the fixes will not be applied. {fixes!r}"
)
cls._report_conflicting_fixes_same_anchor(message)
for lint_result in linting_errors:
lint_result.fixes = []
elif fixes == last_fixes: # pragma: no cover
# If we generate the same fixes two times in a row,
# that means we're in a loop, and we want to stop.
# (Fixes should address issues, hence different
# and/or fewer fixes next time.)
cls._warn_unfixable(crawler.code)
else:
# This is the happy path. We have fixes, now we want to
# apply them.
last_fixes = fixes
new_tree, _ = tree.apply_fixes(
config.get("dialect_obj"), crawler.code, fixes
config.get("dialect_obj"), crawler.code, anchor_info
)
# Check for infinite loops
if new_tree.raw not in previous_versions:
Expand All @@ -530,7 +554,7 @@ def lint_fix_parsed(
continue
else:
# Applying these fixes took us back to a state which we've
# seen before. Abort.
# seen before. We're in a loop, so we want to stop.
cls._warn_unfixable(crawler.code)

if loop == 0:
Expand Down
141 changes: 118 additions & 23 deletions src/sqlfluff/core/parser/segments/base.py
Expand Up @@ -8,10 +8,12 @@
analysis.
"""

from collections import defaultdict
from collections.abc import MutableSet
from copy import deepcopy
from dataclasses import replace
from dataclasses import dataclass, field, replace
from io import StringIO
from typing import Any, Callable, Optional, List, Tuple, NamedTuple, Iterator
from typing import Any, Callable, Dict, Optional, List, Tuple, NamedTuple, Iterator
import logging

from tqdm import tqdm
Expand Down Expand Up @@ -50,6 +52,48 @@ class FixPatch(NamedTuple):
patch_category: str


@dataclass
class AnchorEditInfo:
"""For a given fix anchor, count of the fix edit types and fixes for it."""

delete: int = field(default=0)
replace: int = field(default=0)
create_before: int = field(default=0)
create_after: int = field(default=0)
fixes: List = field(default_factory=list)

def add(self, fix):
"""Adds the fix and updates stats."""
self.fixes.append(fix)
setattr(self, fix.edit_type, getattr(self, fix.edit_type) + 1)

@property
def total(self):
"""Returns total count of fixes."""
return len(self.fixes)

@property
def is_valid(self):
"""Returns True if valid combination of fixes for anchor.
Cases:
* 0-1 fixes of any type: Valid
* 2 fixes: Valid if and only if types are create_before and create_after
"""
if self.total <= 1:
# Definitely valid (i.e. no conflict) if 0 or 1. In practice, this
# function probably won't be called if there are 0 fixes, but 0 is
# valid; it simply means "no fixes to apply".
return True
if self.total == 2:
# This is only OK for this special case. We allow this because
# the intent is clear (i.e. no conflict): Insert something *before*
# the segment and something else *after* the segment.
return self.create_before == 1 and self.create_after == 1
# Definitely bad if > 2.
return False # pragma: no cover


class BaseSegment:
"""The base segment element.
Expand Down Expand Up @@ -1008,13 +1052,21 @@ def apply_fixes(self, dialect, rule_code, fixes):
else:
seg = todo_buffer.pop(0)

fix_buff = fixes.copy()
unused_fixes = []
while fix_buff:
f = fix_buff.pop()
# Look for identity not just equality.
# This handles potential positioning ambiguity.
if f.anchor is seg:
# Look for identity not just equality.
# This handles potential positioning ambiguity.
anchor_info: Optional[AnchorEditInfo] = fixes.pop(id(seg), None)
if anchor_info is not None:
seg_fixes = anchor_info.fixes
if (
len(seg_fixes) == 2
and seg_fixes[0].edit_type == "create_after"
): # pragma: no cover
# Must be create_before & create_after. Swap so the
# "before" comes first.
seg_fixes.reverse()

for f in anchor_info.fixes:
assert f.anchor is seg
fixes_applied.append(f)
linter_logger.debug(
"Matched fix against segment: %s -> %s", f, seg
Expand All @@ -1027,9 +1079,13 @@ def apply_fixes(self, dialect, rule_code, fixes):
"create_before",
"create_after",
):
if f.edit_type == "create_after":
# in the case of a creation after, also add this
# segment before the edit.
if (
f.edit_type == "create_after"
and len(anchor_info.fixes) == 1
):
# in the case of a creation after that is not part
# of a create_before/create_after pair, also add
# this segment before the edit.
seg_buffer.append(seg)

# We're doing a replacement (it could be a single
Expand All @@ -1051,18 +1107,8 @@ def apply_fixes(self, dialect, rule_code, fixes):
f.edit_type, f
)
)
# We've applied a fix here. Move on, this also consumes the
# fix
# TODO: Maybe deal with overlapping fixes later.
break
else:
# We've not used the fix so we should keep it in the list
# for later.
unused_fixes.append(f)
else:
seg_buffer.append(seg)
# Switch over the the unused list
fixes = unused_fixes + fix_buff
# Invalidate any caches
self.invalidate_caches()

Expand Down Expand Up @@ -1090,6 +1136,17 @@ def apply_fixes(self, dialect, rule_code, fixes):
else:
return self, fixes

@classmethod
def compute_anchor_edit_info(cls, fixes) -> Dict[int, AnchorEditInfo]:
"""Group and count fixes by anchor, return dictionary."""
anchor_info = defaultdict(AnchorEditInfo) # type: ignore
for fix in fixes:
# :TRICKY: Use segment id() as the dictionary key since
# different segments may compare as equal.
anchor_id = id(fix.anchor)
anchor_info[anchor_id].add(fix)
return dict(anchor_info)

def _validate_segment_after_fixes(self, rule_code, dialect, fixes_applied, segment):
"""Checks correctness of new segment against match or parse grammar."""
root_parse_context = RootParseContext(dialect=dialect)
Expand Down Expand Up @@ -1117,7 +1174,7 @@ def _validate_segment_after_fixes(self, rule_code, dialect, fixes_applied, segme

@staticmethod
def _log_apply_fixes_check_issue(message, *args): # pragma: no cover
linter_logger.warning(message, *args)
linter_logger.critical(message, *args)

def iter_patches(self, templated_str: str) -> Iterator[FixPatch]:
"""Iterate through the segments generating fix patches.
Expand Down Expand Up @@ -1336,3 +1393,41 @@ def get_table_references(self):
for stmt in self.get_children("statement"):
references |= stmt.get_table_references()
return references


class IdentitySet(MutableSet):
"""Similar to built-in set(), but based on object IDENTITY.
This is often important when working with BaseSegment and other types,
where different object instances may compare as equal.
Copied from: https://stackoverflow.com/questions/16994307/identityset-in-python
"""

key = id # should return a hashable object

def __init__(self, iterable=()):
self.map = {} # id -> object
self |= iterable # add elements from iterable to the set (union)

def __len__(self): # Sized
return len(self.map)

def __iter__(self): # Iterable
return self.map.values().__iter__() # pragma: no cover

def __contains__(self, x): # Container
return self.key(x) in self.map

def add(self, value): # MutableSet
"""Add an element."""
self.map[self.key(value)] = value

def discard(self, value): # MutableSet
"""Remove an element. Do not raise an exception if absent."""
self.map.pop(self.key(value), None) # pragma: no cover

def __repr__(self): # pragma: no cover
if not self:
return "%s()" % (self.__class__.__name__,)
return "%s(%r)" % (self.__class__.__name__, list(self))
35 changes: 24 additions & 11 deletions src/sqlfluff/rules/L036.py
Expand Up @@ -5,6 +5,7 @@
from sqlfluff.core.parser import WhitespaceSegment

from sqlfluff.core.parser import BaseSegment, NewlineSegment
from sqlfluff.core.parser.segments.base import IdentitySet
from sqlfluff.core.rules.base import BaseRule, LintFix, LintResult, RuleContext
from sqlfluff.core.rules.doc_decorators import document_fix_compatible
from sqlfluff.core.rules.functional import Segments
Expand Down Expand Up @@ -254,6 +255,7 @@ def _eval_single_select_target_element(

def _fixes_for_move_after_select_clause(
stop_seg: BaseSegment,
delete_segments: Optional[Segments] = None,
add_newline: bool = True,
) -> List[LintFix]:
"""Cleans up by moving leftover select_clause segments.
Expand All @@ -277,10 +279,25 @@ def _fixes_for_move_after_select_clause(
start_seg=start_seg,
stop_seg=stop_seg,
)
fixes = [
LintFix.delete(seg) for seg in move_after_select_clause
# :TRICKY: Below, we have a couple places where we
# filter to guard against deleting the same segment
# multiple times -- this is illegal.
# :TRICKY: Use IdentitySet rather than set() since
# different segments may compare as equal.
all_deletes = IdentitySet(
fix.anchor for fix in fixes if fix.edit_type == "delete"
)
fixes_ = []
for seg in delete_segments or []:
if seg not in all_deletes:
fixes.append(LintFix.delete(seg))
all_deletes.add(seg)
fixes_ += [
LintFix.delete(seg)
for seg in move_after_select_clause
if seg not in all_deletes
]
fixes.append(
fixes_.append(
LintFix.create_after(
self._choose_anchor_segment(
context,
Expand All @@ -292,7 +309,7 @@ def _fixes_for_move_after_select_clause(
+ list(move_after_select_clause),
)
)
return fixes
return fixes_

if select_stmt.segments[after_select_clause_idx].is_type("newline"):
# Since we're deleting the newline, we should also delete all
Expand All @@ -304,10 +321,7 @@ def _fixes_for_move_after_select_clause(
loop_while=sp.is_type("whitespace"),
start_seg=select_children[start_idx],
)
fixes += [LintFix.delete(seg) for seg in to_delete]

if to_delete:

# The select_clause is immediately followed by a
# newline. Delete the newline in order to avoid leaving
# behind an empty line after fix, *unless* we stopped
Expand All @@ -325,15 +339,15 @@ def _fixes_for_move_after_select_clause(
)

fixes += _fixes_for_move_after_select_clause(
to_delete[-1],
to_delete[-1], to_delete
)
elif select_stmt.segments[after_select_clause_idx].is_type(
"whitespace"
):
# The select_clause has stuff after (most likely a comment)
# Delete the whitespace immediately after the select clause
# so the other stuff aligns nicely based on where the select
# clause started
# clause started.
fixes += [
LintFix.delete(
select_stmt.segments[after_select_clause_idx],
Expand All @@ -354,11 +368,10 @@ def _fixes_for_move_after_select_clause(
loop_while=sp.is_type("whitespace"),
start_seg=select_children[select_clause_idx - 1],
)
fixes += [LintFix.delete(seg) for seg in to_delete]

if to_delete:
fixes += _fixes_for_move_after_select_clause(
to_delete[-1],
to_delete,
# If we deleted a newline, create a newline.
any(seg for seg in to_delete if seg.is_type("newline")),
)
Expand Down

0 comments on commit 0e91f42

Please sign in to comment.