Skip to content

Commit

Permalink
fix(rust): tag amortized iter unsafe and add safe alternatives (#10881)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 3, 2023
1 parent 1a8fba5 commit 01a90a3
Show file tree
Hide file tree
Showing 16 changed files with 401 additions and 360 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/array/iterator.rs
Expand Up @@ -47,7 +47,7 @@ impl ArrayChunked {
// Safety:
// inner type passed as physical type
let series_container = unsafe {
Box::new(Series::from_chunks_and_dtype_unchecked(
Box::pin(Series::from_chunks_and_dtype_unchecked(
name,
vec![inner_values.clone()],
&iter_dtype,
Expand Down
71 changes: 41 additions & 30 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Expand Up @@ -683,45 +683,56 @@ impl ChunkCompare<&str> for Utf8Chunked {
impl ChunkCompare<&ListChunked> for ListChunked {
type Item = BooleanChunked;
fn equal(&self, rhs: &ListChunked) -> BooleanChunked {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
})
.collect_trusted()
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
})
.collect_trusted()
}
}

fn equal_missing(&self, rhs: &ListChunked) -> BooleanChunked {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => true,
_ => false,
})
.collect_trusted()
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => true,
_ => false,
})
.collect_trusted()
}
}

fn not_equal(&self, rhs: &ListChunked) -> BooleanChunked {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(!l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
})
.collect_trusted()
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => Some(!l.as_ref().series_equal_missing(r.as_ref())),
_ => None,
})
.collect_trusted()
}
}

fn not_equal_missing(&self, rhs: &ListChunked) -> BooleanChunked {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => !l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => false,
_ => true,
})
.collect_trusted()
unsafe {
self.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| match (left, right) {
(Some(l), Some(r)) => !l.as_ref().series_equal_missing(r.as_ref()),
(None, None) => false,
_ => true,
})
.collect_trusted()
}
}

// The following are not implemented because gt, lt comparison of series don't make sense.
Expand Down
104 changes: 69 additions & 35 deletions crates/polars-core/src/chunked_array/list/iterator.rs
@@ -1,4 +1,5 @@
use std::marker::PhantomData;
use std::pin::Pin;
use std::ptr::NonNull;

use crate::prelude::*;
Expand All @@ -7,7 +8,7 @@ use crate::utils::CustomIterTools;

