From 2b28777966a5667fa827b54d6a12d776bf2ef826 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 31 Mar 2024 12:48:39 +0200 Subject: [PATCH] fix: Ensure Binary -> Binview cast doesn't overflow the buffer size (#15408) --- .../polars-arrow/src/compute/cast/utf8_to.rs | 106 ++++++++++++++++-- 1 file changed, 99 insertions(+), 7 deletions(-) diff --git a/crates/polars-arrow/src/compute/cast/utf8_to.rs b/crates/polars-arrow/src/compute/cast/utf8_to.rs index fadb6552beab..0f8892e498aa 100644 --- a/crates/polars-arrow/src/compute/cast/utf8_to.rs +++ b/crates/polars-arrow/src/compute/cast/utf8_to.rs @@ -5,6 +5,7 @@ use polars_utils::slice::GetSaferUnchecked; use polars_utils::vec::PushUnchecked; use crate::array::*; +use crate::buffer::Buffer; use crate::datatypes::ArrowDataType; use crate::offset::Offset; use crate::types::NativeType; @@ -69,14 +70,51 @@ pub fn utf8_to_binary( } } +// Different types to test the overflow path. +#[cfg(not(test))] +type OffsetType = u32; + +// To trigger overflow +#[cfg(test)] +type OffsetType = i8; + +// If we don't do this the GC of binview will trigger. As we will split up buffers into multiple +// chunks so that we don't overflow the offset u32. +fn truncate_buffer(buf: &Buffer) -> Buffer { + // * 2, as it must be able to hold u32::MAX offset + u32::MAX len. + buf.clone() + .sliced(0, std::cmp::min(buf.len(), OffsetType::MAX as usize * 2)) +} + pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { - let buffer_idx = 0_u32; - let base_ptr = arr.values().as_ptr() as usize; + // Ensure we didn't accidentally set wrong type + #[cfg(not(debug_assertions))] + { + assert_eq!( + std::mem::size_of::(), + std::mem::size_of::() + ); + } let mut views = Vec::with_capacity(arr.len()); let mut uses_buffer = false; + + let mut base_buffer = arr.values().clone(); + // Offset into the buffer + let mut base_ptr = base_buffer.as_ptr() as usize; + + // Offset into the binview buffers + let mut buffer_idx = 0_u32; + + // Binview buffers + // Note that the buffer may look far further than u32::MAX, but as we don't clone data + let mut buffers = vec![truncate_buffer(&base_buffer)]; + for bytes in arr.values_iter() { - let len: u32 = bytes.len().try_into().unwrap(); + let len: u32 = bytes + .len() + .try_into() + .expect("max string/binary length exceeded"); let mut payload = [0; 16]; payload[0..4].copy_from_slice(&len.to_le_bytes()); @@ -85,18 +123,42 @@ pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { payload[4..4 + bytes.len()].copy_from_slice(bytes); } else { uses_buffer = true; + + // Copy the parts we know are correct. unsafe { payload[4..8].copy_from_slice(bytes.get_unchecked_release(0..4)) }; - let offset = (bytes.as_ptr() as usize - base_ptr) as u32; payload[0..4].copy_from_slice(&len.to_le_bytes()); - payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); - payload[12..16].copy_from_slice(&offset.to_le_bytes()); + + let current_bytes_ptr = bytes.as_ptr() as usize; + let offset = current_bytes_ptr - base_ptr; + + // Here we check the overflow of the buffer offset. + if let Ok(offset) = OffsetType::try_from(offset) { + #[allow(clippy::unnecessary_cast)] + let offset = offset as u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } else { + let len = base_buffer.len() - offset; + + // Set new buffer + base_buffer = base_buffer.clone().sliced(offset, len); + base_ptr = base_buffer.as_ptr() as usize; + + // And add the (truncated) one to the buffers + buffers.push(truncate_buffer(&base_buffer)); + buffer_idx = buffer_idx.checked_add(1).expect("max buffers exceeded"); + + let offset = 0u32; + payload[12..16].copy_from_slice(&offset.to_le_bytes()); + payload[8..12].copy_from_slice(&buffer_idx.to_le_bytes()); + } } let value = View::from_le_bytes(payload); unsafe { views.push_unchecked(value) }; } let buffers = if uses_buffer { - Arc::from([arr.values().clone()]) + Arc::from(buffers) } else { Arc::from([]) }; @@ -114,3 +176,33 @@ pub fn binary_to_binview(arr: &BinaryArray) -> BinaryViewArray { pub fn utf8_to_utf8view(arr: &Utf8Array) -> Utf8ViewArray { unsafe { binary_to_binview(&arr.to_binary()).to_utf8view_unchecked() } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn overflowing_utf8_to_binview() { + let values = [ + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "123", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "234", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "lksafjdlkakjslkjsafkjdalkjfalkdsalkjfaslkfjlkakdsjfkajfksdajfkasjdflkasjdf", + "324", + ]; + let array = Utf8Array::::from_slice(values); + + let out = utf8_to_utf8view(&array); + // Ensure we hit the multiple buffers part. + assert_eq!(out.buffers().len(), 6); + // Ensure we created a valid binview + let out = out.values_iter().collect::>(); + assert_eq!(out, values); + } +}