Skip to content

Commit

Permalink
fix(input)!: Prevent str from being processed as arbitrary bytes
Browse files Browse the repository at this point in the history
This will turn panics into compile errors and open the way for some
unsafe for #115
  • Loading branch information
epage committed Feb 7, 2023
1 parent 8ad67de commit d8533a3
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 46 deletions.
14 changes: 7 additions & 7 deletions src/character/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::error::ErrMode;
use crate::error::ErrorKind;
use crate::error::ParseError;
use crate::input::{
split_at_offset1_complete, split_at_offset_complete, AsBytes, AsChar, ContainsToken, Input,
split_at_offset1_complete, split_at_offset_complete, AsBStr, AsChar, ContainsToken, Input,
};
use crate::input::{Compare, CompareResult};
use crate::IResult;
Expand Down Expand Up @@ -204,7 +204,7 @@ where
)]
pub fn not_line_ending<T, E: ParseError<T>>(input: T) -> IResult<T, <T as Input>::Slice, E>
where
T: Input + AsBytes,
T: Input + AsBStr,
T: Compare<&'static str>,
<T as Input>::Token: AsChar,
{
Expand All @@ -215,7 +215,7 @@ where
None => Ok(input.next_slice(input.input_len())),
Some(offset) => {
let (new_input, res) = input.next_slice(offset);
let bytes = new_input.as_bytes();
let bytes = new_input.as_bstr();
let nth = bytes[0];
if nth == b'\r' {
let comp = new_input.compare("\r\n");
Expand Down Expand Up @@ -952,7 +952,7 @@ mod tests {
let a: &[u8] = b"abcd";
let b: &[u8] = b"1234";
let c: &[u8] = b"a123";
let d: &[u8] = "azé12".as_bytes();
let d: &[u8] = "azé12".as_bstr();
let e: &[u8] = b" ";
let f: &[u8] = b" ;";
//assert_eq!(alpha1::<_, Error<_>>(a), Err(ErrMode::Incomplete(Needed::Size(1))));
Expand All @@ -965,7 +965,7 @@ mod tests {
}))
);
assert_eq!(alpha1::<_, Error<_>>(c), Ok((&c[1..], &b"a"[..])));
assert_eq!(alpha1::<_, Error<_>>(d), Ok(("é12".as_bytes(), &b"az"[..])));
assert_eq!(alpha1::<_, Error<_>>(d), Ok(("é12".as_bstr(), &b"az"[..])));
assert_eq!(
digit1(a),
Err(ErrMode::Backtrack(Error {
Expand Down Expand Up @@ -993,7 +993,7 @@ mod tests {
assert_eq!(hex_digit1::<_, Error<_>>(c), Ok((empty, c)));
assert_eq!(
hex_digit1::<_, Error<_>>(d),
Ok(("zé12".as_bytes(), &b"a"[..]))
Ok(("zé12".as_bstr(), &b"a"[..]))
);
assert_eq!(
hex_digit1(e),
Expand Down Expand Up @@ -1029,7 +1029,7 @@ mod tests {
assert_eq!(alphanumeric1::<_, Error<_>>(c), Ok((empty, c)));
assert_eq!(
alphanumeric1::<_, Error<_>>(d),
Ok(("é12".as_bytes(), &b"az"[..]))
Ok(("é12".as_bstr(), &b"az"[..]))
);
assert_eq!(space1::<_, Error<_>>(e), Ok((empty, e)));
assert_eq!(space1::<_, Error<_>>(f), Ok((&b";"[..], &b" "[..])));
Expand Down
10 changes: 5 additions & 5 deletions src/character/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::combinator::opt;
use crate::error::ParseError;
use crate::error::{ErrMode, ErrorKind, Needed};
use crate::input::Compare;
use crate::input::{AsBytes, AsChar, Input, InputIsStreaming, Offset, ParseTo};
use crate::input::{AsBStr, AsChar, Input, InputIsStreaming, Offset, ParseTo};
use crate::IResult;
use crate::Parser;

Expand Down Expand Up @@ -102,7 +102,7 @@ pub fn not_line_ending<I, E: ParseError<I>, const STREAMING: bool>(
) -> IResult<I, <I as Input>::Slice, E>
where
I: InputIsStreaming<STREAMING>,
I: Input + AsBytes,
I: Input + AsBStr,
I: Compare<&'static str>,
<I as Input>::Token: AsChar,
{
Expand Down Expand Up @@ -1178,7 +1178,7 @@ where
I: Input,
O: HexUint,
<I as Input>::Token: AsChar,
<I as Input>::Slice: AsBytes,
<I as Input>::Slice: AsBStr,
{
let invalid_offset = input
.offset_for(|c| {
Expand Down Expand Up @@ -1213,7 +1213,7 @@ where
let (remaining, parsed) = input.next_slice(offset);

let mut res = O::default();
for c in parsed.as_bytes() {
for c in parsed.as_bstr() {
let nibble = *c as char;
let nibble = nibble.to_digit(16).unwrap_or(0) as u8;
let nibble = O::from(nibble);
Expand Down Expand Up @@ -1316,7 +1316,7 @@ where
<I as Input>::Slice: ParseTo<O>,
<I as Input>::Token: AsChar + Copy,
<I as Input>::IterOffsets: Clone,
I: AsBytes,
I: AsBStr,
{
let (i, s) = if STREAMING {
crate::number::streaming::recognize_float_or_exceptions(input)?
Expand Down
14 changes: 7 additions & 7 deletions src/character/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::error::ErrorKind;
use crate::error::Needed;
use crate::error::ParseError;
use crate::input::{
split_at_offset1_streaming, split_at_offset_streaming, AsBytes, AsChar, ContainsToken, Input,
split_at_offset1_streaming, split_at_offset_streaming, AsBStr, AsChar, ContainsToken, Input,
};
use crate::input::{Compare, CompareResult};
use crate::IResult;
Expand Down Expand Up @@ -220,7 +220,7 @@ where
)]
pub fn not_line_ending<T, E: ParseError<T>>(input: T) -> IResult<T, <T as Input>::Slice, E>
where
T: Input + AsBytes,
T: Input + AsBStr,
T: Compare<&'static str>,
<T as Input>::Token: AsChar,
{
Expand All @@ -231,7 +231,7 @@ where
None => Err(ErrMode::Incomplete(Needed::Unknown)),
Some(offset) => {
let (new_input, res) = input.next_slice(offset);
let bytes = new_input.as_bytes();
let bytes = new_input.as_bstr();
let nth = bytes[0];
if nth == b'\r' {
let comp = new_input.compare("\r\n");
Expand Down Expand Up @@ -900,7 +900,7 @@ mod tests {
let a: &[u8] = b"abcd";
let b: &[u8] = b"1234";
let c: &[u8] = b"a123";
let d: &[u8] = "azé12".as_bytes();
let d: &[u8] = "azé12".as_bstr();
let e: &[u8] = b" ";
let f: &[u8] = b" ;";
//assert_eq!(alpha1::<_, Error<_>>(a), Err(ErrMode::Incomplete(Needed::new(1))));
Expand All @@ -910,7 +910,7 @@ mod tests {
Err(ErrMode::Backtrack(Error::new(b, ErrorKind::Alpha)))
);
assert_eq!(alpha1::<_, Error<_>>(c), Ok((&c[1..], &b"a"[..])));
assert_eq!(alpha1::<_, Error<_>>(d), Ok(("é12".as_bytes(), &b"az"[..])));
assert_eq!(alpha1::<_, Error<_>>(d), Ok(("é12".as_bstr(), &b"az"[..])));
assert_eq!(
digit1(a),
Err(ErrMode::Backtrack(Error::new(a, ErrorKind::Digit)))
Expand Down Expand Up @@ -941,7 +941,7 @@ mod tests {
);
assert_eq!(
hex_digit1::<_, Error<_>>(d),
Ok(("zé12".as_bytes(), &b"a"[..]))
Ok(("zé12".as_bstr(), &b"a"[..]))
);
assert_eq!(
hex_digit1(e),
Expand Down Expand Up @@ -974,7 +974,7 @@ mod tests {
);
assert_eq!(
alphanumeric1::<_, Error<_>>(d),
Ok(("é12".as_bytes(), &b"az"[..]))
Ok(("é12".as_bstr(), &b"az"[..]))
);
assert_eq!(
space1::<_, Error<_>>(e),
Expand Down
60 changes: 51 additions & 9 deletions src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
//! | [`SliceLen`] |Calculate the input length|
//! | [`InputIsStreaming`] | Marks the input as being the complete buffer or a partial buffer for streaming input |
//! | [`AsBytes`] |Casts the input type to a byte slice|
//! | [`AsBStr`] |Casts the input type to a slice of ASCII / UTF-8-like bytes|
//! | [`Compare`] |Character comparison operations|
//! | [`Accumulate`] |Abstracts something which can extend an `Extend`|
//! | [`FindSlice`] |Look for a substring in self|
Expand Down Expand Up @@ -905,13 +906,6 @@ impl<'a> AsBytes for &'a [u8] {
}
}

impl<'a> AsBytes for &'a str {
#[inline(always)]
fn as_bytes(&self) -> &[u8] {
(*self).as_bytes()
}
}

impl<I> AsBytes for Located<I>
where
I: AsBytes,
Expand Down Expand Up @@ -940,6 +934,54 @@ where
}
}

/// Helper trait for types that can be viewed as a byte slice
pub trait AsBStr {
/// Casts the input type to a byte slice
fn as_bstr(&self) -> &[u8];
}

impl<'a> AsBStr for &'a [u8] {
#[inline(always)]
fn as_bstr(&self) -> &[u8] {
self
}
}

impl<'a> AsBStr for &'a str {
#[inline(always)]
fn as_bstr(&self) -> &[u8] {
(*self).as_bytes()
}
}

impl<I> AsBStr for Located<I>
where
I: AsBStr,
{
fn as_bstr(&self) -> &[u8] {
self.input.as_bstr()
}
}

impl<I, S> AsBStr for Stateful<I, S>
where
I: AsBStr,
{
fn as_bstr(&self) -> &[u8] {
self.input.as_bstr()
}
}

impl<I> AsBStr for Streaming<I>
where
I: AsBStr,
{
#[inline(always)]
fn as_bstr(&self) -> &[u8] {
self.0.as_bstr()
}
}

/// Indicates whether a comparison was successful, an error, or
/// if more data was needed
#[derive(Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -1059,11 +1101,11 @@ impl<'a, 'b> Compare<&'b str> for &'a str {
impl<'a, 'b> Compare<&'b [u8]> for &'a str {
#[inline(always)]
fn compare(&self, t: &'b [u8]) -> CompareResult {
AsBytes::as_bytes(self).compare(t)
AsBStr::as_bstr(self).compare(t)
}
#[inline(always)]
fn compare_no_case(&self, t: &'b [u8]) -> CompareResult {
AsBytes::as_bytes(self).compare_no_case(t)
AsBStr::as_bstr(self).compare_no_case(t)
}
}

Expand Down
18 changes: 9 additions & 9 deletions src/number/complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use crate::character::complete::{char, digit1, sign};
use crate::combinator::{cut_err, map, opt};
use crate::error::ParseError;
use crate::error::{make_error, ErrMode, ErrorKind};
use crate::input::{AsBytes, AsChar, Compare, Input, Offset, SliceLen};
use crate::input::{AsBStr, AsBytes, AsChar, Compare, Input, Offset, SliceLen};
use crate::lib::std::ops::{Add, Shl};
use crate::sequence::{pair, tuple};
use crate::*;
Expand Down Expand Up @@ -1442,7 +1442,7 @@ pub fn hex_u32<I, E: ParseError<I>>(input: I) -> IResult<I, u32, E>
where
I: Input,
<I as Input>::Token: AsChar,
<I as Input>::Slice: AsBytes,
<I as Input>::Slice: AsBStr,
{
let invalid_offset = input
.offset_for(|c| {
Expand All @@ -1461,7 +1461,7 @@ where
let (remaining, parsed) = input.next_slice(offset);

let res = parsed
.as_bytes()
.as_bstr()
.iter()
.rev()
.enumerate()
Expand All @@ -1482,7 +1482,7 @@ where
T: Offset + Compare<&'static str>,
<T as Input>::Token: AsChar + Copy,
<T as Input>::IterOffsets: Clone,
T: AsBytes,
T: AsBStr,
{
tuple((
opt(alt((char('+'), char('-')))),
Expand Down Expand Up @@ -1510,7 +1510,7 @@ where
T: Offset + Compare<&'static str>,
<T as Input>::Token: AsChar + Copy,
<T as Input>::IterOffsets: Clone,
T: AsBytes,
T: AsBStr,
{
alt((
|i: T| {
Expand Down Expand Up @@ -1546,7 +1546,7 @@ pub fn recognize_float_parts<T, E: ParseError<T>>(
input: T,
) -> IResult<T, (bool, <T as Input>::Slice, <T as Input>::Slice, i32), E>
where
T: Input + Compare<&'static [u8]> + AsBytes,
T: Input + Compare<&'static [u8]> + AsBStr,
<T as Input>::Token: AsChar + Copy,
<T as Input>::Slice: SliceLen,
{
Expand All @@ -1564,7 +1564,7 @@ where
// match number
let mut zero_count = 0usize;
let mut offset = None;
for (pos, c) in i.as_bytes().iter().enumerate() {
for (pos, c) in i.as_bstr().iter().enumerate() {
if *c >= b'0' && *c <= b'9' {
if *c == b'0' {
zero_count += 1;
Expand Down Expand Up @@ -1635,7 +1635,7 @@ where
<T as Input>::Slice: ParseTo<f32>,
<T as Input>::Token: AsChar + Copy,
<T as Input>::IterOffsets: Clone,
T: AsBytes,
T: AsBStr,
{
let (i, s) = recognize_float_or_exceptions(input)?;
match s.parse_to() {
Expand Down Expand Up @@ -1674,7 +1674,7 @@ where
<T as Input>::Slice: ParseTo<f64>,
<T as Input>::Token: AsChar + Copy,
<T as Input>::IterOffsets: Clone,
T: AsBytes,
T: AsBStr,
{
let (i, s) = recognize_float_or_exceptions(input)?;
match s.parse_to() {
Expand Down

0 comments on commit d8533a3

Please sign in to comment.