Skip to content

Commit

Permalink
move the bounds check to the implementation of TakeIterators
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 26, 2021
1 parent f2dfd3e commit 37f08f9
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 67 deletions.
61 changes: 5 additions & 56 deletions polars/polars-core/src/chunked_array/ops/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,57 +64,6 @@ macro_rules! take_opt_iter_n_chunks_unchecked {
}};
}

fn check_bounds<I, INulls>(len: usize, indices: &TakeIdx<I, INulls>) -> Result<()>
where
I: TakeIterator,
INulls: TakeIteratorNulls,
{
let mut inbounds = true;
match indices {
TakeIdx::Iter(i) => {
// we clone so that we can iterate twice
let iter = i.shallow_clone();
for i in iter {
if i >= len {
inbounds = false;
break;
}
}
}
TakeIdx::Array(arr) => {
let len = len as u32;
if arr.null_count() == 0 {
for &i in arr.values().as_slice() {
if i >= len {
inbounds = false;
break;
}
}
} else {
for opt_v in *arr {
match opt_v {
Some(&v) if v >= len => {
inbounds = false;
break;
}
_ => {}
}
}
}
}
_ => {
return Err(PolarsError::ValueError(
"iterator with opetions not supported".into(),
))
}
}
if inbounds {
Ok(())
} else {
Err(PolarsError::OutOfBounds("index is out of bounds".into()))
}
}

impl<T> ChunkTake for ChunkedArray<T>
where
T: PolarsNumericType,
Expand Down Expand Up @@ -205,7 +154,7 @@ where
I: TakeIterator,
INulls: TakeIteratorNulls,
{
check_bounds(self.len(), &indices)?;
indices.check_bounds(self.len())?;
// Safety:
// just checked bounds
Ok(unsafe { self.take_unchecked(indices) })
Expand Down Expand Up @@ -289,7 +238,7 @@ impl ChunkTake for BooleanChunked {
I: TakeIterator,
INulls: TakeIteratorNulls,
{
check_bounds(self.len(), &indices)?;
indices.check_bounds(self.len())?;
// Safety:
// just checked bounds
Ok(unsafe { self.take_unchecked(indices) })
Expand Down Expand Up @@ -367,7 +316,7 @@ impl ChunkTake for Utf8Chunked {
I: TakeIterator,
INulls: TakeIteratorNulls,
{
check_bounds(self.len(), &indices)?;
indices.check_bounds(self.len())?;
// Safety:
// just checked bounds
Ok(unsafe { self.take_unchecked(indices) })
Expand Down Expand Up @@ -427,7 +376,7 @@ impl ChunkTake for ListChunked {
I: TakeIterator,
INulls: TakeIteratorNulls,
{
check_bounds(self.len(), &indices)?;
indices.check_bounds(self.len())?;
// Safety:
// just checked bounds
Ok(unsafe { self.take_unchecked(indices) })
Expand Down Expand Up @@ -538,7 +487,7 @@ impl<T: PolarsObject> ChunkTake for ObjectChunked<T> {
I: TakeIterator,
INulls: TakeIteratorNulls,
{
check_bounds(self.len(), &indices)?;
indices.check_bounds(self.len())?;
// Safety:
// just checked bounds
Ok(unsafe { self.take_unchecked(indices) })
Expand Down
96 changes: 85 additions & 11 deletions polars/polars-core/src/chunked_array/ops/take/traits.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,24 @@
//! Traits that indicate the allowed arguments in a ChunkedArray::take operation.
use crate::prelude::*;
use arrow::array::UInt32Array;
use arrow::array::{Array, UInt32Array};

// Utility traits
pub trait TakeIterator: Iterator<Item = usize> {
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIterator + 'a>;
fn check_bounds(&self, bound: usize) -> Result<()>;
}
pub trait TakeIteratorNulls: Iterator<Item = Option<usize>> {
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIteratorNulls + 'a>;
fn check_bounds(&self, bound: usize) -> Result<()>;
}

// Implement for the ref as well
impl TakeIterator for &mut dyn TakeIterator {
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIterator + 'a> {
(**self).shallow_clone()
fn check_bounds(&self, bound: usize) -> Result<()> {
(**self).check_bounds(bound)
}
}
impl TakeIteratorNulls for &mut dyn TakeIteratorNulls {
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIteratorNulls + 'a> {
(**self).shallow_clone()
fn check_bounds(&self, bound: usize) -> Result<()> {
(**self).check_bounds(bound)
}
}

Expand All @@ -27,16 +27,48 @@ impl<I> TakeIterator for I
where
I: Iterator<Item = usize> + Clone + Sized,
{
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIterator + 'a> {
Box::new(self.clone())
fn check_bounds(&self, bound: usize) -> Result<()> {
// clone so that the iterator can be used again.
let iter = self.clone();
let mut inbounds = true;

for i in iter {
if i >= bound {
inbounds = false;
break;
}
}
if inbounds {
Ok(())
} else {
Err(PolarsError::OutOfBounds(
"take indices are out of bounds".into(),
))
}
}
}
impl<I> TakeIteratorNulls for I
where
I: Iterator<Item = Option<usize>> + Clone + Sized,
{
fn shallow_clone<'a>(&'a self) -> Box<dyn TakeIteratorNulls + 'a> {
Box::new(self.clone())
fn check_bounds(&self, bound: usize) -> Result<()> {
// clone so that the iterator can be used again.
let iter = self.clone();
let mut inbounds = true;

for i in iter.flatten() {
if i >= bound {
inbounds = false;
break;
}
}
if inbounds {
Ok(())
} else {
Err(PolarsError::OutOfBounds(
"take indices are out of bounds".into(),
))
}
}
}

Expand All @@ -52,6 +84,48 @@ where
IterNulls(INulls),
}

impl<'a, I, INulls> TakeIdx<'a, I, INulls>
where
I: TakeIterator,
INulls: TakeIteratorNulls,
{
pub(crate) fn check_bounds(&self, bound: usize) -> Result<()> {
match self {
TakeIdx::Iter(i) => i.check_bounds(bound),
TakeIdx::IterNulls(i) => i.check_bounds(bound),
TakeIdx::Array(arr) => {
let mut inbounds = true;
let len = bound as u32;
if arr.null_count() == 0 {
for &i in arr.values().as_slice() {
if i >= len {
inbounds = false;
break;
}
}
} else {
for opt_v in *arr {
match opt_v {
Some(&v) if v >= len => {
inbounds = false;
break;
}
_ => {}
}
}
}
if inbounds {
Ok(())
} else {
Err(PolarsError::OutOfBounds(
"take indices are out of bounds".into(),
))
}
}
}
}
}

/// Dummy type, we need to instantiate all generic types, so we fill one with a dummy.
pub type Dummy<T> = std::iter::Once<T>;

Expand Down

0 comments on commit 37f08f9

Please sign in to comment.