diff --git a/e2e_test/batch/functions/substr.slt.part b/e2e_test/batch/functions/substr.slt.part index 651def53e84a..4f87043da3e8 100644 --- a/e2e_test/batch/functions/substr.slt.part +++ b/e2e_test/batch/functions/substr.slt.part @@ -4,7 +4,7 @@ select substr('W7Jc3Vyufj', (INT '-2147483648')); ---- W7Jc3Vyufj -statement error length in substr should be non-negative +statement error negative substring length not allowed select substr('W7Jc3Vyufj', INT '-2147483648', INT '-2147483648'); query T @@ -26,4 +26,4 @@ select substr('W7Jc3Vyufj', INT '-2147483648', INT '2147483647'); query T select substr('a', 2147483646, 1); ---- -(empty) \ No newline at end of file +(empty) diff --git a/src/expr/src/vector_op/substr.rs b/src/expr/src/vector_op/substr.rs index 5175be1eca45..d8accfb19ec8 100644 --- a/src/expr/src/vector_op/substr.rs +++ b/src/expr/src/vector_op/substr.rs @@ -12,40 +12,45 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cmp::{max, min}; use std::fmt::Write; use risingwave_expr_macro::function; -use crate::{bail, Result}; +use crate::{ExprError, Result}; #[function("substr(varchar, int32) -> varchar")] pub fn substr_start(s: &str, start: i32, writer: &mut dyn Write) -> Result<()> { - let start = (start.saturating_sub(1).max(0) as usize).min(s.len()); - writer.write_str(&s[start..]).unwrap(); - Ok(()) -} + let skip = start.saturating_sub(1).max(0) as usize; + + let substr = s.chars().skip(skip); + for char in substr { + writer.write_char(char).unwrap(); + } -// #[function("substr(varchar, 0, int32) -> varchar")] -pub fn substr_for(s: &str, count: i32, writer: &mut dyn Write) -> Result<()> { - let end = min(count as usize, s.len()); - writer.write_str(&s[..end]).unwrap(); Ok(()) } #[function("substr(varchar, int32, int32) -> varchar")] pub fn substr_start_for(s: &str, start: i32, count: i32, writer: &mut dyn Write) -> Result<()> { if count < 0 { - bail!("length in substr should be non-negative: {}", count); + return Err(ExprError::InvalidParam { + name: "length", + reason: "negative substring length not allowed".to_string(), + }); } - let start = start.saturating_sub(1); - // NOTE: we use `s.len()` here as an upper bound. - // This is so it will return an empty slice if it exceeds - // the length of `s`. - // 0 <= begin <= s.len() - let begin = min(max(start, 0) as usize, s.len()); - let end = (start.saturating_add(count).max(0) as usize).min(s.len()); - writer.write_str(&s[begin..end]).unwrap(); + + let skip = start.saturating_sub(1).max(0) as usize; + let take = if start >= 1 { + count as usize + } else { + count.saturating_add(start.saturating_sub(1)).max(0) as usize + }; + + let substr = s.chars().skip(skip).take(take); + for char in substr { + writer.write_char(char).unwrap(); + } + Ok(()) } @@ -56,20 +61,31 @@ mod tests { #[test] fn test_substr() -> Result<()> { let s = "cxscgccdd"; + let us = "上海自来水来自海上"; let cases = [ - (s, Some(4), None, "cgccdd"), - (s, None, Some(3), "cxs"), - (s, Some(4), Some(-2), "[unused result]"), - (s, Some(4), Some(2), "cg"), - (s, Some(-1), Some(-5), "[unused result]"), - (s, Some(-1), Some(5), "cxs"), + (s, 4, None, "cgccdd"), + (s, 4, Some(-2), "[unused result]"), + (s, 4, Some(2), "cg"), + (s, -1, Some(-5), "[unused result]"), + (s, -1, Some(0), ""), + (s, -1, Some(1), ""), + (s, -1, Some(2), ""), + (s, -1, Some(3), "c"), + (s, -1, Some(5), "cxs"), + // Unicode test + (us, 1, Some(3), "上海自"), + (us, 3, Some(3), "自来水"), + (us, 6, Some(2), "来自"), + (us, 6, Some(100), "来自海上"), + (us, 6, None, "来自海上"), + ("Mér", 1, Some(2), "Mé"), ]; for (s, off, len, expected) in cases { let mut writer = String::new(); - match (off, len) { - (Some(off), Some(len)) => { + match len { + Some(len) => { let result = substr_start_for(s, off, len, &mut writer); if len < 0 { assert!(result.is_err()); @@ -78,9 +94,7 @@ mod tests { result? } } - (Some(off), None) => substr_start(s, off, &mut writer)?, - (None, Some(len)) => substr_for(s, len, &mut writer)?, - _ => unreachable!(), + None => substr_start(s, off, &mut writer)?, } assert_eq!(writer, expected); } diff --git a/src/tests/sqlsmith/src/validation.rs b/src/tests/sqlsmith/src/validation.rs index 13fec24bb836..37e71a51a811 100644 --- a/src/tests/sqlsmith/src/validation.rs +++ b/src/tests/sqlsmith/src/validation.rs @@ -64,7 +64,7 @@ fn is_numeric_overflow_error(db_error: &str) -> bool { /// Negative substr error fn is_neg_substr_error(db_error: &str) -> bool { - db_error.contains("length in substr should be non-negative") + db_error.contains("negative substring length not allowed") } /// Certain errors are permitted to occur. This is because: