Skip to content

Commit

Permalink
Fix Jinja variant location correction (#5814)
Browse files Browse the repository at this point in the history
  • Loading branch information
alanmcruickshank committed Apr 26, 2024
1 parent 3844bbc commit 22fc89e
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 46 deletions.
163 changes: 117 additions & 46 deletions src/sqlfluff/core/templaters/jinja.py
Expand Up @@ -718,6 +718,94 @@ def slice_file(
trace = tracer.trace(append_to_templated=kwargs.pop("append_to_templated", ""))
return trace.raw_sliced, trace.sliced_file, trace.templated_str

@staticmethod
def _rectify_templated_slices(
length_deltas: Dict[int, int], sliced_template: List[TemplatedFileSlice]
):
"""This method rectifies the source slices of a variant template.
:TRICKY: We want to yield variants that _look like_ they were
rendered from the original template. However, they were actually
rendered from a modified template, which means they have source
indices which won't line up with the source files. We correct that
here by using the length deltas generated earlier from the
modifications.
This should ensure that lint issues and fixes for the variants are
handled correctly and can be combined with those from the original
template.
"""
# NOTE: We sort the stack because it's important that it's in order
# because we're going to be popping from one end of it. There's no
# guarantee that the items are in a particular order a) because it's
# a dict and b) because they may have been generated out of order.
delta_stack = sorted(length_deltas.items(), key=lambda t: t[0])

adjusted_slices: List[TemplatedFileSlice] = []
carried_delta = 0
for tfs in sliced_template:
if delta_stack:
idx, d = delta_stack[0]
if idx == tfs.source_slice.start + carried_delta:
adjusted_slices.append(
tfs._replace(
# "stretch" the slice by adjusting the end more
# than the start.
source_slice=slice(
tfs.source_slice.start + carried_delta,
tfs.source_slice.stop + carried_delta - d,
)
)
)
carried_delta -= d
delta_stack.pop(0)
continue

# No delta match. Just shift evenly.
adjusted_slices.append(
tfs._replace(
source_slice=slice(
tfs.source_slice.start + carried_delta,
tfs.source_slice.stop + carried_delta,
)
)
)
return adjusted_slices

@staticmethod
def _calculate_variant_score(
raw_sliced: List[RawFileSlice],
sliced_file: List[TemplatedFileSlice],
uncovered_slices: Set[int],
original_source_slices: Dict[int, slice],
) -> int:
"""Compute a score for the variant based from size of covered slices.
NOTE: We need to map this back to the positions in the original
file, and only have the positions in the modified file here.
That means we go translate back via the slice index in raw file.
"""
# First, work out the literal positions in the modified file which
# are now covered.
covered_source_positions = {
tfs.source_slice.start
for tfs in sliced_file
if tfs.slice_type == "literal" and not is_zero_slice(tfs.templated_slice)
}
# Second, convert these back into indices so we can use them to
# refer to the unmodified source file.
covered_raw_slice_idxs = [
idx
for idx, raw_slice in enumerate(raw_sliced)
if raw_slice.source_idx in covered_source_positions
]

return sum(
slice_length(original_source_slices[idx])
for idx in covered_raw_slice_idxs
if idx in uncovered_slices
)

def _handle_unreached_code(
self,
in_str: str,
Expand Down Expand Up @@ -745,7 +833,7 @@ def _handle_unreached_code(

max_variants_generated = 10
max_variants_returned = 5
variants: Dict[str, Tuple[int, JinjaTrace]] = {}
variants: Dict[str, Tuple[int, JinjaTrace, Dict[int, int]]] = {}

# Create a mapping of the original source slices before modification so
# we can adjust the positions post-modification.
Expand All @@ -758,19 +846,27 @@ def _handle_unreached_code(
tracer_probe = copy.deepcopy(tracer_copy)
tracer_trace = copy.deepcopy(tracer_copy)
override_raw_slices = []
# `length_deltas` is to keep track of the length changes associated
# with the changes we're making so we can correct the positions in
# the resulting template.
length_deltas: Dict[int, int] = {}
# Find a path that takes us to 'uncovered_slice'.
choices = tracer_probe.move_to_slice(uncovered_slice, 0)
for branch, options in choices.items():
tag = tracer_probe.raw_sliced[branch].tag
if tag in ("if", "elif"):
raw_file_slice = tracer_probe.raw_sliced[branch]
if raw_file_slice.tag in ("if", "elif"):
# Replace the existing "if" of "elif" expression with a new,
# hardcoded value that hits the target slice in the template
# (here that is options[0]).
new_value = "True" if options[0] == branch + 1 else "False"
tracer_trace.raw_slice_info[
tracer_probe.raw_sliced[branch]
].alternate_code = f"{{% {tag} {new_value} %}}"
new_source = f"{{% {raw_file_slice.tag} {new_value} %}}"
tracer_trace.raw_slice_info[raw_file_slice].alternate_code = (
new_source
)
override_raw_slices.append(branch)
length_deltas[raw_file_slice.source_idx] = len(new_source) - len(
raw_file_slice.raw
)

# Render and analyze the template with the overrides.
variant_key = tuple(
Expand Down Expand Up @@ -803,52 +899,27 @@ def _handle_unreached_code(
else:
# Compute a score for the variant based on the size of initially
# uncovered literal slices it hits.
# NOTE: We need to map this back to the positions in the original
# file, and only have the positions in the modified file here.
# That means we go translate back via the slice index in raw file.

# First, work out the literal positions in the modified file which
# are now covered.
_covered_source_positions = {
tfs.source_slice.start
for tfs in trace.sliced_file
if tfs.slice_type == "literal"
and not is_zero_slice(tfs.templated_slice)
}
# Second, convert these back into indices so we can use them to
# refer to the unmodified source file.
_covered_raw_slice_idxs = [
idx
for idx, raw_slice in enumerate(trace.raw_sliced)
if raw_slice.source_idx in _covered_source_positions
]

score = sum(
slice_length(original_source_slices[idx])
for idx in _covered_raw_slice_idxs
if idx in uncovered_slices
score = self._calculate_variant_score(
raw_sliced=trace.raw_sliced,
sliced_file=trace.sliced_file,
uncovered_slices=uncovered_slices,
original_source_slices=original_source_slices,
)

variants[variant_raw_str] = (score, trace)
variants[variant_raw_str] = (score, trace, length_deltas)

# Return the top-scoring variants.
sorted_variants: List[Tuple[int, JinjaTrace]] = sorted(
sorted_variants: List[Tuple[int, JinjaTrace, Dict[int, int]]] = sorted(
variants.values(), key=lambda v: v[0], reverse=True
)
for _, trace in sorted_variants[:max_variants_returned]:
# :TRICKY: Yield variants that _look like_ they were rendered from
# the original template, but actually were rendered from a modified
# template. This should ensure that lint issues and fixes for the
# variants are handled correctly and can be combined with those from
# the original template.
# To do this we run through modified slices and adjust their source
# slices to correspond with the original version. We do this by referencing
# their slice position in the original file, because we know we haven't
# changed the number or ordering of slices, just their length/content.
adjusted_slices: List[TemplatedFileSlice] = [
tfs._replace(source_slice=original_source_slices[idx])
for idx, tfs in enumerate(trace.sliced_file)
]
for _, trace, deltas in sorted_variants[:max_variants_returned]:
# Rectify the source slices of the generated template, which should
# ensure that lint issues and fixes for the variants are handled
# correctly and can be combined with those from the original template.
adjusted_slices = self._rectify_templated_slices(
deltas,
trace.sliced_file,
)
yield (
tracer_copy.raw_sliced,
adjusted_slices,
Expand Down
27 changes: 27 additions & 0 deletions test/core/templaters/jinja_test.py
Expand Up @@ -1823,17 +1823,44 @@ def test_undefined_magic_methods():
],
id="if_true_elif_type_error_else",
),
# https://github.com/sqlfluff/sqlfluff/issues/5803
pytest.param(
"inline_select.sql",
[
"select 2\n",
"select 1\n",
],
id="inline_select",
),
],
)
def test__templater_lint_unreached_code(sql_path: str, expected_renderings):
"""Test that Jinja templater slices raw and templated file correctly."""
test_dir = Path("test/fixtures/templater/jinja_lint_unreached_code")
t = JinjaTemplater()
renderings = []
raw_slicings = []
final_source_slices = []
for templated_file, _ in t.process_with_variants(
in_str=(test_dir / sql_path).read_text(),
fname=str(sql_path),
config=FluffConfig.from_path(str(test_dir)),
):
renderings.append(templated_file.templated_str)
raw_slicings.append(templated_file.raw_sliced)
# Capture the final slice for all of them.
final_source_slices.append(templated_file.sliced_file[-1].source_slice)
assert renderings == expected_renderings
# Compare all of the additional raw slicings to make sure they're the
# same as the root.
root_slicing = raw_slicings[0]
for additional_slicing in raw_slicings[1:]:
assert additional_slicing == root_slicing
# Check that the final source slices also line up in the templated files.
# NOTE: Clearly the `templated_slice` values _won't_ be the same.
# We're doing the _final_ slice, because it's very likely to be the same
# _type_ and if it's in the right place, we can assume that all of the
# others probably are.
root_final_slice = final_source_slices[0]
for additional_final_slice in final_source_slices[1:]:
assert additional_final_slice == root_final_slice
@@ -0,0 +1 @@
select {% if 1 > 2 %}1{% else %}2{% endif %}

0 comments on commit 22fc89e

Please sign in to comment.