Skip to content

Commit

Permalink
reduce generic bloat (#2782)
Browse files Browse the repository at this point in the history
A lot of `take_*` methods are called with
an iterator over &[IdxSize] take simply
applies a closure `Fn(v: &IdxSize) -> *v as usize`.
Assuming that every closure is a duplication of the
take_* functions this is a lot of bloat.

Now we implement the conversion to `TakeIdx` that uses
a function item, so one instantiation.
  • Loading branch information
ritchie46 committed Feb 26, 2022
1 parent 4bafb7c commit 9befa6d
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 40 deletions.
31 changes: 31 additions & 0 deletions polars/polars-core/src/chunked_array/ops/take/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,34 @@ where
TakeIdx::IterNulls(iter)
}
}

#[inline]
fn to_usize(idx: &IdxSize) -> usize {
*idx as usize
}

/// Conversion from `&[IdxSize]` to Unchecked TakeIdx
impl<'a> From<&'a [IdxSize]>
for TakeIdx<
'a,
std::iter::Map<std::slice::Iter<'a, IdxSize>, fn(&IdxSize) -> usize>,
Dummy<Option<usize>>,
>
{
fn from(slice: &'a [IdxSize]) -> Self {
TakeIdx::Iter(slice.iter().map(to_usize))
}
}

