diff --git a/src/distributions/slice.rs b/src/distributions/slice.rs index 398cad18b2..224bf1712c 100644 --- a/src/distributions/slice.rs +++ b/src/distributions/slice.rs @@ -7,6 +7,8 @@ // except according to those terms. use crate::distributions::{Distribution, Uniform}; +#[cfg(feature = "alloc")] +use alloc::string::String; /// A distribution to sample items uniformly from a slice. /// @@ -115,3 +117,35 @@ impl core::fmt::Display for EmptySlice { #[cfg(feature = "std")] impl std::error::Error for EmptySlice {} + +/// Note: the `String` is potentially left with excess capacity; optionally the +/// user may call `string.shrink_to_fit()` afterwards. +#[cfg(feature = "alloc")] +impl<'a> super::DistString for Slice<'a, char> { + fn append_string(&self, rng: &mut R, string: &mut String, len: usize) { + // Get the max char length to minimize extra space. + // Limit this check to avoid searching for long slice. + let max_char_len = if self.slice.len() < 200 { + self.slice + .iter() + .try_fold(1, |max_len, char| { + // When the current max_len is 4, the result max_char_len will be 4. + Some(max_len.max(char.len_utf8())).filter(|len| *len < 4) + }) + .unwrap_or(4) + } else { + 4 + }; + + // Split the extension of string to reuse the unused capacities. + // Skip the split for small length or only ascii slice. + let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 }; + let mut remain_len = len; + while extend_len > 0 { + string.reserve(max_char_len * extend_len); + string.extend(self.sample_iter(&mut *rng).take(extend_len)); + remain_len -= extend_len; + extend_len = extend_len.min(remain_len); + } + } +} diff --git a/src/distributions/uniform.rs b/src/distributions/uniform.rs index a2664768c9..713961e8e0 100644 --- a/src/distributions/uniform.rs +++ b/src/distributions/uniform.rs @@ -843,6 +843,24 @@ impl UniformSampler for UniformChar { } } +/// Note: the `String` is potentially left with excess capacity if the range +/// includes non ascii chars; optionally the user may call +/// `string.shrink_to_fit()` afterwards. +#[cfg(feature = "alloc")] +impl super::DistString for Uniform{ + fn append_string(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) { + // Getting the hi value to assume the required length to reserve in string. + let mut hi = self.0.sampler.low + self.0.sampler.range - 1; + if hi >= CHAR_SURROGATE_START { + hi += CHAR_SURROGATE_LEN; + } + // Get the utf8 length of hi to minimize extra space. + let max_char_len = char::from_u32(hi).map(char::len_utf8).unwrap_or(4); + string.reserve(max_char_len * len); + string.extend(self.sample_iter(rng).take(len)) + } +} + /// The back-end implementing [`UniformSampler`] for floating-point types. /// /// Unless you are implementing [`UniformSampler`] for your own type, this type @@ -1376,6 +1394,22 @@ mod tests { let c = d.sample(&mut rng); assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); } + #[cfg(feature = "alloc")] + { + use crate::distributions::DistString; + let string1 = d.sample_string(&mut rng, 100); + assert_eq!(string1.capacity(), 300); + let string2 = Uniform::new( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ).unwrap().sample_string(&mut rng, 100); + assert_eq!(string2.capacity(), 100); + let string3 = Uniform::new_inclusive( + core::char::from_u32(0x0000).unwrap(), + core::char::from_u32(0x0080).unwrap(), + ).unwrap().sample_string(&mut rng, 100); + assert_eq!(string3.capacity(), 200); + } } #[test]