Skip to content

Commit

Permalink
Tag unsafe (#3581)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 5, 2022
1 parent e668c2c commit 4f579fa
Show file tree
Hide file tree
Showing 26 changed files with 410 additions and 386 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ mod test {
));

let groups = s.group_tuples(false, true);
let aggregated = s.agg_list(&groups);
let aggregated = unsafe { s.agg_list(&groups) };
match aggregated.get(0) {
AnyValue::List(s) => {
assert!(matches!(s.dtype(), DataType::Categorical(_)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ impl CategoricalChunked {

pub fn value_counts(&self) -> Result<DataFrame> {
let groups = self.logical().group_tuples(true, false);
let logical_values = self
.logical()
.clone()
.into_series()
.agg_first(&groups)
.u32()
.unwrap()
.clone();
let logical_values = unsafe {
self.logical()
.clone()
.into_series()
.agg_first(&groups)
.u32()
.unwrap()
.clone()
};

let mut values = self.clone();
*values.logical_mut() = logical_values;
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/object/extension/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ mod test {
let ca = ObjectChunked::new("", values);

let groups = GroupsProxy::Idx(vec![(0, vec![0, 1]), (2, vec![2]), (3, vec![3])].into());
let out = ca.agg_list(&groups);
let out = unsafe { ca.agg_list(&groups) };
assert!(matches!(out.dtype(), DataType::List(_)));
assert_eq!(out.len(), groups.len());
}
Expand All @@ -214,7 +214,7 @@ mod test {
let ca = ObjectChunked::new("", values);

let groups = vec![(0, vec![0, 1]), (2, vec![2]), (3, vec![3])].into();
let out = ca.agg_list(&GroupsProxy::Idx(groups));
let out = unsafe { ca.agg_list(&GroupsProxy::Idx(groups)) };
let a = out.explode().unwrap();

let ca_foo = a.as_any().downcast_ref::<ObjectChunked<Foo>>().unwrap();
Expand Down
40 changes: 21 additions & 19 deletions polars/polars-core/src/frame/groupby/aggregations/agg_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@ use super::*;
use crate::chunked_array::builder::AnonymousOwnedListBuilder;

pub trait AggList {
fn agg_list(&self, _groups: &GroupsProxy) -> Series;
/// # Safety
///
/// groups should be in bounds
unsafe fn agg_list(&self, _groups: &GroupsProxy) -> Series;
}

impl<T> AggList for ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
let ca = self.rechunk();

match groups {
Expand All @@ -35,7 +38,7 @@ where
length_so_far += idx_len as i64;
// Safety:
// group tuples are in bounds
unsafe {
{
list_values.extend(idx.iter().map(|idx| {
debug_assert!((*idx as usize) < values.len());
*values.get_unchecked(*idx as usize)
Expand All @@ -51,7 +54,7 @@ where
let mut validity = MutableBitmap::from_len_set(list_values.len());

let mut count = 0;
groups.iter().for_each(|(_, idx)| unsafe {
groups.iter().for_each(|(_, idx)| {
for i in idx {
if !old_validity.get_bit_unchecked(*i as usize) {
validity.set_bit_unchecked(count, false)
Expand Down Expand Up @@ -96,7 +99,7 @@ where

length_so_far += len as i64;
list_values.extend_from_slice(&values[first as usize..(first + len) as usize]);
unsafe {
{
// Safety:
// we know that offsets has allocated enough slots
offsets.push_unchecked(length_so_far);
Expand All @@ -108,7 +111,7 @@ where
let mut validity = MutableBitmap::from_len_set(list_values.len());

let mut count = 0;
groups.iter().for_each(|[first, len]| unsafe {
groups.iter().for_each(|[first, len]| {
for i in *first..(*first + *len) {
if !old_validity.get_bit_unchecked(i as usize) {
validity.set_bit_unchecked(count, false)
Expand Down Expand Up @@ -140,13 +143,13 @@ where
}

impl AggList for BooleanChunked {
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => {
let mut builder =
ListBooleanChunkedBuilder::new(self.name(), groups.len(), self.len());
for idx in groups.all().iter() {
let ca = unsafe { self.take_unchecked(idx.into()) };
let ca = { self.take_unchecked(idx.into()) };
builder.append(&ca)
}
builder.finish().into_series()
Expand All @@ -165,13 +168,13 @@ impl AggList for BooleanChunked {
}

impl AggList for Utf8Chunked {
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => {
let mut builder =
ListUtf8ChunkedBuilder::new(self.name(), groups.len(), self.len());
for idx in groups.all().iter() {
let ca = unsafe { self.take_unchecked(idx.into()) };
let ca = { self.take_unchecked(idx.into()) };
builder.append(&ca)
}
builder.finish().into_series()
Expand Down Expand Up @@ -230,7 +233,7 @@ fn agg_list_list<F: Fn(&ListChunked, bool, &mut Vec<i64>, &mut i64, &mut Vec<Arr
}

impl AggList for ListChunked {
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
match groups {
GroupsProxy::Idx(groups) => {
let func = |ca: &ListChunked,
Expand All @@ -247,7 +250,7 @@ impl AggList for ListChunked {
*length_so_far += idx_len as i64;
// Safety:
// group tuples are in bounds
unsafe {
{
let mut s = ca.take_unchecked(idx.into());
let arr = s.chunks.pop().unwrap();
list_values.push(arr);
Expand Down Expand Up @@ -278,7 +281,7 @@ impl AggList for ListChunked {
let arr = s.chunks.pop().unwrap();
list_values.push(arr);

unsafe {
{
// Safety:
// we know that offsets has allocated enough slots
offsets.push_unchecked(*length_so_far);
Expand All @@ -295,14 +298,14 @@ impl AggList for ListChunked {

#[cfg(feature = "object")]
impl<T: PolarsObject> AggList for ObjectChunked<T> {
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
let mut can_fast_explode = true;
let mut offsets = Vec::<i64>::with_capacity(groups.len() + 1);
let mut length_so_far = 0i64;
offsets.push(length_so_far);

// we know that iterators length
let iter = unsafe {
let iter = {
groups
.iter()
.flat_map(|indicator| {
Expand Down Expand Up @@ -341,7 +344,7 @@ impl<T: PolarsObject> AggList for ObjectChunked<T> {
// this is safe because we just created the PolarsExtension
// meaning that the sentinel is heap allocated and the dereference of the
// pointer does not fail
unsafe { pe.set_to_series_fn::<T>() };
pe.set_to_series_fn::<T>();
let extension_array = Arc::new(pe.take_and_forget()) as ArrayRef;
let extension_dtype = extension_array.data_type();

Expand All @@ -363,7 +366,7 @@ impl<T: PolarsObject> AggList for ObjectChunked<T> {

#[cfg(feature = "dtype-struct")]
impl AggList for StructChunked {
fn agg_list(&self, groups: &GroupsProxy) -> Series {
unsafe fn agg_list(&self, groups: &GroupsProxy) -> Series {
let s = self.clone().into_series();
match groups {
GroupsProxy::Idx(groups) => {
Expand All @@ -373,8 +376,7 @@ impl AggList for StructChunked {
Some(self.dtype().clone()),
);
for idx in groups.all().iter() {
let taken =
unsafe { s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize)) };
let taken = s.take_iter_unchecked(&mut idx.iter().map(|i| *i as usize));
builder.append_series(&taken)
}
builder.finish().into_series()
Expand Down

0 comments on commit 4f579fa

Please sign in to comment.