Skip to content

Commit

Permalink
use specialized list collect for most common numerical types
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 9, 2022
1 parent cc659f8 commit 8f9010c
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 24 deletions.
24 changes: 23 additions & 1 deletion polars/polars-core/src/chunked_array/builder/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,27 @@ pub trait ListBuilderTrait {
fn finish(&mut self) -> ListChunked;
}

impl<S: ?Sized> ListBuilderTrait for Box<S>
where
S: ListBuilderTrait,
{
fn append_opt_series(&mut self, opt_s: Option<&Series>) {
(**self).append_opt_series(opt_s)
}

fn append_series(&mut self, s: &Series) {
(**self).append_series(s)
}

fn append_null(&mut self) {
(**self).append_null()
}

fn finish(&mut self) -> ListChunked {
(**self).finish()
}
}

pub struct ListPrimitiveChunkedBuilder<T>
where
T: NumericNative,
Expand Down Expand Up @@ -139,7 +160,8 @@ where
unsafe { values.extend_trusted_len_unchecked(arr.into_iter()) }
}
});
self.builder.try_push_valid().unwrap();
// overflow of i64 is far beyond polars capable lengths.
unsafe { self.builder.try_push_valid().unwrap_unchecked() };
}

fn finish(&mut self) -> ListChunked {
Expand Down
137 changes: 114 additions & 23 deletions polars/polars-core/src/chunked_array/upstream_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,25 +153,57 @@ where
}
}

fn primitive_series_collect<Ptr, Iter, Lb>(
mut nulls_so_far: usize,
iter: Iter,
s: &Series,
builder: &mut Lb,
) -> ListChunked
where
Ptr: Borrow<Series>,
Iter: Iterator<Item = Option<Ptr>>,
Lb: ?Sized + ListBuilderTrait,
{
// first fill all None's we encountered
while nulls_so_far > 0 {
builder.append_null();
nulls_so_far -= 1;
}

// now the first non None
builder.append_series(s);

// now we have added all Nones, we can consume the rest of the iterator.
for opt_s in iter {
match opt_s {
Some(s) => builder.append_series(s.borrow()),
None => builder.append_null(),
}
}

builder.finish()
}

impl<Ptr> FromIterator<Option<Ptr>> for ListChunked
where
Ptr: Borrow<Series>,
{
fn from_iter<I: IntoIterator<Item = Option<Ptr>>>(iter: I) -> Self {
// first pull all `None` values so that we can determine the inner `dtype`
let mut it = iter.into_iter();
let owned_v;
let mut cnt = 0;
let owned_s;
let mut nulls_so_far = 0;

loop {
let opt_v = it.next();

match opt_v {
Some(opt_v) => match opt_v {
Some(val) => {
owned_v = val;
owned_s = val;
break;
}
None => cnt += 1,
None => nulls_so_far += 1,
},
// end of iterator
None => {
Expand All @@ -180,28 +212,87 @@ where
}
}
}
let v = owned_v.borrow();
let s: &Series = owned_s.borrow();
let capacity = get_iter_capacity(&it);
let mut builder = get_list_builder(v.dtype(), capacity * 5, capacity, "collected");

// first fill all None's we encountered
while cnt > 0 {
builder.append_opt_series(None);
cnt -= 1;
}

// now the first non None
builder.append_series(v);

// now we have added all Nones, we can consume the rest of the iterator.
for opt_s in it {
match opt_s {
Some(s) => builder.append_series(s.borrow()),
None => builder.append_null(),
let estimated_s_size = std::cmp::min(s.len(), 1 << 18);
// use specialized builder for most common types
match s.dtype() {
DataType::UInt32 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<u32>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
DataType::Int32 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<i32>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
DataType::UInt64 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<u64>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
DataType::Int64 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<i64>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
DataType::Float32 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<f32>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
DataType::Float64 => primitive_series_collect(
nulls_so_far,
it,
s,
&mut ListPrimitiveChunkedBuilder::<f64>::new(
"collected",
capacity,
capacity * estimated_s_size,
s.dtype().clone(),
),
),
_ => {
let mut builder = get_list_builder(
s.dtype(),
capacity * estimated_s_size,
capacity,
"collected",
);
primitive_series_collect(nulls_so_far, it, s, &mut builder)
}
}

builder.finish()
}
}

Expand Down

0 comments on commit 8f9010c

Please sign in to comment.