/// Conversion from `&[IdxSize]` to Unchecked TakeIdx
impl<'a> From<&'a Vec<IdxSize>>
for TakeIdx<
'a,
std::iter::Map<std::slice::Iter<'a, IdxSize>, fn(&IdxSize) -> usize>,
Dummy<Option<usize>>,
>
{
fn from(slice: &'a Vec<IdxSize>) -> Self {
(&**slice).into()
}
}
51 changes: 18 additions & 33 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ where
)
},
_ => {
let take = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
};
let take = unsafe { self.take_unchecked(idx.into()) };
take.min()
}
}
Expand Down Expand Up @@ -283,9 +281,7 @@ where
)
},
_ => {
let take = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
};
let take = unsafe { self.take_unchecked(idx.into()) };
take.max()
}
}
Expand Down Expand Up @@ -332,9 +328,7 @@ where
)
},
_ => {
let take = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
};
let take = unsafe { self.take_unchecked(idx.into()) };
take.sum()
}
}
Expand Down Expand Up @@ -408,9 +402,7 @@ where
.unwrap()
}),
_ => {
let take = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
};
let take = unsafe { self.take_unchecked(idx.into()) };
let opt_sum: Option<T::Native> = take.sum();
opt_sum.map(|sum| sum.to_f64().unwrap() / idx.len() as f64)
}
Expand Down Expand Up @@ -441,7 +433,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { ca.take_unchecked(idx.into()) };
take.var_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
Expand All @@ -465,7 +457,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { ca.take_unchecked(idx.into()) };
take.std_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
Expand Down Expand Up @@ -496,7 +488,7 @@ where
if idx.is_empty() | invalid_quantile {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { ca.take_unchecked(idx.into()) };
take.quantile_as_series(quantile, interpol)
.unwrap() // checked with invalid quantile check
.unpack::<T>()
Expand Down Expand Up @@ -528,7 +520,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { ca.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { ca.take_unchecked(idx.into()) };
take.median_as_series().unpack::<T>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => agg_helper_slice::<T, _>(groups, |[first, len]| {
Expand Down Expand Up @@ -596,9 +588,7 @@ where
.unwrap()
}),
_ => {
let take = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
};
let take = unsafe { self.take_unchecked(idx.into()) };
let opt_sum: Option<T::Native> = take.sum();
opt_sum.map(|sum| sum.to_f64().unwrap() / idx.len() as f64)
}
Expand Down Expand Up @@ -629,7 +619,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { self.take_unchecked(idx.into()) };
take.var_as_series().unpack::<Float64Type>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => {
Expand All @@ -654,7 +644,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { self.take_unchecked(idx.into()) };
take.std_as_series().unpack::<Float64Type>().unwrap().get(0)
}),
GroupsProxy::Slice(groups) => {
Expand Down Expand Up @@ -685,7 +675,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { self.take_unchecked(idx.into()) };
take.quantile_as_series(quantile, interpol)
.unwrap()
.unpack::<Float64Type>()
Expand Down Expand Up @@ -714,7 +704,7 @@ where
if idx.is_empty() {
return None;
}
let take = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let take = unsafe { self.take_unchecked(idx.into()) };
take.median_as_series()
.unpack::<Float64Type>()
.unwrap()
Expand Down Expand Up @@ -802,10 +792,7 @@ where
self.dtype().clone(),
);
for idx in groups.all().iter() {
let s = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
.into_series()
};
let s = unsafe { self.take_unchecked(idx.into()).into_series() };
builder.append_series(&s);
}
return Some(builder.finish().into_series());
Expand Down Expand Up @@ -885,7 +872,7 @@ impl AggList for BooleanChunked {
let mut builder =
ListBooleanChunkedBuilder::new(self.name(), groups.len(), self.len());
for idx in groups.all().iter() {
let ca = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let ca = unsafe { self.take_unchecked(idx.into()) };
builder.append(&ca)
}
Some(builder.finish().into_series())
Expand All @@ -910,7 +897,7 @@ impl AggList for Utf8Chunked {
let mut builder =
ListUtf8ChunkedBuilder::new(self.name(), groups.len(), self.len());
for idx in groups.all().iter() {
let ca = unsafe { self.take_unchecked(idx.iter().map(|i| *i as usize).into()) };
let ca = unsafe { self.take_unchecked(idx.into()) };
builder.append(&ca)
}
Some(builder.finish().into_series())
Expand Down Expand Up @@ -987,8 +974,7 @@ impl AggList for ListChunked {
// Safety:
// group tuples are in bounds
unsafe {
let mut s =
ca.take_unchecked((idx.iter().map(|idx| *idx as usize)).into());
let mut s = ca.take_unchecked(idx.into());
let arr = s.chunks.pop().unwrap();
list_values.push(arr);

Expand Down Expand Up @@ -1050,8 +1036,7 @@ impl<T: PolarsObject> AggList for ObjectChunked<T> {
GroupsIndicator::Idx((_first, idx)) => {
// Safety:
// group tuples always in bounds
let group_vals =
self.take_unchecked((idx.iter().map(|idx| *idx as usize)).into());
let group_vals = self.take_unchecked(idx.into());

(group_vals, idx.len() as IdxSize)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ impl PhysicalAggregation for AggregationExpr {
let ca = unsafe {
// Safety
// The indexes of the groupby operation are never out of bounds
ca.take_unchecked(idx.iter().map(|i| *i as usize).into())
ca.take_unchecked(idx.into())
};
let s = ca.explode()?;
length_so_far += s.len() as i64;
Expand Down
8 changes: 2 additions & 6 deletions polars/polars-time/src/groupby/dynamic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ impl Wrap<&DataFrame> {
.iter()
// we just flat map, because iterate over groups so we almost always need to reallocate
.flat_map(|base_g| {
let dt = unsafe {
dt.take_unchecked((base_g.1.iter().map(|i| *i as usize)).into())
};
let dt = unsafe { dt.take_unchecked(base_g.1.into()) };

let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
Expand Down Expand Up @@ -252,9 +250,7 @@ impl Wrap<&DataFrame> {
groups
.par_iter()
.flat_map(|base_g| {
let dt = unsafe {
dt.take_unchecked((base_g.1.iter().map(|i| *i as usize)).into())
};
let dt = unsafe { dt.take_unchecked(base_g.1.into()) };
let vals = dt.downcast_iter().next().unwrap();
let ts = vals.values().as_slice();
let (sub_groups, _, _) = groupby_windows(
Expand Down

0 comments on commit 9befa6d

Please sign in to comment.