Skip to content

Commit

Permalink
Improved API of getting mutable from Buffer (jorgecarleitao#1399)
Browse files Browse the repository at this point in the history
* Removed potential source of UB

* Improved API
  • Loading branch information
jorgecarleitao authored and ritchie46 committed Mar 29, 2023
1 parent 0605625 commit 6c04bd6
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 82 deletions.
50 changes: 21 additions & 29 deletions src/array/binary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ impl<O: Offset> BinaryArray<O> {
impl_into_array!();

/// Try to convert this `BinaryArray` to a `MutableBinaryArray`
pub fn into_mut(mut self) -> Either<Self, MutableBinaryArray<O>> {
#[must_use]
pub fn into_mut(self) -> Either<Self, MutableBinaryArray<O>> {
use Either::*;
if let Some(bitmap) = self.validity {
match bitmap.into_mut() {
Expand All @@ -217,29 +218,26 @@ impl<O: Offset> BinaryArray<O> {
self.values,
Some(bitmap),
)),
Right(mutable_bitmap) => match (
self.values.get_mut().map(std::mem::take),
self.offsets.get_mut(),
) {
(None, None) => Left(BinaryArray::new(
Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) {
(Left(values), Left(offsets)) => Left(BinaryArray::new(
self.data_type,
self.offsets,
self.values,
offsets,
values,
Some(mutable_bitmap.into()),
)),
(None, Some(offsets)) => Left(BinaryArray::new(
(Left(values), Right(offsets)) => Left(BinaryArray::new(
self.data_type,
offsets.into(),
self.values,
values,
Some(mutable_bitmap.into()),
)),
(Some(mutable_values), None) => Left(BinaryArray::new(
(Right(values), Left(offsets)) => Left(BinaryArray::new(
self.data_type,
self.offsets,
mutable_values.into(),
offsets,
values.into(),
Some(mutable_bitmap.into()),
)),
(Some(values), Some(offsets)) => Right(
(Right(values), Right(offsets)) => Right(
MutableBinaryArray::try_new(
self.data_type,
offsets,
Expand All @@ -251,29 +249,23 @@ impl<O: Offset> BinaryArray<O> {
},
}
} else {
match (
self.values.get_mut().map(std::mem::take),
self.offsets.get_mut(),
) {
(None, None) => Left(BinaryArray::new(
self.data_type,
self.offsets,
self.values,
None,
)),
(None, Some(offsets)) => Left(BinaryArray::new(
match (self.values.into_mut(), self.offsets.into_mut()) {
(Left(values), Left(offsets)) => {
Left(BinaryArray::new(self.data_type, offsets, values, None))
}
(Left(values), Right(offsets)) => Left(BinaryArray::new(
self.data_type,
offsets.into(),
self.values,
values,
None,
)),
(Some(values), None) => Left(BinaryArray::new(
(Right(values), Left(offsets)) => Left(BinaryArray::new(
self.data_type,
self.offsets,
offsets,
values.into(),
None,
)),
(Some(values), Some(offsets)) => Right(
(Right(values), Right(offsets)) => Right(
MutableBinaryArray::try_new(self.data_type, offsets, values, None).unwrap(),
),
}
Expand Down
18 changes: 9 additions & 9 deletions src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ impl<T: NativeType> PrimitiveArray<T> {

/// Returns an option of a mutable reference to the values of this [`PrimitiveArray`].
pub fn get_mut_values(&mut self) -> Option<&mut [T]> {
self.values.get_mut().map(|x| x.as_mut())
self.values.get_mut_slice()
}

/// Returns its internal representation
Expand All @@ -282,7 +282,7 @@ impl<T: NativeType> PrimitiveArray<T> {
///
/// This function is primarily used to re-use memory regions.
#[must_use]
pub fn into_mut(mut self) -> Either<Self, MutablePrimitiveArray<T>> {
pub fn into_mut(self) -> Either<Self, MutablePrimitiveArray<T>> {
use Either::*;

if let Some(bitmap) = self.validity {
Expand All @@ -292,28 +292,28 @@ impl<T: NativeType> PrimitiveArray<T> {
self.values,
Some(bitmap),
)),
Right(mutable_bitmap) => match self.values.get_mut().map(std::mem::take) {
Some(values) => Right(
Right(mutable_bitmap) => match self.values.into_mut() {
Right(values) => Right(
MutablePrimitiveArray::try_new(
self.data_type,
values,
Some(mutable_bitmap),
)
.unwrap(),
),
None => Left(PrimitiveArray::new(
Left(values) => Left(PrimitiveArray::new(
self.data_type,
self.values,
values,
Some(mutable_bitmap.into()),
)),
},
}
} else {
match self.values.get_mut().map(std::mem::take) {
Some(values) => {
match self.values.into_mut() {
Right(values) => {
Right(MutablePrimitiveArray::try_new(self.data_type, values, None).unwrap())
}
None => Left(PrimitiveArray::new(self.data_type, self.values, None)),
Left(values) => Left(PrimitiveArray::new(self.data_type, values, None)),
}
}
}
Expand Down
47 changes: 21 additions & 26 deletions src/array/utf8/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ impl<O: Offset> Utf8Array<O> {
impl_into_array!();

/// Try to convert this `Utf8Array` to a `MutableUtf8Array`
pub fn into_mut(mut self) -> Either<Self, MutableUtf8Array<O>> {
#[must_use]
pub fn into_mut(self) -> Either<Self, MutableUtf8Array<O>> {
use Either::*;
if let Some(bitmap) = self.validity {
match bitmap.into_mut() {
Expand All @@ -239,44 +240,41 @@ impl<O: Offset> Utf8Array<O> {
Some(bitmap),
)
}),
Right(mutable_bitmap) => match (
self.values.get_mut().map(std::mem::take),
self.offsets.get_mut(),
) {
(None, None) => {
Right(mutable_bitmap) => match (self.values.into_mut(), self.offsets.into_mut()) {
(Left(values), Left(offsets)) => {
// Safety: invariants are preserved
Left(unsafe {
Utf8Array::new_unchecked(
self.data_type,
self.offsets,
self.values,
offsets,
values,
Some(mutable_bitmap.into()),
)
})
}
(None, Some(offsets)) => {
(Left(values), Right(offsets)) => {
// Safety: invariants are preserved
Left(unsafe {
Utf8Array::new_unchecked(
self.data_type,
offsets.into(),
self.values,
values,
Some(mutable_bitmap.into()),
)
})
}
(Some(mutable_values), None) => {
(Right(values), Left(offsets)) => {
// Safety: invariants are preserved
Left(unsafe {
Utf8Array::new_unchecked(
self.data_type,
self.offsets,
mutable_values.into(),
offsets,
values.into(),
Some(mutable_bitmap.into()),
)
})
}
(Some(values), Some(offsets)) => Right(unsafe {
(Right(values), Right(offsets)) => Right(unsafe {
MutableUtf8Array::new_unchecked(
self.data_type,
offsets,
Expand All @@ -287,20 +285,17 @@ impl<O: Offset> Utf8Array<O> {
},
}
} else {
match (
self.values.get_mut().map(std::mem::take),
self.offsets.get_mut(),
) {
(None, None) => Left(unsafe {
Utf8Array::new_unchecked(self.data_type, self.offsets, self.values, None)
match (self.values.into_mut(), self.offsets.into_mut()) {
(Left(values), Left(offsets)) => {
Left(unsafe { Utf8Array::new_unchecked(self.data_type, offsets, values, None) })
}
(Left(values), Right(offsets)) => Left(unsafe {
Utf8Array::new_unchecked(self.data_type, offsets.into(), values, None)
}),
(None, Some(offsets)) => Left(unsafe {
Utf8Array::new_unchecked(self.data_type, offsets.into(), self.values, None)
(Right(values), Left(offsets)) => Left(unsafe {
Utf8Array::new_unchecked(self.data_type, offsets, values.into(), None)
}),
(Some(values), None) => Left(unsafe {
Utf8Array::new_unchecked(self.data_type, self.offsets, values.into(), None)
}),
(Some(values), Some(offsets)) => Right(unsafe {
(Right(values), Right(offsets)) => Right(unsafe {
MutableUtf8Array::new_unchecked(self.data_type, offsets, values, None)
}),
}
Expand Down
51 changes: 41 additions & 10 deletions src/buffer/immutable.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::{iter::FromIterator, ops::Deref, sync::Arc, usize};

use either::Either;

use super::Bytes;
use super::IntoIter;

Expand All @@ -21,16 +23,16 @@ use super::IntoIter;
/// assert_eq!(buffer.as_ref(), [1, 2, 3].as_ref());
///
/// // it supports copy-on-write semantics (i.e. back to a `Vec`)
/// let vec: &mut [u32] = buffer.get_mut().unwrap();
/// assert_eq!(vec, &mut [1, 2, 3]);
/// let vec: Vec<u32> = buffer.into_mut().right().unwrap();
/// assert_eq!(vec, vec![1, 2, 3]);
///
/// // cloning and slicing is `O(1)` (data is shared)
/// let mut buffer: Buffer<u32> = vec![1, 2, 3].into();
/// let mut sliced = buffer.clone();
/// sliced.slice(1, 1);
/// assert_eq!(sliced.as_ref(), [2].as_ref());
/// // but cloning forbids getting mut since `slice` and `buffer` now share data
/// assert_eq!(buffer.get_mut(), None);
/// assert_eq!(buffer.get_mut_slice(), None);
/// ```
#[derive(Clone)]
pub struct Buffer<T> {
Expand Down Expand Up @@ -178,18 +180,47 @@ impl<T> Buffer<T> {

/// Returns a mutable reference to its underlying [`Vec`], if possible.
///
/// This operation returns [`Some`] iff this [`Buffer`]:
/// * has not been sliced with an offset
/// This operation returns [`Either::Right`] iff this [`Buffer`]:
/// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`])
/// * has not been imported from the c data interface (FFI)
pub fn get_mut(&mut self) -> Option<&mut Vec<T>> {
if self.offset != 0 {
None
} else {
Arc::get_mut(&mut self.data).and_then(|b| b.get_vec())
#[inline]
pub fn into_mut(mut self) -> Either<Self, Vec<T>> {
match Arc::get_mut(&mut self.data)
.and_then(|b| b.get_vec())
.map(std::mem::take)
{
Some(inner) => Either::Right(inner),
None => Either::Left(self),
}
}

/// Returns a mutable reference to its underlying `Vec`, if possible.
/// Note that only `[self.offset(), self.offset() + self.len()[` in this vector is visible
/// by this buffer.
///
/// This operation returns [`Some`] iff this [`Buffer`]:
/// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`])
/// * has not been imported from the c data interface (FFI)
/// # Safety
/// The caller must ensure that the vector in the mutable reference keeps a length of at least `self.offset() + self.len() - 1`.
#[inline]
pub unsafe fn get_mut(&mut self) -> Option<&mut Vec<T>> {
Arc::get_mut(&mut self.data).and_then(|b| b.get_vec())
}

/// Returns a mutable reference to its slice, if possible.
///
/// This operation returns [`Some`] iff this [`Buffer`]:
/// * has not been cloned (i.e. [`Arc`]`::get_mut` yields [`Some`])
/// * has not been imported from the c data interface (FFI)
#[inline]
pub fn get_mut_slice(&mut self) -> Option<&mut [T]> {
Arc::get_mut(&mut self.data)
.and_then(|b| b.get_vec())
// Safety: the invariant of this struct
.map(|x| unsafe { x.get_unchecked_mut(self.offset..self.offset + self.length) })
}

/// Get the strong count of underlying `Arc` data buffer.
pub fn shared_count_strong(&self) -> usize {
Arc::strong_count(&self.data)
Expand Down
12 changes: 4 additions & 8 deletions src/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,16 +363,12 @@ impl<O: Offset> OffsetsBuffer<O> {

/// Copy-on-write API to convert [`OffsetsBuffer`] into [`Offsets`].
#[inline]
pub fn get_mut(&mut self) -> Option<Offsets<O>> {
pub fn into_mut(self) -> either::Either<Self, Offsets<O>> {
self.0
.get_mut()
.map(|x| {
let mut new = vec![O::zero()];
std::mem::swap(x, &mut new);
new
})
.into_mut()
// Safety: Offsets and OffsetsBuffer share invariants
.map(|offsets| unsafe { Offsets::new_unchecked(offsets) })
.map_right(|offsets| unsafe { Offsets::new_unchecked(offsets) })
.map_left(Self)
}

/// Returns a reference to its internal [`Buffer`].
Expand Down

0 comments on commit 6c04bd6

Please sign in to comment.