Skip to content

Commit

Permalink
reduce monomorphization of take_rand structs
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 26, 2021
1 parent 165758c commit ff0b7d7
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 19 deletions.
12 changes: 6 additions & 6 deletions polars/polars-core/src/chunked_array/ops/compare_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro_rules! impl_traits {
($struct:ty, $T:tt) => {
impl<$T> PartialEqInner for $struct
where
$T: PolarsNumericType + Sync,
$T: NumericNative + Sync,
{
#[inline]
unsafe fn eq_element_unchecked(&self, idx_a: usize, idx_b: usize) -> bool {
Expand All @@ -53,7 +53,7 @@ macro_rules! impl_traits {

impl<$T> PartialOrdInner for $struct
where
$T: PolarsNumericType + Sync,
$T: NumericNative + Sync,
{
#[inline]
unsafe fn cmp_element_unchecked(&self, idx_a: usize, idx_b: usize) -> Ordering {
Expand Down Expand Up @@ -106,11 +106,11 @@ where
};
Box::new(t)
} else {
let t = NumTakeRandomSingleChunk::<'_, T> { arr };
let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr };
Box::new(t)
}
} else {
let t = NumTakeRandomChunked::<'_, T> {
let t = NumTakeRandomChunked::<'_, T::Native> {
chunks: chunks.collect(),
chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(),
};
Expand Down Expand Up @@ -219,11 +219,11 @@ where
};
Box::new(t)
} else {
let t = NumTakeRandomSingleChunk::<'_, T> { arr };
let t = NumTakeRandomSingleChunk::<'_, T::Native> { arr };
Box::new(t)
}
} else {
let t = NumTakeRandomChunked::<'_, T> {
let t = NumTakeRandomChunked::<'_, T::Native> {
chunks: chunks.collect(),
chunk_lens: self.chunks.iter().map(|a| a.len() as u32).collect(),
};
Expand Down
21 changes: 11 additions & 10 deletions polars/polars-core/src/chunked_array/ops/take/take_random.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,16 @@ where
}
}

