From 3a79aa85f4fbb132a5e1789afc7732565fd61e0f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 20 May 2024 17:02:15 +0200 Subject: [PATCH] Fuzzy-match lines when applying edits from the assistant (#12056) This uses Jaro-Winkler similarity for now, which seemed to produce pretty good results in my tests. We can easily swap it with something else if needed. Release Notes: - N/A --- Cargo.lock | 17 ++- crates/assistant/Cargo.toml | 1 + crates/assistant/src/assistant_panel.rs | 6 +- crates/assistant/src/search.rs | 153 ++++++++++++++---------- 4 files changed, 103 insertions(+), 74 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a89e343e1d77..b2d2f17f70ef 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -368,6 +368,7 @@ dependencies = [ "serde_json", "settings", "smol", + "strsim 0.11.1", "telemetry_events", "theme", "tiktoken-rs", @@ -1684,7 +1685,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4c2f7349907b712260e64b0afe2f84692af14a454be26187d9df565c7f69266a" dependencies = [ "memchr", - "regex-automata 0.3.8", + "regex-automata 0.3.9", "serde", ] @@ -2094,7 +2095,7 @@ dependencies = [ "bitflags 1.3.2", "clap_lex 0.2.4", "indexmap 1.9.3", - "strsim", + "strsim 0.10.0", "termcolor", "textwrap", ] @@ -2118,7 +2119,7 @@ dependencies = [ "anstream", "anstyle", "clap_lex 0.5.1", - "strsim", + "strsim 0.10.0", ] [[package]] @@ -8141,9 +8142,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2f401f4955220693b56f8ec66ee9c78abffd8d1c4f23dc41a23839eb88f0795" +checksum = "59b23e92ee4318893fa3fe3e6fb365258efbfe6ac6ab30f090cdcbb7aa37efa9" [[package]] name = "regex-automata" @@ -9783,6 +9784,12 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "strum" version = "0.25.0" diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index cc6cc2e8fd8b..4a3b131a5281 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -40,6 +40,7 @@ serde.workspace = true serde_json.workspace = true settings.workspace = true smol.workspace = true +strsim = "0.11" telemetry_events.workspace = true theme.workspace = true tiktoken-rs.workspace = true diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 9b2b76de4043..603c5aa4398b 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -3058,9 +3058,9 @@ impl ConversationEditor { .entry(buffer) .or_insert(Vec::<(Range, _)>::new()); for suggestion in suggestions { - let ranges = - fuzzy_search_lines(snapshot.as_rope(), &suggestion.old_text); - if let Some(range) = ranges.first() { + if let Some(range) = + fuzzy_search_lines(snapshot.as_rope(), &suggestion.old_text) + { let edit_start = snapshot.anchor_after(range.start); let edit_end = snapshot.anchor_before(range.end); if let Err(ix) = edits.binary_search_by(|(range, _)| { diff --git a/crates/assistant/src/search.rs b/crates/assistant/src/search.rs index f7b957bfdcec..7e8b18ae5056 100644 --- a/crates/assistant/src/search.rs +++ b/crates/assistant/src/search.rs @@ -6,51 +6,75 @@ use std::ops::Range; /// /// Returns a vector of ranges of byte offsets in the buffer corresponding /// to the entire lines of the buffer. -pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Vec> { - let mut matches = Vec::new(); +pub fn fuzzy_search_lines(haystack: &Rope, needle: &str) -> Option> { + const SIMILARITY_THRESHOLD: f64 = 0.8; + + let mut best_match: Option<(Range, f64)> = None; // (range, score) let mut haystack_lines = haystack.chunks().lines(); let mut haystack_line_start = 0; - while let Some(haystack_line) = haystack_lines.next() { + while let Some(mut haystack_line) = haystack_lines.next() { let next_haystack_line_start = haystack_line_start + haystack_line.len() + 1; - let mut trimmed_needle_lines = needle.lines().map(|line| line.trim()); - if Some(haystack_line.trim()) == trimmed_needle_lines.next() { - let match_start = haystack_line_start; - let mut match_end = next_haystack_line_start; - let matched = loop { - match (haystack_lines.next(), trimmed_needle_lines.next()) { - (Some(haystack_line), Some(needle_line)) => { - // Haystack line differs from needle line: not a match. - if haystack_line.trim() == needle_line { - match_end = haystack_lines.offset(); - } else { - break false; - } + let mut advanced_to_next_haystack_line = false; + + let mut matched = true; + let match_start = haystack_line_start; + let mut match_end = next_haystack_line_start; + let mut match_score = 0.0; + let mut needle_lines = needle.lines().peekable(); + while let Some(needle_line) = needle_lines.next() { + let similarity = line_similarity(haystack_line, needle_line); + if similarity >= SIMILARITY_THRESHOLD { + match_end = haystack_lines.offset(); + match_score += similarity; + + if needle_lines.peek().is_some() { + if let Some(next_haystack_line) = haystack_lines.next() { + advanced_to_next_haystack_line = true; + haystack_line = next_haystack_line; + } else { + matched = false; + break; } - // We exhausted the haystack but not the query: not a match. - (None, Some(_)) => break false, - // We exhausted the query: it's a match. - (_, None) => break true, + } else { + break; } - }; - - if matched { - matches.push(match_start..match_end) + } else { + matched = false; + break; } + } - // Advance to the next line. - haystack_lines.seek(next_haystack_line_start); + if matched + && best_match + .as_ref() + .map(|(_, best_score)| match_score > *best_score) + .unwrap_or(true) + { + best_match = Some((match_start..match_end, match_score)); } + if advanced_to_next_haystack_line { + haystack_lines.seek(next_haystack_line_start); + } haystack_line_start = next_haystack_line_start; } - matches + + best_match.map(|(range, _)| range) +} + +/// Calculates the similarity between two lines, ignoring leading and trailing whitespace, +/// using the Jaro-Winkler distance. +/// +/// Returns a value between 0.0 and 1.0, where 1.0 indicates an exact match. +fn line_similarity(line1: &str, line2: &str) -> f64 { + strsim::jaro_winkler(line1.trim(), line2.trim()) } #[cfg(test)] mod test { use super::*; use gpui::{AppContext, Context as _}; - use language::{Buffer, OffsetRangeExt}; + use language::Buffer; use unindent::Unindent as _; use util::test::marked_text_ranges; @@ -79,17 +103,11 @@ mod test { ); » - assert_eq!( + « assert_eq!( "something", "else", ); - - if b { - « assert_eq!( - 1 + 2, - 3, - ); - » } + » } "# .unindent(), @@ -99,7 +117,7 @@ mod test { let buffer = cx.new_model(|cx| Buffer::local(&text, cx)); let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - let actual_ranges = fuzzy_search_lines( + let actual_range = fuzzy_search_lines( snapshot.as_rope(), &" assert_eq!( @@ -108,43 +126,46 @@ mod test { ); " .unindent(), - ); - assert_eq!( - actual_ranges, - expected_ranges, - "actual: {:?}, expected: {:?}", - actual_ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>(), - expected_ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>() - ); + ) + .unwrap(); + assert_eq!(actual_range, expected_ranges[0]); - let actual_ranges = fuzzy_search_lines( + let actual_range = fuzzy_search_lines( snapshot.as_rope(), &" assert_eq!( 1 + 2, 3, - ); + ); + " + .unindent(), + ) + .unwrap(); + assert_eq!(actual_range, expected_ranges[0]); + + let actual_range = fuzzy_search_lines( + snapshot.as_rope(), + &" + asst_eq!( + \"something\", + \"els\" + ) + " + .unindent(), + ) + .unwrap(); + assert_eq!(actual_range, expected_ranges[1]); + + let actual_range = fuzzy_search_lines( + snapshot.as_rope(), + &" + assert_eq!( + 2 + 1, + 3, + ); " .unindent(), ); - assert_eq!( - actual_ranges, - expected_ranges, - "actual: {:?}, expected: {:?}", - actual_ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>(), - expected_ranges - .iter() - .map(|range| range.to_point(&snapshot)) - .collect::>() - ); + assert_eq!(actual_range, None); } }