diff --git a/Cargo.lock b/Cargo.lock index 48efd88ff..ec8a406c8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1722,6 +1722,18 @@ dependencies = [ "wasi 0.11.0+wasi-snapshot-preview1", ] +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi 0.14.7+wasi-0.2.4", +] + [[package]] name = "gimli" version = "0.31.1" @@ -2172,9 +2184,9 @@ checksum = "d75a2a4b1b190afb6f5425f10f6a8f959d2ea0b9c2b1d79553551850539e4674" [[package]] name = "js-sys" -version = "0.3.76" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6717b6b5b077764fb5966237269cb3c64edddde4b14ce42647430a78ced9e7b7" +checksum = "ec48937a97411dcb524a265206ccd4c90bb711fca92b2792c407f268825b9305" dependencies = [ "once_cell", "wasm-bindgen", @@ -3202,6 +3214,7 @@ dependencies = [ "sqlx", "tokio", "tree-sitter", + "uuid", ] [[package]] @@ -3629,6 +3642,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.8.5" @@ -5182,11 +5201,13 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.11.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8c5f0a0af699448548ad1a2fbf920fb4bee257eae39953ba95cb84891a0446a" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ - "getrandom 0.2.15", + "getrandom 0.3.3", + "js-sys", + "wasm-bindgen", ] [[package]] @@ -5270,6 +5291,24 @@ version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasite" version = "0.1.0" @@ -5278,20 +5317,22 @@ checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b" [[package]] name = "wasm-bindgen" -version = "0.2.99" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a474f6281d1d70c17ae7aa6a613c87fce69a127e2624002df63dcb39d6cf6396" +checksum = "c1da10c01ae9f1ae40cbfac0bac3b1e724b320abfcf52229f80b547c0d250e2d" dependencies = [ "cfg-if", "once_cell", + "rustversion", "wasm-bindgen-macro", + "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.99" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f89bb38646b4f81674e8f5c3fb81b562be1fd936d84320f3264486418519c79" +checksum = "671c9a5a66f49d8a47345ab942e2cb93c7d1d0339065d4f8139c486121b43b19" dependencies = [ "bumpalo", "log", @@ -5303,9 +5344,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.49" +version = "0.4.54" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38176d9b44ea84e9184eff0bc34cc167ed044f816accfe5922e54d84cf48eca2" +checksum = "7e038d41e478cc73bae0ff9b36c60cff1c98b8f38f8d7e8061e79ee63608ac5c" dependencies = [ "cfg-if", "js-sys", @@ -5316,9 +5357,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.99" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cc6181fd9a7492eef6fef1f33961e3695e4579b9872a6f7c83aee556666d4fe" +checksum = "7ca60477e4c59f5f2986c50191cd972e3a50d8a95603bc9434501cf156a9a119" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -5326,9 +5367,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.99" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30d7a95b763d3c45903ed6c81f156801839e5ee968bb07e534c44df0fcd330c2" +checksum = "9f07d2f20d4da7b26400c9f4a0511e6e0345b040694e8a75bd41d578fa4421d7" dependencies = [ "proc-macro2", "quote", @@ -5339,15 +5380,18 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.99" +version = "0.2.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +checksum = "bad67dc8b2a1a6e5448428adec4c3e84c43e561d8c9ee8a9e5aabeb193ec41d1" +dependencies = [ + "unicode-ident", +] [[package]] name = "web-sys" -version = "0.3.76" +version = "0.3.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04dd7223427d52553d3702c004d3b2fe07c148165faa56313cb00211e31c12bc" +checksum = "9367c417a924a74cae129e6a2ae3b47fabb1f8995595ab474029da749a8be120" dependencies = [ "js-sys", "wasm-bindgen", @@ -5587,6 +5631,12 @@ version = "0.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d135d17ab770252ad95e9a872d365cf3090e3be864a34ab46f48555993efc904" +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "write-json" version = "0.1.4" diff --git a/crates/pgt_text_size/src/lib.rs b/crates/pgt_text_size/src/lib.rs index 133f6192e..f87dc6445 100644 --- a/crates/pgt_text_size/src/lib.rs +++ b/crates/pgt_text_size/src/lib.rs @@ -21,6 +21,7 @@ mod range; mod size; +mod text_range_replacement; mod traits; #[cfg(feature = "serde")] @@ -29,7 +30,12 @@ mod serde_impls; #[cfg(feature = "schema")] mod schemars_impls; -pub use crate::{range::TextRange, size::TextSize, traits::TextLen}; +pub use crate::{ + range::TextRange, + size::TextSize, + text_range_replacement::{TextRangeReplacement, TextRangeReplacementBuilder}, + traits::TextLen, +}; #[cfg(target_pointer_width = "16")] compile_error!("text-size assumes usize >= u32 and does not work on 16-bit targets"); diff --git a/crates/pgt_text_size/src/range.rs b/crates/pgt_text_size/src/range.rs index baab91e9f..44f9f5d59 100644 --- a/crates/pgt_text_size/src/range.rs +++ b/crates/pgt_text_size/src/range.rs @@ -1,3 +1,5 @@ +use std::num::TryFromIntError; + use cmp::Ordering; use { @@ -443,6 +445,28 @@ where } } +impl TryFrom<&Range> for TextRange { + type Error = TryFromIntError; + + fn try_from(value: &Range) -> Result { + let start: TextSize = value.start.try_into()?; + let end: TextSize = value.end.try_into()?; + + Ok(TextRange { start, end }) + } +} + +impl TryFrom> for TextRange { + type Error = TryFromIntError; + + fn try_from(value: Range) -> Result { + let start: TextSize = value.start.try_into()?; + let end: TextSize = value.end.try_into()?; + + Ok(TextRange { start, end }) + } +} + macro_rules! ops { (impl $Op:ident for TextRange by fn $f:ident = $op:tt) => { impl $Op<&TextSize> for TextRange { diff --git a/crates/pgt_text_size/src/text_range_replacement.rs b/crates/pgt_text_size/src/text_range_replacement.rs new file mode 100644 index 000000000..22b17a6e3 --- /dev/null +++ b/crates/pgt_text_size/src/text_range_replacement.rs @@ -0,0 +1,353 @@ +use crate::{TextRange, TextSize}; + +#[derive(Debug)] +enum AdjustmentDirection { + Lengthened, + Shortened, +} + +#[derive(Debug)] +struct AdjustmentMarker { + #[allow(dead_code)] + original_range: TextRange, + adjusted_range: TextRange, + replacement_txt: String, + registered_offset_at_start: TextSize, + + #[allow(dead_code)] + adjustment_direction: AdjustmentDirection, + range_size_difference: TextSize, +} + +impl AdjustmentMarker { + fn new(original_range: TextRange, replacement_txt: &str) -> Self { + let og_range_len = usize::from(original_range.len()); + + let (range_size_difference, adjustment_direction) = if og_range_len <= replacement_txt.len() + { + ( + replacement_txt.len() - og_range_len, + AdjustmentDirection::Lengthened, + ) + } else { + ( + og_range_len - replacement_txt.len(), + AdjustmentDirection::Shortened, + ) + }; + + AdjustmentMarker { + original_range, + adjustment_direction, + replacement_txt: replacement_txt.into(), + range_size_difference: TextSize::new(range_size_difference.try_into().unwrap()), + + // will be calculated during `.build()` + adjusted_range: original_range, + registered_offset_at_start: 0.into(), + } + } + + /// If the original text is `select $1 from $2` and the adjusted text is `select email from auth.x`, + /// and you index into the `x` in the adjusted string, this will "correct" the adjusted range + /// as if it had the original length ('$2', so a length of 2). + /// + /// So, the resulting `TextSize` *will* be corrected "to the left" as though we indexed onto the `u`, since `$2` has a range + /// of two characters. + /// + /// The TextSize *will still* consider the offsets of previous replacements (3 to the right, since `email` is longer than `$1`). + fn adjusted_end_within_clamped_range(&self, position: TextSize) -> TextSize { + let clamped_end = self.adjusted_range.end() - self.range_size_difference; + std::cmp::min(position, clamped_end - TextSize::new(1)) + } +} + +/// Builder for creating a `TextRangeReplacement` that tracks text range adjustments. +/// +/// This builder allows you to register multiple text replacements and their effects on ranges, +/// then build a tracker that can map positions between the original and adjusted text. +#[derive(Debug)] +pub struct TextRangeReplacementBuilder { + markers: Vec, + text: String, +} + +impl TextRangeReplacementBuilder { + /// Creates a new empty builder for range adjustments tracking. + pub fn new(text: &str) -> Self { + Self { + markers: vec![], + text: text.to_string(), + } + } + + /// Registers a text replacement that affects range positions. + /// + /// #### Arguments + /// + /// * `original_range` - The range in the original text that will be replaced + /// * `replacement_text` - The text that will replace the content in the original range + pub fn replace_range(&mut self, original_range: TextRange, replacement_text: &str) { + if usize::from(original_range.len()) == replacement_text.len() { + // if the replacement text is the same length as the to-replace range, + // we can just immediately apply the replacement. + let range: std::ops::Range = original_range.into(); + self.text.replace_range(range, replacement_text); + return; + } + + self.markers + .push(AdjustmentMarker::new(original_range, replacement_text)); + } + + /// Builds the range adjustments tracker from all registered adjustments. + /// + /// The adjustments are processed in order of their starting positions in the original text. + /// Currently only supports lengthening adjustments (where replacement text is longer + /// than the original range). + /// + /// # Panics + /// + /// Panics if any adjustment involves shortening the text, as this is not yet implemented. + pub fn build(mut self) -> TextRangeReplacement { + self.markers.sort_by_key(|r| r.original_range.start()); + + let mut total_offset: usize = 0; + + for marker in self.markers.iter_mut() { + match marker.adjustment_direction { + AdjustmentDirection::Lengthened => { + marker.registered_offset_at_start = total_offset.try_into().unwrap(); + + marker.adjusted_range = TextRange::new( + marker.original_range.start() + marker.registered_offset_at_start, + marker.original_range.end() + + marker.registered_offset_at_start + + marker.range_size_difference, + ); + + total_offset += usize::from(marker.range_size_difference); + } + AdjustmentDirection::Shortened => { + unimplemented!( + "So far, we've only ever lengthened TextRanges. Consider filling up your range with spaces" + ) + } + } + } + + for marker in self.markers.iter().rev() { + let std_range: std::ops::Range = marker.original_range.into(); + self.text + .replace_range(std_range, marker.replacement_txt.as_str()); + } + + TextRangeReplacement { + markers: self.markers, + text: self.text, + } + } +} + +/// Tracks text range adjustments and provides mapping between original and adjusted positions. +/// +/// This struct maintains information about how text ranges have been modified (typically by +/// replacing placeholders with actual values) and can map positions from the adjusted text +/// back to their corresponding positions in the original text. +/// +/// # Example +/// +/// If you have original text `"select $1 from $2"` and replace `$1` with `email` and +/// `$2` with `auth.users`, this tracker can map positions in the adjusted text +/// `"select email from auth.users"` back to positions in the original text. +#[derive(Debug)] +pub struct TextRangeReplacement { + markers: Vec, + text: String, +} + +impl TextRangeReplacement { + /// Returns the adjusted text. + pub fn text(&self) -> &str { + &self.text + } + + /// Maps a position in the adjusted text back to its corresponding position in the original text. + /// + /// + /// #### Example + /// + /// If the original text was `"select $1 from $2"` and it was adjusted to + /// `"select email from auth.users"`, then calling this method with the position + /// of `'m'` in `"email"` would return the position of `'1'` in the original text, and using the position of `'e'` in + /// `'email'` will give you the first `'$'`. + /// + /// The position tracker "clamps" positions, so if you call it with the position of `'l'` in `'email'` , + /// you'd still get the position of `'1'`. + /// + /// The position of `'f'` in `'from'` will give you the position of `'f'` in `'from'`. + pub fn to_original_position(&self, adjusted_position: TextSize) -> TextSize { + if let Some(marker) = self + .markers + .iter() + .rev() + .find(|marker| adjusted_position >= marker.adjusted_range.start()) + { + if adjusted_position >= marker.adjusted_range.end() { + adjusted_position + .checked_sub(marker.registered_offset_at_start) + .unwrap() + .checked_sub(marker.range_size_difference) + .unwrap() + } else { + let clamped = marker.adjusted_end_within_clamped_range(adjusted_position); + clamped + .checked_sub(marker.registered_offset_at_start) + .unwrap() + } + } else { + adjusted_position + } + } + + /// Maps a range in the adjusted text back to its corresponding range in the original text. + #[allow(dead_code)] + pub fn to_original_range(&self, adjusted_range: TextRange) -> TextRange { + // todo(@juleswritescode): optimize with windows + TextRange::new( + self.to_original_position(adjusted_range.start()), + self.to_original_position(adjusted_range.end()), + ) + } +} + +#[cfg(test)] +mod tests { + use crate::TextSize; + + use crate::text_range_replacement::TextRangeReplacementBuilder; + + #[test] + fn tracks_adjustments() { + let sql = "select $1 from $2 where $3 = $4 limit 5;"; + + let range_1: std::ops::Range = 7..9; // $1 + let range_2: std::ops::Range = 15..17; // $2 + let range_3: std::ops::Range = 24..26; // $3 + let range_4: std::ops::Range = 29..31; // $4 + let og_end = sql.len(); + + let mut replacement_builder = TextRangeReplacementBuilder::new(sql); + + let replacement_4 = "'00000000-0000-0000-0000-000000000000'"; + let replacement_3 = "id"; + let replacement_2 = "auth.users"; + let replacement_1 = "email"; + + // start in the middle – the builder can deal with unordered replacements + replacement_builder.replace_range(range_2.clone().try_into().unwrap(), replacement_2); + replacement_builder.replace_range(range_4.clone().try_into().unwrap(), replacement_4); + replacement_builder.replace_range(range_1.clone().try_into().unwrap(), replacement_1); + replacement_builder.replace_range(range_3.clone().try_into().unwrap(), replacement_3); + + let text_replacement = replacement_builder.build(); + + assert_eq!( + text_replacement.text(), + "select email from auth.users where id = '00000000-0000-0000-0000-000000000000' limit 5;" + ); + + let repl_range_1 = 7..12; // email + let repl_range_2 = 18..28; // auth.users + let repl_range_3 = 35..37; // id + let repl_range_4 = 40..78; // '00000000-0000-0000-0000-000000000000' + + // |select |email from auth.users where id = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // |select |$1 from $2 where $3 = $4 limit 5; + for i in 0..repl_range_1.clone().start { + let between_og_0_1 = 0..range_1.start; + let adjusted = + text_replacement.to_original_position(TextSize::new(i.try_into().unwrap())); + assert!(between_og_0_1.contains(&usize::from(adjusted))); + } + + // select |email| from auth.users where id = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select |$1| from $2 where $3 = $4 limit 5; + for i in repl_range_1.clone() { + let pos = TextSize::new(i.try_into().unwrap()); + let og_pos = text_replacement.to_original_position(pos); + assert!(range_1.contains(&usize::from(og_pos))); + } + + // select email| from |auth.users where id = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select $1| from |$2 where $3 = $4 limit 5; + for i in repl_range_1.end..repl_range_2.clone().start { + let between_og_1_2 = range_1.end..range_2.start; + let adjusted = + text_replacement.to_original_position(TextSize::new(i.try_into().unwrap())); + assert!(between_og_1_2.contains(&usize::from(adjusted))); + } + + // select email from |auth.users| where id = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select $1 from |$2| where $3 = $4 limit 5; + for i in repl_range_2.clone() { + let pos = TextSize::new(i.try_into().unwrap()); + let og_pos = text_replacement.to_original_position(pos); + assert!(range_2.contains(&usize::from(og_pos))); + } + + // select email from auth.users| where |id = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select $1 from $2| where |$3 = $4 limit 5; + for i in repl_range_2.end..repl_range_3.clone().start { + let between_og_2_3 = range_2.end..range_3.start; + let adjusted = + text_replacement.to_original_position(TextSize::new(i.try_into().unwrap())); + assert!(between_og_2_3.contains(&usize::from(adjusted))); + } + + // select email from auth.users where |id| = '00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select $1 from $2 where |$3| = $4 limit 5; + // + // NOTE: this isn't even hanlded by the tracker, since `id` has the same length as `$3` + for i in repl_range_3.clone() { + let pos = TextSize::new(i.try_into().unwrap()); + let og_pos = text_replacement.to_original_position(pos); + assert!(range_3.contains(&usize::from(og_pos))); + } + + // select email from auth.users where id| = |'00000000-0000-0000-0000-000000000000' limit 5; + // maps to + // select $1 from $2 where $3| = |$4 limit 5; + for i in repl_range_3.end..repl_range_4.clone().start { + let between_og_3_4 = range_3.end..range_4.start; + let adjusted = + text_replacement.to_original_position(TextSize::new(i.try_into().unwrap())); + assert!(between_og_3_4.contains(&usize::from(adjusted))); + } + + // select email from auth.users where id = |'00000000-0000-0000-0000-000000000000'| limit 5; + // maps to + // select $1 from $2 where $3 = |$4| limit 5; + for i in repl_range_4.clone() { + let pos = TextSize::new(i.try_into().unwrap()); + let og_pos = text_replacement.to_original_position(pos); + assert!(range_4.contains(&usize::from(og_pos))); + } + + // select email from auth.users where id = '00000000-0000-0000-0000-000000000000'| limit 5;| + // maps to + // select $1 from $2 where $3 = $4| limit 5;| + for i in repl_range_4.end..sql.len() { + let between_og_4_end = range_4.end..og_end; + let adjusted = + text_replacement.to_original_position(TextSize::new(i.try_into().unwrap())); + assert!(between_og_4_end.contains(&usize::from(adjusted))); + } + } +} diff --git a/crates/pgt_typecheck/Cargo.toml b/crates/pgt_typecheck/Cargo.toml index 9c0db5d91..1c326a675 100644 --- a/crates/pgt_typecheck/Cargo.toml +++ b/crates/pgt_typecheck/Cargo.toml @@ -24,6 +24,7 @@ pgt_treesitter_grammar.workspace = true sqlx.workspace = true tokio.workspace = true tree-sitter.workspace = true +uuid = { version = "1.18.1", features = ["v4"] } [dev-dependencies] insta.workspace = true diff --git a/crates/pgt_typecheck/src/diagnostics.rs b/crates/pgt_typecheck/src/diagnostics.rs index 2117adbed..994bf8e19 100644 --- a/crates/pgt_typecheck/src/diagnostics.rs +++ b/crates/pgt_typecheck/src/diagnostics.rs @@ -2,7 +2,7 @@ use std::io; use pgt_console::markup; use pgt_diagnostics::{Advices, Diagnostic, LogCategory, MessageAndDescription, Severity, Visit}; -use pgt_text_size::TextRange; +use pgt_text_size::{TextRange, TextRangeReplacement, TextSize}; use sqlx::postgres::{PgDatabaseError, PgSeverity}; /// A specialized diagnostic for the typechecker. @@ -97,7 +97,7 @@ impl Advices for TypecheckAdvices { pub(crate) fn create_type_error( pg_err: &PgDatabaseError, ts: &tree_sitter::Tree, - positions_valid: bool, + txt_replacement: TextRangeReplacement, ) -> TypecheckDiagnostic { let position = pg_err.position().and_then(|pos| match pos { sqlx::postgres::PgErrorPosition::Original(pos) => Some(pos - 1), @@ -105,18 +105,16 @@ pub(crate) fn create_type_error( }); let range = position.and_then(|pos| { - if positions_valid { - ts.root_node() - .named_descendant_for_byte_range(pos, pos) - .map(|node| { - TextRange::new( - node.start_byte().try_into().unwrap(), - node.end_byte().try_into().unwrap(), - ) - }) - } else { - None - } + let adjusted = txt_replacement.to_original_position(TextSize::new(pos.try_into().unwrap())); + + ts.root_node() + .named_descendant_for_byte_range(adjusted.into(), adjusted.into()) + .map(|node| { + TextRange::new( + node.start_byte().try_into().unwrap(), + node.end_byte().try_into().unwrap(), + ) + }) }); let severity = match pg_err.severity() { diff --git a/crates/pgt_typecheck/src/lib.rs b/crates/pgt_typecheck/src/lib.rs index 854cdfadc..ecb4f987b 100644 --- a/crates/pgt_typecheck/src/lib.rs +++ b/crates/pgt_typecheck/src/lib.rs @@ -48,7 +48,7 @@ pub async fn check_sql( // each typecheck operation. conn.close_on_drop(); - let (prepared, positions_valid) = apply_identifiers( + let replacement = apply_identifiers( params.identifiers, params.schema_cache, params.tree, @@ -68,17 +68,13 @@ pub async fn check_sql( conn.execute(&*search_path_query).await?; } - let res = conn.prepare(&prepared).await; + let res = conn.prepare(replacement.text()).await; match res { Ok(_) => Ok(None), Err(sqlx::Error::Database(err)) => { let pg_err = err.downcast_ref::(); - Ok(Some(create_type_error( - pg_err, - params.tree, - positions_valid, - ))) + Ok(Some(create_type_error(pg_err, params.tree, replacement))) } Err(err) => Err(err), } diff --git a/crates/pgt_typecheck/src/typed_identifier.rs b/crates/pgt_typecheck/src/typed_identifier.rs index 199c5950e..ca355dfc6 100644 --- a/crates/pgt_typecheck/src/typed_identifier.rs +++ b/crates/pgt_typecheck/src/typed_identifier.rs @@ -1,16 +1,16 @@ use pgt_schema_cache::PostgresType; +use pgt_text_size::{TextRangeReplacement, TextRangeReplacementBuilder}; use pgt_treesitter::queries::{ParameterMatch, TreeSitterQueriesExecutor}; -/// A typed identifier is a parameter that has a type associated with it. /// It is used to replace parameters within the SQL string. #[derive(Debug)] pub struct TypedIdentifier { /// The path of the parameter, usually the name of the function. /// This is because `fn_name.arg_name` is a valid reference within a SQL function. pub path: String, - /// The name of the argument + /// The name of the parameter pub name: Option, - /// The type of the argument with schema and name + /// The type of the parameter with schema and name pub type_: IdentifierType, } @@ -27,7 +27,7 @@ pub fn apply_identifiers<'a>( schema_cache: &'a pgt_schema_cache::SchemaCache, cst: &'a tree_sitter::Tree, sql: &'a str, -) -> (String, bool) { +) -> TextRangeReplacement { let mut executor = TreeSitterQueriesExecutor::new(cst.root_node(), sql); executor.add_query_results::(); @@ -50,30 +50,21 @@ pub fn apply_identifiers<'a>( }) .collect(); - let mut result = sql.to_string(); + let mut text_range_replacement_builder = TextRangeReplacementBuilder::new(sql); - let mut valid_positions = true; - - // Apply replacements in reverse order to maintain correct byte offsets - for (range, type_, is_array) in replacements.into_iter().rev() { + for (range, type_, is_array) in replacements { let default_value = get_formatted_default_value(type_, is_array); + let range_size = range.end - range.start; + // if the default_value is shorter than "range", fill it up with spaces - let default_value = if default_value.len() < range.end - range.start { - format!("{: range.end - range.start { - valid_positions = false; - } - - result.replace_range(range, &default_value); + let default_value = format!("{: String // Get the base default value for this type let default = resolve_default_value(pg_type); - let default = if default.len() > "NULL".len() { - // If the default value is longer than "NULL", use "NULL" instead - "NULL".to_string() - } else { - // Otherwise, use the default value - default - }; - // For arrays, wrap the default in array syntax if is_array { format!("'{{{}}}'", default) @@ -231,6 +214,7 @@ fn resolve_type<'a>( #[cfg(test)] mod tests { + use pgt_schema_cache::SchemaCache; use sqlx::{Executor, PgPool}; #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] @@ -325,15 +309,62 @@ mod tests { let tree = parser.parse(input, None).unwrap(); - let (sql_out, valid_pos) = - super::apply_identifiers(identifiers, &schema_cache, &tree, input); + let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input); - assert!(valid_pos); assert_eq!( - sql_out, + replacement.text(), // the numeric parameters are filled with 0; - // all values of the enums are longer than `NULL`, so we use `NULL` instead - "select 0 + 0 + 0 + 0 + 0 + NULL " + "select 0 + 0 + 0 + 0 + 0 + 'critical'" + ); + } + + #[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] + async fn test_longer_identifiers(pool: PgPool) { + // create or replace function retrieve(uid uuid, mail text) + // returns uuid + // as $$ + // select id from auth.users where email_change_confirm_status = uid and email = mail; + // $$ + // language sql immutable; + + let input = r#"select id from auth.users where email_change_confirm_status = uid and email = mail;"#; + + let identifiers = vec![ + super::TypedIdentifier { + path: "retrieve".to_string(), + name: Some("uid".to_string()), + type_: super::IdentifierType { + schema: None, + name: "uuid".to_string(), + is_array: false, + }, + }, + super::TypedIdentifier { + path: "retrieve".to_string(), + name: Some("mail".to_string()), + type_: super::IdentifierType { + schema: None, + name: "text".to_string(), + is_array: false, + }, + }, + ]; + + let mut parser = tree_sitter::Parser::new(); + parser + .set_language(&pgt_treesitter_grammar::LANGUAGE.into()) + .expect("Error loading sql language"); + + let schema_cache = SchemaCache::load(&pool).await.unwrap(); + + let tree = parser.parse(input, None).unwrap(); + + let replacement = super::apply_identifiers(identifiers, &schema_cache, &tree, input); + + assert_eq!( + replacement.text(), + // two spaces at the end because mail is longer than '' + r#"select id from auth.users where email_change_confirm_status = '00000000-0000-0000-0000-000000000000' and email = '' ;"# ); } } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index b13a34422..12b8310e4 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -463,7 +463,7 @@ impl Workspace for WorkspaceServer { // Combined async context for both typecheck and plpgsql_check let async_results = run_async(async move { stream::iter(input) - .map(|(id, range, ast, cst, sign)| { + .map(|(id, range, ast, cst, fn_sig)| { let pool = pool.clone(); let path = path_clone.clone(); let schema_cache = Arc::clone(&schema_cache); @@ -484,7 +484,7 @@ impl Workspace for WorkspaceServer { tree: &cst, schema_cache: schema_cache.as_ref(), search_path_patterns, - identifiers: sign + identifiers: fn_sig .map(|s| { s.args .iter()