From 6c04bd64610d6a726c512a323fcabf0fa1ce81dd Mon Sep 17 00:00:00 2001 From: Jorge Leitao Date: Mon, 13 Feb 2023 07:52:53 +0100 Subject: [PATCH] Improved API of getting mutable from Buffer (#1399) * Removed potential source of UB * Improved API --- src/array/binary/mod.rs | 50 ++++++++++++++++--------------------- src/array/primitive/mod.rs | 18 +++++++------- src/array/utf8/mod.rs | 47 ++++++++++++++++------------------- src/buffer/immutable.rs | 51 ++++++++++++++++++++++++++++++-------- src/offset.rs | 12 +++------ 5 files changed, 96 insertions(+), 82 deletions(-) diff --git a/src/array/binary/mod.rs b/src/array/binary/mod.rs index 73860840bc5..86ebaa73573 100644 --- a/src/array/binary/mod.rs +++ b/src/array/binary/mod.rs @@ -206,7 +206,8 @@ impl BinaryArray { impl_into_array!(); /// Try to convert this `BinaryArray` to a `MutableBinaryArray` - pub fn into_mut(mut self) -> Either> { + #[must_use] + pub fn into_mut(self) -> Either> { use Either::*; if let Some(bitmap) = self.validity { match bitmap.into_mut() { @@ -217,29 +218,26 @@ impl BinaryArray { self.values, Some(bitmap), )), - Right(mutable_bitmap) => match ( - self.values.get_mut().map(std::mem::take), - self.offsets.get_mut(), - ) { - (None, None) => Left(BinaryArray::new( + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => Left(BinaryArray::new( self.data_type, - self.offsets, - self.values, + offsets, + values, Some(mutable_bitmap.into()), )), - (None, Some(offsets)) => Left(BinaryArray::new( + (Left(values), Right(offsets)) => Left(BinaryArray::new( self.data_type, offsets.into(), - self.values, + values, Some(mutable_bitmap.into()), )), - (Some(mutable_values), None) => Left(BinaryArray::new( + (Right(values), Left(offsets)) => Left(BinaryArray::new( self.data_type, - self.offsets, - mutable_values.into(), + offsets, + values.into(), Some(mutable_bitmap.into()), )), - (Some(values), Some(offsets)) => Right( + (Right(values), Right(offsets)) => Right( MutableBinaryArray::try_new( self.data_type, offsets, @@ -251,29 +249,23 @@ impl BinaryArray { }, } } else { - match ( - self.values.get_mut().map(std::mem::take), - self.offsets.get_mut(), - ) { - (None, None) => Left(BinaryArray::new( - self.data_type, - self.offsets, - self.values, - None, - )), - (None, Some(offsets)) => Left(BinaryArray::new( + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(BinaryArray::new(self.data_type, offsets, values, None)) + } + (Left(values), Right(offsets)) => Left(BinaryArray::new( self.data_type, offsets.into(), - self.values, + values, None, )), - (Some(values), None) => Left(BinaryArray::new( + (Right(values), Left(offsets)) => Left(BinaryArray::new( self.data_type, - self.offsets, + offsets, values.into(), None, )), - (Some(values), Some(offsets)) => Right( + (Right(values), Right(offsets)) => Right( MutableBinaryArray::try_new(self.data_type, offsets, values, None).unwrap(), ), } diff --git a/src/array/primitive/mod.rs b/src/array/primitive/mod.rs index 34229101df1..b8ea5fbb84e 100644 --- a/src/array/primitive/mod.rs +++ b/src/array/primitive/mod.rs @@ -260,7 +260,7 @@ impl PrimitiveArray { /// Returns an option of a mutable reference to the values of this [`PrimitiveArray`]. pub fn get_mut_values(&mut self) -> Option<&mut [T]> { - self.values.get_mut().map(|x| x.as_mut()) + self.values.get_mut_slice() } /// Returns its internal representation @@ -282,7 +282,7 @@ impl PrimitiveArray { /// /// This function is primarily used to re-use memory regions. #[must_use] - pub fn into_mut(mut self) -> Either> { + pub fn into_mut(self) -> Either> { use Either::*; if let Some(bitmap) = self.validity { @@ -292,8 +292,8 @@ impl PrimitiveArray { self.values, Some(bitmap), )), - Right(mutable_bitmap) => match self.values.get_mut().map(std::mem::take) { - Some(values) => Right( + Right(mutable_bitmap) => match self.values.into_mut() { + Right(values) => Right( MutablePrimitiveArray::try_new( self.data_type, values, @@ -301,19 +301,19 @@ impl PrimitiveArray { ) .unwrap(), ), - None => Left(PrimitiveArray::new( + Left(values) => Left(PrimitiveArray::new( self.data_type, - self.values, + values, Some(mutable_bitmap.into()), )), }, } } else { - match self.values.get_mut().map(std::mem::take) { - Some(values) => { + match self.values.into_mut() { + Right(values) => { Right(MutablePrimitiveArray::try_new(self.data_type, values, None).unwrap()) } - None => Left(PrimitiveArray::new(self.data_type, self.values, None)), + Left(values) => Left(PrimitiveArray::new(self.data_type, values, None)), } } } diff --git a/src/array/utf8/mod.rs b/src/array/utf8/mod.rs index 51055fe8dd1..e7e83906598 100644 --- a/src/array/utf8/mod.rs +++ b/src/array/utf8/mod.rs @@ -226,7 +226,8 @@ impl Utf8Array { impl_into_array!(); /// Try to convert this `Utf8Array` to a `MutableUtf8Array` - pub fn into_mut(mut self) -> Either> { + #[must_use] + pub fn into_mut(self) -> Either> { use Either::*; if let Some(bitmap) = self.validity { match bitmap.into_mut() { @@ -239,44 +240,41 @@ impl Utf8Array { Some(bitmap), ) }), - Right(mutable_bitmap) => match ( - self.values.get_mut().map(std::mem::take), - self.offsets.get_mut(), - ) { - (None, None) => { + Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { // Safety: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, - self.offsets, - self.values, + offsets, + values, Some(mutable_bitmap.into()), ) }) } - (None, Some(offsets)) => { + (Left(values), Right(offsets)) => { // Safety: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, offsets.into(), - self.values, + values, Some(mutable_bitmap.into()), ) }) } - (Some(mutable_values), None) => { + (Right(values), Left(offsets)) => { // Safety: invariants are preserved Left(unsafe { Utf8Array::new_unchecked( self.data_type, - self.offsets, - mutable_values.into(), + offsets, + values.into(), Some(mutable_bitmap.into()), ) }) } - (Some(values), Some(offsets)) => Right(unsafe { + (Right(values), Right(offsets)) => Right(unsafe { MutableUtf8Array::new_unchecked( self.data_type, offsets, @@ -287,20 +285,17 @@ impl Utf8Array { }, } } else { - match ( - self.values.get_mut().map(std::mem::take), - self.offsets.get_mut(), - ) { - (None, None) => Left(unsafe { - Utf8Array::new_unchecked(self.data_type, self.offsets, self.values, None) + match (self.values.into_mut(), self.offsets.into_mut()) { + (Left(values), Left(offsets)) => { + Left(unsafe { Utf8Array::new_unchecked(self.data_type, offsets, values, None) }) + } + (Left(values), Right(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets.into(), values, None) }), - (None, Some(offsets)) => Left(unsafe { - Utf8Array::new_unchecked(self.data_type, offsets.into(), self.values, None) + (Right(values), Left(offsets)) => Left(unsafe { + Utf8Array::new_unchecked(self.data_type, offsets, values.into(), None) }), - (Some(values), None) => Left(unsafe { - Utf8Array::new_unchecked(self.data_type, self.offsets, values.into(), None) - }), - (Some(values), Some(offsets)) => Right(unsafe { + (Right(values), Right(offsets)) => Right(unsafe { MutableUtf8Array::new_unchecked(self.data_type, offsets, values, None) }), } diff --git a/src/buffer/immutable.rs b/src/buffer/immutable.rs index e4b69ddc371..943bc3b3be3 100644 --- a/src/buffer/immutable.rs +++ b/src/buffer/immutable.rs @@ -1,5 +1,7 @@ use std::{iter::FromIterator, ops::Deref, sync::Arc, usize}; +use either::Either; + use super::Bytes; use super::IntoIter; @@ -21,8 +23,8 @@ use super::IntoIter; /// assert_eq!(buffer.as_ref(), [1, 2, 3].as_ref()); /// /// // it supports copy-on-write semantics (i.e. back to a `Vec`) -/// let vec: &mut [u32] = buffer.get_mut().unwrap(); -/// assert_eq!(vec, &mut [1, 2, 3]); +/// let vec: Vec = buffer.into_mut().right().unwrap(); +/// assert_eq!(vec, vec![1, 2, 3]); /// /// // cloning and slicing is `O(1)` (data is shared) /// let mut buffer: Buffer = vec![1, 2, 3].into(); @@ -30,7 +32,7 @@ use super::IntoIter; /// sliced.slice(1, 1); /// assert_eq!(sliced.as_ref(), [2].as_ref()); /// // but cloning forbids getting mut since `slice` and `buffer` now share data -/// assert_eq!(buffer.get_mut(), None); +/// assert_eq!(buffer.get_mut_slice(), None); /// ``` #[derive(Clone)] pub struct Buffer { @@ -178,18 +180,47 @@ impl Buffer { /// Returns a mutable reference to its underlying [`Vec`], if possible. /// - /// This operation returns [`Some`] iff this [`Buffer`]: - /// * has not been sliced with an offset + /// This operation returns [`Either::Right`] iff this [`Buffer`]: /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) /// * has not been imported from the c data interface (FFI) - pub fn get_mut(&mut self) -> Option<&mut Vec> { - if self.offset != 0 { - None - } else { - Arc::get_mut(&mut self.data).and_then(|b| b.get_vec()) + #[inline] + pub fn into_mut(mut self) -> Either> { + match Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + .map(std::mem::take) + { + Some(inner) => Either::Right(inner), + None => Either::Left(self), } } + /// Returns a mutable reference to its underlying `Vec`, if possible. + /// Note that only `[self.offset(), self.offset() + self.len()[` in this vector is visible + /// by this buffer. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + /// # Safety + /// The caller must ensure that the vector in the mutable reference keeps a length of at least `self.offset() + self.len() - 1`. + #[inline] + pub unsafe fn get_mut(&mut self) -> Option<&mut Vec> { + Arc::get_mut(&mut self.data).and_then(|b| b.get_vec()) + } + + /// Returns a mutable reference to its slice, if possible. + /// + /// This operation returns [`Some`] iff this [`Buffer`]: + /// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`]) + /// * has not been imported from the c data interface (FFI) + #[inline] + pub fn get_mut_slice(&mut self) -> Option<&mut [T]> { + Arc::get_mut(&mut self.data) + .and_then(|b| b.get_vec()) + // Safety: the invariant of this struct + .map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) }) + } + /// Get the strong count of underlying `Arc` data buffer. pub fn shared_count_strong(&self) -> usize { Arc::strong_count(&self.data) diff --git a/src/offset.rs b/src/offset.rs index 2532084ef03..4ab062ef6da 100644 --- a/src/offset.rs +++ b/src/offset.rs @@ -363,16 +363,12 @@ impl OffsetsBuffer { /// Copy-on-write API to convert [`OffsetsBuffer`] into [`Offsets`]. #[inline] - pub fn get_mut(&mut self) -> Option> { + pub fn into_mut(self) -> either::Either> { self.0 - .get_mut() - .map(|x| { - let mut new = vec![O::zero()]; - std::mem::swap(x, &mut new); - new - }) + .into_mut() // Safety: Offsets and OffsetsBuffer share invariants - .map(|offsets| unsafe { Offsets::new_unchecked(offsets) }) + .map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) }) + .map_left(Self) } /// Returns a reference to its internal [`Buffer`].