#[allow(clippy::type_complexity)]
impl<'a, T> IntoTakeRandom<'a> for &'a ChunkedArray<T>
where
T: PolarsNumericType,
{
type Item = T::Native;
type TakeRandom = TakeRandBranch3<
NumTakeRandomCont<'a, T::Native>,
NumTakeRandomSingleChunk<'a, T>,
NumTakeRandomChunked<'a, T>,
NumTakeRandomSingleChunk<'a, T::Native>,
NumTakeRandomChunked<'a, T::Native>,
>;

#[inline]
Expand Down Expand Up @@ -271,17 +272,17 @@ impl<'a> IntoTakeRandom<'a> for &'a ListChunked {

pub struct NumTakeRandomChunked<'a, T>
where
T: PolarsNumericType,
T: NumericNative,
{
pub(crate) chunks: Vec<&'a PrimitiveArray<T::Native>>,
pub(crate) chunks: Vec<&'a PrimitiveArray<T>>,
pub(crate) chunk_lens: Vec<u32>,
}

impl<'a, T> TakeRandom for NumTakeRandomChunked<'a, T>
where
T: PolarsNumericType,
T: NumericNative,
{
type Item = T::Native;
type Item = T;

#[inline]
fn get(&self, index: usize) -> Option<Self::Item> {
Expand Down Expand Up @@ -317,16 +318,16 @@ where

pub struct NumTakeRandomSingleChunk<'a, T>
where
T: PolarsNumericType,
T: NumericNative,
{
pub(crate) arr: &'a PrimitiveArray<T::Native>,
pub(crate) arr: &'a PrimitiveArray<T>,
}

impl<'a, T> TakeRandom for NumTakeRandomSingleChunk<'a, T>
where
T: PolarsNumericType,
T: NumericNative,
{
type Item = T::Native;
type Item = T;

#[inline]
fn get(&self, index: usize) -> Option<Self::Item> {
Expand Down
3 changes: 2 additions & 1 deletion polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ pub type CategoricalChunked = ChunkedArray<CategoricalType>;

pub trait NumericNative:
PartialOrd
+ NativeType
+ Num
+ NumCast
+ Zero
Expand Down Expand Up @@ -156,7 +157,7 @@ impl NumericNative for f32 {}
impl NumericNative for f64 {}

pub trait PolarsNumericType: Send + Sync + PolarsDataType + 'static {
type Native: NativeType + NumericNative;
type Native: NumericNative;
}
impl PolarsNumericType for UInt8Type {
type Native = u8;
Expand Down
18 changes: 16 additions & 2 deletions polars/polars-core/src/frame/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,11 @@ impl DataFrame {

let iter = columns.iter().map(|s| {
(0..s.len()).zip(row.0.iter_mut()).for_each(|(i, av)| {
*av = s.get(i);
// Safety:
// we iterate over the length of s, so we are in bounds
unsafe { *av = s.get_unchecked(i) };
});
// borrow checkery does not allow row borrow, so we deref from raw ptr.
// borrow checker does not allow row borrow, so we deref from raw ptr.
// we do all this to amortize allocs
// Safety:
// row is still alive
Expand Down Expand Up @@ -201,6 +203,8 @@ impl<'a> From<&AnyValue<'a>> for Field {
Date(_) => Field::new("", DataType::Date),
#[cfg(feature = "dtype-datetime")]
Datetime(_) => Field::new("", DataType::Datetime),
#[cfg(feature = "dtype-time")]
Time(_) => Field::new("", DataType::Time),
_ => unimplemented!(),
}
}
Expand Down Expand Up @@ -232,6 +236,8 @@ pub(crate) enum Buffer {
Date(PrimitiveChunkedBuilder<Int32Type>),
#[cfg(feature = "dtype-datetime")]
Datetime(PrimitiveChunkedBuilder<Int64Type>),
#[cfg(feature = "dtype-time")]
Time(PrimitiveChunkedBuilder<Int64Type>),
Float32(PrimitiveChunkedBuilder<Float32Type>),
Float64(PrimitiveChunkedBuilder<Float64Type>),
Utf8(Utf8ChunkedBuilder),
Expand All @@ -250,6 +256,8 @@ impl Debug for Buffer {
Date(_) => f.write_str("Date"),
#[cfg(feature = "dtype-datetime")]
Datetime(_) => f.write_str("datetime"),
#[cfg(feature = "dtype-time")]
Time(_) => f.write_str("time"),
Float32(_) => f.write_str("f32"),
Float64(_) => f.write_str("f64"),
Utf8(_) => f.write_str("utf8"),
Expand All @@ -275,6 +283,8 @@ impl Buffer {
(Date(builder), AnyValue::Null) => builder.append_null(),
#[cfg(feature = "dtype-datetime")]
(Datetime(builder), AnyValue::Datetime(v)) => builder.append_value(v),
#[cfg(feature = "dtype-time")]
(Time(builder), AnyValue::Time(v)) => builder.append_value(v),
(Float32(builder), AnyValue::Null) => builder.append_null(),
(Float64(builder), AnyValue::Float64(v)) => builder.append_value(v),
(Utf8(builder), AnyValue::Utf8(v)) => builder.append_value(v),
Expand All @@ -297,6 +307,8 @@ impl Buffer {
Date(b) => b.finish().into_date().into_series(),
#[cfg(feature = "dtype-datetime")]
Datetime(b) => b.finish().into_date().into_series(),
#[cfg(feature = "dtype-time")]
Time(b) => b.finish().into_date().into_series(),
Float32(b) => b.finish().into_series(),
Float64(b) => b.finish().into_series(),
Utf8(b) => b.finish().into_series(),
Expand All @@ -319,6 +331,8 @@ impl From<(&DataType, usize)> for Buffer {
Date => Buffer::Date(PrimitiveChunkedBuilder::new("", len)),
#[cfg(feature = "dtype-datetime")]
Datetime => Buffer::Datetime(PrimitiveChunkedBuilder::new("", len)),
#[cfg(feature = "dtype-time")]
Time => Buffer::Time(PrimitiveChunkedBuilder::new("", len)),
Float32 => Buffer::Float32(PrimitiveChunkedBuilder::new("", len)),
Float64 => Buffer::Float64(PrimitiveChunkedBuilder::new("", len)),
Utf8 => Buffer::Utf8(Utf8ChunkedBuilder::new("", len, len * 5)),
Expand Down

0 comments on commit ff0b7d7

Please sign in to comment.