Skip to content

Commit

Permalink
correct take_iter to be trusted-len
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 31, 2021
1 parent 1edaf66 commit fbc0df0
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 14 deletions.
7 changes: 4 additions & 3 deletions polars/polars-arrow/src/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use arrow::bitmap::MutableBitmap;
use arrow::buffer::Buffer;
use arrow::datatypes::{DataType, PhysicalType};
use arrow::types::NativeType;
use std::iter::FromIterator;
use std::sync::Arc;

/// # Safety
Expand Down Expand Up @@ -291,6 +290,7 @@ pub unsafe fn take_no_null_bool_opt_iter_unchecked<I: IntoIterator<Item = Option

/// # Safety
/// - no bounds checks
/// - iterator must be TrustedLen
#[inline]
pub unsafe fn take_no_null_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
arr: &LargeStringArray,
Expand All @@ -300,11 +300,12 @@ pub unsafe fn take_no_null_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
debug_assert!(idx < arr.len());
arr.value_unchecked(idx)
});
Arc::new(MutableUtf8Array::<i64>::from_iter_values(iter).into())
Arc::new(MutableUtf8Array::<i64>::from_trusted_len_values_iter_unchecked(iter).into())
}

/// # Safety
/// - no bounds checks
/// - iterator must be TrustedLen
#[inline]
pub unsafe fn take_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
arr: &LargeStringArray,
Expand All @@ -320,7 +321,7 @@ pub unsafe fn take_utf8_iter_unchecked<I: IntoIterator<Item = usize>>(
}
});

Arc::new(LargeStringArray::from_iter(iter))
Arc::new(LargeStringArray::from_trusted_len_iter_unchecked(iter))
}

/// # Safety
Expand Down
1 change: 1 addition & 0 deletions polars/polars-arrow/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use arrow::bitmap::Bitmap;
use arrow::types::NativeType;
use std::ops::BitAnd;

#[derive(Clone)]
pub struct TrustMyLength<I: Iterator<Item = J>, J> {
iter: I,
len: usize,
Expand Down
11 changes: 7 additions & 4 deletions polars/polars-core/src/chunked_array/ops/take/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ use arrow::array::UInt32Array;
use polars_arrow::array::PolarsArray;

// Utility traits
pub trait TakeIterator: Iterator<Item = usize> {
pub trait TakeIterator: Iterator<Item = usize> + TrustedLen {
fn check_bounds(&self, bound: usize) -> Result<()>;
}
pub trait TakeIteratorNulls: Iterator<Item = Option<usize>> {
pub trait TakeIteratorNulls: Iterator<Item = Option<usize>> + TrustedLen {
fn check_bounds(&self, bound: usize) -> Result<()>;
}

unsafe impl TrustedLen for &mut dyn TakeIterator {}
unsafe impl TrustedLen for &mut dyn TakeIteratorNulls {}

// Implement for the ref as well
impl TakeIterator for &mut dyn TakeIterator {
fn check_bounds(&self, bound: usize) -> Result<()> {
Expand All @@ -26,7 +29,7 @@ impl TakeIteratorNulls for &mut dyn TakeIteratorNulls {
// Clonable iterators may implement the traits above
impl<I> TakeIterator for I
where
I: Iterator<Item = usize> + Clone + Sized,
I: Iterator<Item = usize> + Clone + Sized + TrustedLen,
{
fn check_bounds(&self, bound: usize) -> Result<()> {
// clone so that the iterator can be used again.
Expand All @@ -50,7 +53,7 @@ where
}
impl<I> TakeIteratorNulls for I
where
I: Iterator<Item = Option<usize>> + Clone + Sized,
I: Iterator<Item = Option<usize>> + Clone + Sized + TrustedLen,
{
fn check_bounds(&self, bound: usize) -> Result<()> {
// clone so that the iterator can be used again.
Expand Down
6 changes: 3 additions & 3 deletions polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ impl DataFrame {
/// ```
pub fn take_iter<I>(&self, iter: I) -> Result<Self>
where
I: Iterator<Item = usize> + Clone + Sync,
I: Iterator<Item = usize> + Clone + Sync + TrustedLen,
{
let new_col = POOL.install(|| {
self.columns
Expand All @@ -1358,7 +1358,7 @@ impl DataFrame {
/// This doesn't do any bound checking but checks null validity.
pub unsafe fn take_iter_unchecked<I>(&self, mut iter: I) -> Self
where
I: Iterator<Item = usize> + Clone + Sync,
I: Iterator<Item = usize> + Clone + Sync + TrustedLen,
{
if std::env::var("POLARS_VERT_PAR").is_ok() {
let idx_ca: NoNull<UInt32Chunked> = iter.into_iter().map(|idx| idx as u32).collect();
Expand Down Expand Up @@ -1407,7 +1407,7 @@ impl DataFrame {
/// Null validity is checked
pub unsafe fn take_opt_iter_unchecked<I>(&self, mut iter: I) -> Self
where
I: Iterator<Item = Option<usize>> + Clone + Sync,
I: Iterator<Item = Option<usize>> + Clone + Sync + TrustedLen,
{
if std::env::var("POLARS_VERT_PAR").is_ok() {
let idx_ca: UInt32Chunked = iter.into_iter().map(|opt| opt.map(|v| v as u32)).collect();
Expand Down
6 changes: 4 additions & 2 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -766,13 +766,15 @@ impl Series {
/// Take the group tuples.
///
/// # Safety
/// Group tuples have to be in bounds.
/// - Group tuples have to be in bounds.
pub unsafe fn take_group_values(&self, groups: &GroupTuples) -> Series {
let len = groups.iter().map(|g| g.1.len()).sum::<usize>();
self.take_iter_unchecked(
&mut groups
.iter()
.map(|g| g.1.iter().map(|idx| *idx as usize))
.flatten(),
.flatten()
.trust_my_length(len),
)
}
}
Expand Down
6 changes: 4 additions & 2 deletions polars/polars-core/src/series/series_trait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,8 @@ pub trait SeriesTrait:
///
/// # Safety
///
/// This doesn't check any bounds.
/// - This doesn't check any bounds.
/// - Iterator must be TrustedLen
unsafe fn take_iter_unchecked(&self, _iter: &mut dyn TakeIterator) -> Series {
invalid_operation_panic!(self)
}
Expand All @@ -526,7 +527,8 @@ pub trait SeriesTrait:
///
/// # Safety
///
/// This doesn't check any bounds.
/// - This doesn't check any bounds.
/// - Iterator must be TrustedLen
unsafe fn take_opt_iter_unchecked(&self, _iter: &mut dyn TakeIteratorNulls) -> Series {
invalid_operation_panic!(self)
}
Expand Down

0 comments on commit fbc0df0

Please sign in to comment.