pub struct AmortizedListIter<'a, I: Iterator<Item = Option<ArrayBox>>> {
len: usize,
series_container: Box<Series>,
series_container: Pin<Box<Series>>,
inner: NonNull<ArrayRef>,
lifetime: PhantomData<&'a ArrayRef>,
iter: I,
Expand All @@ -19,7 +20,7 @@ pub struct AmortizedListIter<'a, I: Iterator<Item = Option<ArrayBox>>> {
impl<'a, I: Iterator<Item = Option<ArrayBox>>> AmortizedListIter<'a, I> {
pub(crate) fn new(
len: usize,
series_container: Box<Series>,
series_container: Pin<Box<Series>>,
inner: NonNull<ArrayRef>,
iter: I,
inner_dtype: DataType,
Expand Down Expand Up @@ -111,11 +112,20 @@ impl ListChunked {
/// this function still needs precautions. The returned should never be cloned or taken longer
/// than a single iteration, as every call on `next` of the iterator will change the contents of
/// that Series.
pub fn amortized_iter(&self) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
///
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter(
&self,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
self.amortized_iter_with_name("")
}

pub fn amortized_iter_with_name(
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter_with_name(
&self,
name: &str,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
Expand Down Expand Up @@ -143,7 +153,7 @@ impl ListChunked {
&iter_dtype,
);
s.clear_settings();
Box::new(s)
Box::pin(s)
};

let ptr = series_container.array_ref(0) as *const ArrayRef as *mut ArrayRef;
Expand All @@ -157,6 +167,23 @@ impl ListChunked {
)
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray<V>
where
V: PolarsDataType,
F: FnMut(Option<UnstableSeries<'a>>) -> Option<K> + Copy,
K: ArrayFromElementIter,
K::ArrayType: StaticallyMatchesPolarsType<V>,
{
// TODO! make an amortized iter that does not flatten

// SAFETY: unstable series never lives longer than the iterator.
let element_iter = unsafe { self.amortized_iter().map(f) };
let array = K::array_from_iter(element_iter);
ChunkedArray::from_chunk_iter(self.name(), std::iter::once(array))
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
Expand All @@ -167,18 +194,20 @@ impl ListChunked {
return self.clone();
}
let mut fast_explode = self.null_count() == 0;
let mut ca: ListChunked = self
.amortized_iter()
.map(|opt_v| {
opt_v.map(|v| {
let out = f(v);
if out.is_empty() {
fast_explode = false;
}
out
// SAFETY: unstable series never lives longer than the iterator.
let mut ca: ListChunked = unsafe {
self.amortized_iter()
.map(|opt_v| {
opt_v.map(|v| {
let out = f(v);
if out.is_empty() {
fast_explode = false;
}
out
})
})
})
.collect_trusted();
.collect_trusted()
};

ca.rename(self.name());
if fast_explode {
Expand All @@ -195,22 +224,24 @@ impl ListChunked {
return Ok(self.clone());
}
let mut fast_explode = self.null_count() == 0;
let mut ca: ListChunked = self
.amortized_iter()
.map(|opt_v| {
opt_v
.map(|v| {
let out = f(v);
if let Ok(out) = &out {
if out.is_empty() {
fast_explode = false
}
};
out
})
.transpose()
})
.collect::<PolarsResult<_>>()?;
// SAFETY: unstable series never lives longer than the iterator.
let mut ca: ListChunked = unsafe {
self.amortized_iter()
.map(|opt_v| {
opt_v
.map(|v| {
let out = f(v);
if let Ok(out) = &out {
if out.is_empty() {
fast_explode = false
}
};
out
})
.transpose()
})
.collect::<PolarsResult<_>>()?
};
ca.rename(self.name());
if fast_explode {
ca.set_fast_explode();
Expand All @@ -232,8 +263,11 @@ mod test {
builder.append_series(&Series::new("", &[1, 1])).unwrap();
let ca = builder.finish();

ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| {
assert!(s1.unwrap().as_ref().series_equal(&s2.unwrap()));
});
// SAFETY: unstable series never lives longer than the iterator.
unsafe {
ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| {
assert!(s1.unwrap().as_ref().series_equal(&s2.unwrap()));
})
};
}
}
4 changes: 3 additions & 1 deletion crates/polars-lazy/src/physical_plan/expressions/apply.rs
Expand Up @@ -193,7 +193,9 @@ impl ApplyExpr {
// then unpack the lists and finally create iterators from this list chunked arrays.
let mut iters = acs
.iter_mut()
.map(|ac| ac.iter_groups(self.pass_name_to_apply))
.map(|ac|
// SAFETY: unstable series never lives longer than the iterator.
unsafe { ac.iter_groups(self.pass_name_to_apply) })
.collect::<Vec<_>>();

// length of the items to iterate over
Expand Down
32 changes: 17 additions & 15 deletions crates/polars-lazy/src/physical_plan/expressions/binary.rs
Expand Up @@ -112,21 +112,23 @@ impl BinaryExpr {
mut ac_r: AggregationContext<'a>,
) -> PolarsResult<AggregationContext<'a>> {
let name = ac_l.series().name().to_string();
let mut ca: ListChunked = ac_l
.iter_groups(false)
.zip(ac_r.iter_groups(false))
.map(|(l, r)| {
match (l, r) {
(Some(l), Some(r)) => {
let l = l.as_ref();
let r = r.as_ref();
Some(apply_operator(l, r, self.op))
},
_ => None,
}
.transpose()
})
.collect::<PolarsResult<_>>()?;
// SAFETY: unstable series never lives longer than the iterator.
let mut ca: ListChunked = unsafe {
ac_l.iter_groups(false)
.zip(ac_r.iter_groups(false))
.map(|(l, r)| {
match (l, r) {
(Some(l), Some(r)) => {
let l = l.as_ref();
let r = r.as_ref();
Some(apply_operator(l, r, self.op))
},
_ => None,
}
.transpose()
})
.collect::<PolarsResult<_>>()?
};
ca.rename(&name);

// try if we can reuse the groups
Expand Down
21 changes: 12 additions & 9 deletions crates/polars-lazy/src/physical_plan/expressions/filter.rs
Expand Up @@ -49,17 +49,20 @@ impl PhysicalExpr for FilterExpr {
let (mut ac_s, mut ac_predicate) = (ac_s?, ac_predicate?);

if ac_predicate.is_aggregated() || ac_s.is_aggregated() {
let preds = ac_predicate.iter_groups(false);
// SAFETY: unstable series never lives longer than the iterator.
let preds = unsafe { ac_predicate.iter_groups(false) };
let s = ac_s.aggregated();
let ca = s.list()?;
let mut out = ca
.amortized_iter()
.zip(preds)
.map(|(opt_s, opt_pred)| match (opt_s, opt_pred) {
(Some(s), Some(pred)) => s.as_ref().filter(pred.as_ref().bool()?).map(Some),
_ => Ok(None),
})
.collect::<PolarsResult<ListChunked>>()?;
// SAFETY: unstable series never lives longer than the iterator.
let mut out = unsafe {
ca.amortized_iter()
.zip(preds)
.map(|(opt_s, opt_pred)| match (opt_s, opt_pred) {
(Some(s), Some(pred)) => s.as_ref().filter(pred.as_ref().bool()?).map(Some),
_ => Ok(None),
})
.collect::<PolarsResult<ListChunked>>()?
};
out.rename(s.name());
ac_s.with_series(out.into_series(), true, Some(&self.expr))?;
ac_s.update_groups = WithSeriesLen;
Expand Down
Expand Up @@ -5,7 +5,10 @@ use polars_core::series::unstable::UnstableSeries;
use super::*;

impl<'a> AggregationContext<'a> {
pub(super) fn iter_groups(
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub(super) unsafe fn iter_groups(
&mut self,
keep_names: bool,
) -> Box<dyn Iterator<Item = Option<UnstableSeries<'_>>> + '_> {
Expand Down
36 changes: 19 additions & 17 deletions crates/polars-lazy/src/physical_plan/expressions/mod.rs
Expand Up @@ -265,23 +265,25 @@ impl<'a> AggregationContext<'a> {
});
},
_ => {
let groups = self
.series()
.list()
.expect("impl error, should be a list at this point")
.amortized_iter()
.map(|s| {
if let Some(s) = s {
let len = s.as_ref().len() as IdxSize;
let new_offset = offset + len;
let out = [offset, len];
offset = new_offset;
out
} else {
[offset, 0]
}
})
.collect_trusted();
// SAFETY: unstable series never lives longer than the iterator.
let groups = unsafe {
self.series()
.list()
.expect("impl error, should be a list at this point")
.amortized_iter()
.map(|s| {
if let Some(s) = s {
let len = s.as_ref().len() as IdxSize;
let new_offset = offset + len;
let out = [offset, len];
offset = new_offset;
out
} else {
[offset, 0]
}
})
.collect_trusted()
};
self.groups = Cow::Owned(GroupsProxy::Slice {
groups,
rolling: false,
Expand Down

0 comments on commit 01a90a3

Please sign in to comment.