Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(expr): correctly handle unicode for substr #9079

Merged
merged 4 commits into from
Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions e2e_test/batch/functions/substr.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ select substr('W7Jc3Vyufj', (INT '-2147483648'));
----
W7Jc3Vyufj

statement error length in substr should be non-negative
statement error negative substring length not allowed
select substr('W7Jc3Vyufj', INT '-2147483648', INT '-2147483648');

query T
Expand All @@ -26,4 +26,4 @@ select substr('W7Jc3Vyufj', INT '-2147483648', INT '2147483647');
query T
select substr('a', 2147483646, 1);
----
(empty)
(empty)
74 changes: 44 additions & 30 deletions src/expr/src/vector_op/substr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,40 +12,45 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::cmp::{max, min};
use std::fmt::Write;

use risingwave_expr_macro::function;

use crate::{bail, Result};
use crate::{ExprError, Result};

#[function("substr(varchar, int32) -> varchar")]
pub fn substr_start(s: &str, start: i32, writer: &mut dyn Write) -> Result<()> {
let start = (start.saturating_sub(1).max(0) as usize).min(s.len());
writer.write_str(&s[start..]).unwrap();
Ok(())
}
let skip = start.saturating_sub(1).max(0) as usize;

let substr = s.chars().skip(skip);
for char in substr {
writer.write_char(char).unwrap();
}

// #[function("substr(varchar, 0, int32) -> varchar")]
pub fn substr_for(s: &str, count: i32, writer: &mut dyn Write) -> Result<()> {
let end = min(count as usize, s.len());
writer.write_str(&s[..end]).unwrap();
Ok(())
}

#[function("substr(varchar, int32, int32) -> varchar")]
pub fn substr_start_for(s: &str, start: i32, count: i32, writer: &mut dyn Write) -> Result<()> {
if count < 0 {
bail!("length in substr should be non-negative: {}", count);
return Err(ExprError::InvalidParam {
name: "length",
reason: "negative substring length not allowed".to_string(),
});
}
let start = start.saturating_sub(1);
// NOTE: we use `s.len()` here as an upper bound.
// This is so it will return an empty slice if it exceeds
// the length of `s`.
// 0 <= begin <= s.len()
let begin = min(max(start, 0) as usize, s.len());
let end = (start.saturating_add(count).max(0) as usize).min(s.len());
writer.write_str(&s[begin..end]).unwrap();

let skip = start.saturating_sub(1).max(0) as usize;
let take = if start >= 1 {
count as usize
} else {
count.saturating_add(start.saturating_sub(1)).max(0) as usize
};

let substr = s.chars().skip(skip).take(take);
for char in substr {
writer.write_char(char).unwrap();
}

Ok(())
}

Expand All @@ -56,20 +61,31 @@ mod tests {
#[test]
fn test_substr() -> Result<()> {
let s = "cxscgccdd";
let us = "上海自来水来自海上";

let cases = [
(s, Some(4), None, "cgccdd"),
(s, None, Some(3), "cxs"),
(s, Some(4), Some(-2), "[unused result]"),
(s, Some(4), Some(2), "cg"),
(s, Some(-1), Some(-5), "[unused result]"),
(s, Some(-1), Some(5), "cxs"),
(s, 4, None, "cgccdd"),
(s, 4, Some(-2), "[unused result]"),
(s, 4, Some(2), "cg"),
(s, -1, Some(-5), "[unused result]"),
(s, -1, Some(0), ""),
(s, -1, Some(1), ""),
(s, -1, Some(2), ""),
(s, -1, Some(3), "c"),
(s, -1, Some(5), "cxs"),
// Unicode test
(us, 1, Some(3), "上海自"),
(us, 3, Some(3), "自来水"),
(us, 6, Some(2), "来自"),
(us, 6, Some(100), "来自海上"),
(us, 6, None, "来自海上"),
("Mér", 1, Some(2), "Mé"),
];

for (s, off, len, expected) in cases {
let mut writer = String::new();
match (off, len) {
(Some(off), Some(len)) => {
match len {
Some(len) => {
let result = substr_start_for(s, off, len, &mut writer);
if len < 0 {
assert!(result.is_err());
Expand All @@ -78,9 +94,7 @@ mod tests {
result?
}
}
(Some(off), None) => substr_start(s, off, &mut writer)?,
(None, Some(len)) => substr_for(s, len, &mut writer)?,
_ => unreachable!(),
None => substr_start(s, off, &mut writer)?,
}
assert_eq!(writer, expected);
}
Expand Down
2 changes: 1 addition & 1 deletion src/tests/sqlsmith/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ fn is_numeric_overflow_error(db_error: &str) -> bool {

/// Negative substr error
fn is_neg_substr_error(db_error: &str) -> bool {
db_error.contains("length in substr should be non-negative")
db_error.contains("negative substring length not allowed")
}

/// Certain errors are permitted to occur. This is because:
Expand Down