Skip to content

Commit

Permalink
fix list builders for logical types
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 31, 2021
1 parent 6e19c95 commit 8c5653f
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 38 deletions.
44 changes: 26 additions & 18 deletions polars/polars-core/src/chunked_array/builder/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ pub trait ListBuilderTrait {

pub struct ListPrimitiveChunkedBuilder<T>
where
T: PolarsNumericType,
T: NumericNative,
{
pub builder: LargePrimitiveBuilder<T::Native>,
pub builder: LargePrimitiveBuilder<T>,
field: Field,
fast_explode: bool,
}
Expand All @@ -35,12 +35,17 @@ macro_rules! finish_list_builder {

impl<T> ListPrimitiveChunkedBuilder<T>
where
T: PolarsNumericType,
T: NumericNative,
{
pub fn new(name: &str, capacity: usize, values_capacity: usize) -> Self {
let values = MutablePrimitiveArray::<T::Native>::with_capacity(values_capacity);
let builder = LargePrimitiveBuilder::<T::Native>::new_with_capacity(values, capacity);
let field = Field::new(name, DataType::List(Box::new(T::get_dtype())));
pub fn new(
name: &str,
capacity: usize,
values_capacity: usize,
logical_type: DataType,
) -> Self {
let values = MutablePrimitiveArray::<T>::with_capacity(values_capacity);
let builder = LargePrimitiveBuilder::<T>::new_with_capacity(values, capacity);
let field = Field::new(name, DataType::List(Box::new(logical_type)));

Self {
builder,
Expand All @@ -49,7 +54,7 @@ where
}
}

pub fn append_slice(&mut self, opt_v: Option<&[T::Native]>) {
pub fn append_slice(&mut self, opt_v: Option<&[T]>) {
match opt_v {
Some(items) => {
let values = self.builder.mut_values();
Expand All @@ -67,7 +72,7 @@ where
}
/// Appends from an iterator over values
#[inline]
pub fn append_iter_values<I: Iterator<Item = T::Native> + TrustedLen>(&mut self, iter: I) {
pub fn append_iter_values<I: Iterator<Item = T> + TrustedLen>(&mut self, iter: I) {
let values = self.builder.mut_values();

if iter.size_hint().0 == 0 {
Expand All @@ -81,7 +86,7 @@ where

/// Appends from an iterator over values
#[inline]
pub fn append_iter<I: Iterator<Item = Option<T::Native>> + TrustedLen>(&mut self, iter: I) {
pub fn append_iter<I: Iterator<Item = Option<T>> + TrustedLen>(&mut self, iter: I) {
let values = self.builder.mut_values();

if iter.size_hint().0 == 0 {
Expand All @@ -96,7 +101,7 @@ where

impl<T> ListBuilderTrait for ListPrimitiveChunkedBuilder<T>
where
T: PolarsNumericType,
T: NumericNative,
{
#[inline]
fn append_opt_series(&mut self, opt_s: Option<&Series>) {
Expand All @@ -123,10 +128,7 @@ where
let values = self.builder.mut_values();

arrays.iter().for_each(|x| {
let arr = x
.as_any()
.downcast_ref::<PrimitiveArray<T::Native>>()
.unwrap();
let arr = x.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();

if arr.null_count() == 0 {
values.extend_from_slice(arr.values().as_slice())
Expand Down Expand Up @@ -284,10 +286,16 @@ pub fn get_list_builder(
list_capacity: usize,
name: &str,
) -> Box<dyn ListBuilderTrait> {
let physical_type = dt.to_physical();

macro_rules! get_primitive_builder {
($type:ty) => {{
let builder =
ListPrimitiveChunkedBuilder::<$type>::new(&name, list_capacity, value_capacity);
let builder = ListPrimitiveChunkedBuilder::<$type>::new(
&name,
list_capacity,
value_capacity,
dt.clone(),
);
Box::new(builder)
}};
}
Expand All @@ -304,7 +312,7 @@ pub fn get_list_builder(
}};
}
match_arrow_data_type_apply_macro!(
dt,
physical_type,
get_primitive_builder,
get_utf8_builder,
get_bool_builder
Expand Down
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ mod test {

#[test]
fn test_list_builder() {
let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 5);
let mut builder = ListPrimitiveChunkedBuilder::<i32>::new("a", 10, 5, DataType::Int32);

// create a series containing two chunks
let mut s1 = Int32Chunked::new_from_slice("a", &[1, 2, 3]).into_series();
Expand All @@ -209,7 +209,7 @@ mod test {
assert_eq!(out.get(0).unwrap().len(), 6);
assert_eq!(out.get(1).unwrap().len(), 3);

let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 5);
let mut builder = ListPrimitiveChunkedBuilder::<i32>::new("a", 10, 5, DataType::Int32);
builder.append_series(&s1);
builder.append_null();

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ mod test {

#[test]
fn test_cast_list() -> Result<()> {
let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 10);
let mut builder = ListPrimitiveChunkedBuilder::<i32>::new("a", 10, 10, DataType::Int32);
builder.append_slice(Some(&[1i32, 2, 3]));
builder.append_slice(Some(&[1i32, 2, 3]));
let ca = builder.finish();
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ mod test {

#[test]
fn test_fmt_list() {
let mut builder = ListPrimitiveChunkedBuilder::<Int32Type>::new("a", 10, 10);
let mut builder = ListPrimitiveChunkedBuilder::<i32>::new("a", 10, 10, DataType::Int32);
builder.append_slice(Some(&[1, 2, 3]));
builder.append_slice(None);
let list = builder.finish().into_series();
Expand Down
8 changes: 6 additions & 2 deletions polars/polars-core/src/frame/groupby/aggregations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,8 +501,12 @@ where
ListArray::<i64>::from_data(data_type, offsets.into(), Arc::new(array), None)
}
_ => {
let mut builder =
ListPrimitiveChunkedBuilder::<T>::new(self.name(), groups.len(), self.len());
let mut builder = ListPrimitiveChunkedBuilder::<T::Native>::new(
self.name(),
groups.len(),
self.len(),
self.dtype().clone(),
);
for (_first, idx) in groups {
let s = unsafe {
self.take_unchecked(idx.iter().map(|i| *i as usize).into())
Expand Down
20 changes: 10 additions & 10 deletions polars/polars-core/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,19 +267,19 @@ macro_rules! match_arrow_data_type_apply_macro {
DataType::Utf8 => $macro_utf8!($($opt_args)*),
DataType::Boolean => $macro_bool!($($opt_args)*),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => $macro!(UInt8Type $(, $opt_args)*),
DataType::UInt8 => $macro!(u8 $(, $opt_args)*),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => $macro!(UInt16Type $(, $opt_args)*),
DataType::UInt32 => $macro!(UInt32Type $(, $opt_args)*),
DataType::UInt64 => $macro!(UInt64Type $(, $opt_args)*),
DataType::UInt16 => $macro!(u16 $(, $opt_args)*),
DataType::UInt32 => $macro!(u32 $(, $opt_args)*),
DataType::UInt64 => $macro!(u64 $(, $opt_args)*),
#[cfg(feature = "dtype-i8")]
DataType::Int8 => $macro!(Int8Type $(, $opt_args)*),
DataType::Int8 => $macro!(i8 $(, $opt_args)*),
#[cfg(feature = "dtype-i16")]
DataType::Int16 => $macro!(Int16Type $(, $opt_args)*),
DataType::Int32 => $macro!(Int32Type $(, $opt_args)*),
DataType::Int64 => $macro!(Int64Type $(, $opt_args)*),
DataType::Float32 => $macro!(Float32Type $(, $opt_args)*),
DataType::Float64 => $macro!(Float64Type $(, $opt_args)*),
DataType::Int16 => $macro!(i16 $(, $opt_args)*),
DataType::Int32 => $macro!(i32 $(, $opt_args)*),
DataType::Int64 => $macro!(i64 $(, $opt_args)*),
DataType::Float32 => $macro!(f32 $(, $opt_args)*),
DataType::Float64 => $macro!(f64 $(, $opt_args)*),
dt => panic!("not implemented for dtype {:?}", dt),
}
}};
Expand Down
8 changes: 6 additions & 2 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,12 @@ pub fn arange(low: Expr, high: Expr, step: usize) -> Expr {
let sb = sb.cast(&DataType::Int64)?;
let low = sa.i64()?;
let high = sb.i64()?;
let mut builder =
ListPrimitiveChunkedBuilder::<Int64Type>::new("arange", low.len(), low.len() * 3);
let mut builder = ListPrimitiveChunkedBuilder::<i64>::new(
"arange",
low.len(),
low.len() * 3,
DataType::Int64,
);

low.into_iter()
.zip(high.into_iter())
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/list_construction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn py_seq_to_list(name: &str, seq: &PyAny, dtype: &PyAny) -> PyResult<Series
let (seq, len) = get_pyseq(seq)?;
let s = match dtype {
DataType::Int64 => {
let mut builder = ListPrimitiveChunkedBuilder::<Int64Type>::new(name, len, len * 5);
let mut builder = ListPrimitiveChunkedBuilder::<i64>::new(name, len, len * 5, DataType::Int64);
for sub_seq in seq.iter()? {
let sub_seq = sub_seq?;
let (sub_seq, len) = get_pyseq(sub_seq)?;
Expand All @@ -31,7 +31,7 @@ pub fn py_seq_to_list(name: &str, seq: &PyAny, dtype: &PyAny) -> PyResult<Series
builder.finish().into_series()
}
DataType::Float64 => {
let mut builder = ListPrimitiveChunkedBuilder::<Float64Type>::new(name, len, len * 5);
let mut builder = ListPrimitiveChunkedBuilder::<f64>::new(name, len, len * 5, DataType::Float64);
for sub_seq in seq.iter()? {
let sub_seq = sub_seq?;
let (sub_seq, len) = get_pyseq(sub_seq)?;
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1192,3 +1192,24 @@ def test_filter_with_all_expansion():
)
out = df.filter(~pl.fold(True, lambda acc, s: acc & s.is_null(), pl.all()))
assert out.shape == (2, 3)


def test_diff_datetime():

df = pl.DataFrame(
{
"timestamp": ["2021-02-01", "2021-03-1", "2021-04-1"],
"guild": [1, 2, 3],
"char": ["a", "a", "b"],
}
)

out = (
df.with_columns(
[
pl.col("timestamp").str.strptime(pl.Date, fmt="%Y-%m-%d"),
]
).with_columns([pl.col("timestamp").diff().over(pl.col("char"))])
)["timestamp"]

assert out[0] == out[1]

0 comments on commit 8c5653f

Please sign in to comment.