diff --git a/compiler/rustc_ast/src/tokenstream.rs b/compiler/rustc_ast/src/tokenstream.rs index 4111182c3b7dc..74a8bc059ab18 100644 --- a/compiler/rustc_ast/src/tokenstream.rs +++ b/compiler/rustc_ast/src/tokenstream.rs @@ -978,10 +978,57 @@ pub struct DelimSpan { impl DelimSpan { pub fn from_single(sp: Span) -> Self { - DelimSpan { open: sp, close: sp } + let Some(sm) = rustc_span::source_map::get_source_map() else { + // No source map available. + return Self { open: sp, close: sp }; + }; + + let (open, close) = sm + .span_to_source(sp, |src, start, end| { + let src = match src.get(start..end) { + Some(s) if s.len() >= 2 => s.as_bytes(), + _ => return Ok((sp, sp)), + }; + + // Only check the first and last characters. + // If there is white space or other characters + // other than `( ... )`, `[ ... ]`, and `{ ... }`. + // I assume that is intentionally included in this + // span so we don't want to shrink the span by + // searching for the delimiters, and setting + // the open and close spans to some more interior + // position. + let first = src[0]; + let last = src[src.len() - 1]; + + // Thought maybe scan through if first is '(', '[', or '{' + // and see if the last matches up (e.g. make sure it's not some + // extra mismatched delimiter) + + let pos = (sp.hi() - sp.lo()).0.checked_sub(1).unwrap_or(0); + // If these return `None` just use the default because that + // means the span is too small for there to be a matched pair. + let Some(open) = sp.subspan(0..1) else { + return Ok((sp, sp)); + }; + let Some(close) = sp.subspan(pos..) else { + return Ok((sp, sp)); + }; + + Ok(match (first, last) { + (b'(', b')') | (b'{', b'}') | (b'[', b']') => (open, close), + (_, _) => (sp, sp), + }) + }) + .ok() + .unwrap_or((sp, sp)); + + debug_assert!(open.lo() <= close.lo()); + DelimSpan { open, close } } pub fn from_pair(open: Span, close: Span) -> Self { + debug_assert!(open.lo() <= close.lo()); DelimSpan { open, close } } @@ -990,6 +1037,7 @@ impl DelimSpan { } pub fn entire(self) -> Span { + debug_assert!(self.open.lo() <= self.close.lo()); self.open.with_hi(self.close.hi()) } } diff --git a/compiler/rustc_span/src/lib.rs b/compiler/rustc_span/src/lib.rs index 2e03ccb1aa1a3..72631f1050ee6 100644 --- a/compiler/rustc_span/src/lib.rs +++ b/compiler/rustc_span/src/lib.rs @@ -77,7 +77,7 @@ use std::cmp::{self, Ordering}; use std::fmt::Display; use std::hash::Hash; use std::io::{self, Read}; -use std::ops::{Add, Range, Sub}; +use std::ops::{Add, Bound, Range, RangeBounds, Sub}; use std::path::{Path, PathBuf}; use std::str::FromStr; use std::sync::Arc; @@ -1176,6 +1176,37 @@ impl Span { pub fn normalize_to_macro_rules(self) -> Span { self.map_ctxt(|ctxt| ctxt.normalize_to_macro_rules()) } + + /// This function is similar to `Span::from_inner`, but it + /// will return `None` if the relative Range span exceeds + /// the bounds of span. + pub fn subspan>(self, subspan: R) -> Option + where + u32: TryFrom, + { + let lo = self.lo().0; + let hi = self.hi().0; + + let start = match subspan.start_bound() { + Bound::Included(s) => u32::try_from(*s).ok()?, + Bound::Excluded(s) => u32::try_from(*s).ok()?.checked_add(1)?, + Bound::Unbounded => 0, + }; + + let end = match subspan.end_bound() { + Bound::Included(e) => u32::try_from(*e).ok()?.checked_add(1)?, + Bound::Excluded(e) => u32::try_from(*e).ok()?, + Bound::Unbounded => hi - lo, + }; + + let new_lo = lo.checked_add(start)?; + let new_hi = lo.checked_add(end)?; + if new_lo > hi || new_hi > hi { + return None; + } + + Some(self.with_lo(BytePos(new_lo)).with_hi(BytePos(new_hi))) + } } impl Default for Span {