Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug assertions to validate NUL terminator in c strings #93979

Merged
merged 1 commit into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 29 additions & 12 deletions library/std/src/ffi/c_str.rs
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ impl CString {
let bytes: Vec<u8> = self.into();
match memchr::memchr(0, &bytes) {
Some(i) => Err(NulError(i, bytes)),
None => Ok(unsafe { CString::from_vec_unchecked(bytes) }),
None => Ok(unsafe { CString::_from_vec_unchecked(bytes) }),
}
}
}
Expand All @@ -405,7 +405,7 @@ impl CString {
// This allows better optimizations if lto enabled.
match memchr::memchr(0, bytes) {
Some(i) => Err(NulError(i, buffer)),
None => Ok(unsafe { CString::from_vec_unchecked(buffer) }),
None => Ok(unsafe { CString::_from_vec_unchecked(buffer) }),
}
}

Expand Down Expand Up @@ -451,10 +451,15 @@ impl CString {
/// ```
#[must_use]
#[stable(feature = "rust1", since = "1.0.0")]
pub unsafe fn from_vec_unchecked(mut v: Vec<u8>) -> CString {
pub unsafe fn from_vec_unchecked(v: Vec<u8>) -> Self {
debug_assert!(memchr::memchr(0, &v).is_none());
unsafe { Self::_from_vec_unchecked(v) }
}

unsafe fn _from_vec_unchecked(mut v: Vec<u8>) -> Self {
v.reserve_exact(1);
v.push(0);
CString { inner: v.into_boxed_slice() }
Self { inner: v.into_boxed_slice() }
}

/// Retakes ownership of a `CString` that was transferred to C via
Expand Down Expand Up @@ -578,7 +583,7 @@ impl CString {
pub fn into_string(self) -> Result<String, IntoStringError> {
String::from_utf8(self.into_bytes()).map_err(|e| IntoStringError {
error: e.utf8_error(),
inner: unsafe { CString::from_vec_unchecked(e.into_bytes()) },
inner: unsafe { Self::_from_vec_unchecked(e.into_bytes()) },
})
}

Expand Down Expand Up @@ -735,6 +740,11 @@ impl CString {
#[must_use]
#[stable(feature = "cstring_from_vec_with_nul", since = "1.58.0")]
pub unsafe fn from_vec_with_nul_unchecked(v: Vec<u8>) -> Self {
debug_assert!(memchr::memchr(0, &v).unwrap() + 1 == v.len());
unsafe { Self::_from_vec_with_nul_unchecked(v) }
}

unsafe fn _from_vec_with_nul_unchecked(v: Vec<u8>) -> Self {
Self { inner: v.into_boxed_slice() }
}

Expand Down Expand Up @@ -778,7 +788,7 @@ impl CString {
Some(nul_pos) if nul_pos + 1 == v.len() => {
// SAFETY: We know there is only one nul byte, at the end
// of the vec.
Ok(unsafe { Self::from_vec_with_nul_unchecked(v) })
Ok(unsafe { Self::_from_vec_with_nul_unchecked(v) })
}
Some(nul_pos) => Err(FromVecWithNulError {
error_kind: FromBytesWithNulErrorKind::InteriorNul(nul_pos),
Expand Down Expand Up @@ -811,7 +821,7 @@ impl ops::Deref for CString {

#[inline]
fn deref(&self) -> &CStr {
unsafe { CStr::from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) }
unsafe { CStr::_from_bytes_with_nul_unchecked(self.as_bytes_with_nul()) }
}
}

Expand Down Expand Up @@ -922,7 +932,7 @@ impl From<Vec<NonZeroU8>> for CString {
};
// SAFETY: `v` cannot contain null bytes, given the type-level
// invariant of `NonZeroU8`.
CString::from_vec_unchecked(v)
Self::_from_vec_unchecked(v)
}
}
}
Expand Down Expand Up @@ -1215,7 +1225,7 @@ impl CStr {
unsafe {
let len = sys::strlen(ptr);
let ptr = ptr as *const u8;
CStr::from_bytes_with_nul_unchecked(slice::from_raw_parts(ptr, len as usize + 1))
Self::_from_bytes_with_nul_unchecked(slice::from_raw_parts(ptr, len as usize + 1))
}
}

Expand Down Expand Up @@ -1258,7 +1268,7 @@ impl CStr {
Some(nul_pos) if nul_pos + 1 == bytes.len() => {
// SAFETY: We know there is only one nul byte, at the end
// of the byte slice.
Ok(unsafe { Self::from_bytes_with_nul_unchecked(bytes) })
Ok(unsafe { Self::_from_bytes_with_nul_unchecked(bytes) })
}
Some(nul_pos) => Err(FromBytesWithNulError::interior_nul(nul_pos)),
None => Err(FromBytesWithNulError::not_nul_terminated()),
Expand Down Expand Up @@ -1287,12 +1297,19 @@ impl CStr {
#[stable(feature = "cstr_from_bytes", since = "1.10.0")]
#[rustc_const_stable(feature = "const_cstr_unchecked", since = "1.59.0")]
pub const unsafe fn from_bytes_with_nul_unchecked(bytes: &[u8]) -> &CStr {
// We're in a const fn, so this is the best we can do
debug_assert!(!bytes.is_empty() && bytes[bytes.len() - 1] == 0);
unsafe { Self::_from_bytes_with_nul_unchecked(bytes) }
}

#[inline]
const unsafe fn _from_bytes_with_nul_unchecked(bytes: &[u8]) -> &Self {
// SAFETY: Casting to CStr is safe because its internal representation
// is a [u8] too (safe only inside std).
// Dereferencing the obtained pointer is safe because it comes from a
// reference. Making a reference is then safe because its lifetime
// is bound by the lifetime of the given `bytes`.
unsafe { &*(bytes as *const [u8] as *const CStr) }
unsafe { &*(bytes as *const [u8] as *const Self) }
}

/// Returns the inner pointer to this C string.
Expand Down Expand Up @@ -1555,7 +1572,7 @@ impl ops::Index<ops::RangeFrom<usize>> for CStr {
// byte, since otherwise we could get an empty string that doesn't end
// in a null.
if index.start < bytes.len() {
unsafe { CStr::from_bytes_with_nul_unchecked(&bytes[index.start..]) }
unsafe { CStr::_from_bytes_with_nul_unchecked(&bytes[index.start..]) }
} else {
panic!(
"index out of bounds: the len is {} but the index is {}",
Expand Down
8 changes: 0 additions & 8 deletions library/std/src/ffi/c_str/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ fn build_with_zero2() {
assert!(CString::new(vec![0]).is_err());
}

#[test]
fn build_with_zero3() {
unsafe {
let s = CString::from_vec_unchecked(vec![0]);
assert_eq!(s.as_bytes(), b"\0");
}
}

#[test]
fn formatted() {
let s = CString::new(&b"abc\x01\x02\n\xE2\x80\xA6\xFF"[..]).unwrap();
Expand Down