Skip to content

Commit

Permalink
fix[rust]: fast to_list keep logical type (#4551)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 24, 2022
1 parent eb3b53a commit 8732182
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 36 deletions.
2 changes: 1 addition & 1 deletion polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,7 @@ impl ListChunked {
}

pub(crate) fn with_inner_type(&mut self, dtype: DataType) {
assert_eq!(dtype.to_physical(), self.inner_dtype());
debug_assert_eq!(dtype.to_physical(), self.inner_dtype().to_physical());
let field = Arc::make_mut(&mut self.field);
field.coerce(DataType::List(Box::new(dtype)));
}
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/series/ops/to_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn reshape_fast_path(name: &str, s: &Series) -> Series {
};

let mut ca = ListChunked::from_chunks(name, chunks);
ca.with_inner_type(s.dtype().clone());
ca.set_fast_explode();
ca.into_series()
}
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-lazy/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use super::*;

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ListFunction {
Concat,
}

#[cfg(feature = "is_in")]
pub(super) fn contains(args: &mut [Series]) -> Result<Series> {
let list = &args[0];
let is_in = &args[1];
Expand All @@ -9,3 +16,17 @@ pub(super) fn contains(args: &mut [Series]) -> Result<Series> {
ca.into_series()
})
}

pub(super) fn concat(s: &mut [Series]) -> Result<Series> {
let mut first = std::mem::take(&mut s[0]);
let other = &s[1..];

let first_ca = match first.list().ok() {
Some(ca) => ca,
None => {
first = first.reshape(&[-1, 1]).unwrap();
first.list().unwrap()
}
};
first_ca.lst_concat(other).map(|ca| ca.into_series())
}
62 changes: 61 additions & 1 deletion polars/polars-lazy/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod clip;
mod fill_null;
#[cfg(feature = "is_in")]
mod is_in;
#[cfg(feature = "is_in")]
#[cfg(any(feature = "is_in", feature = "list"))]
mod list;
mod nan;
mod pow;
Expand All @@ -25,6 +25,8 @@ mod temporal;
#[cfg(feature = "trigonometry")]
mod trigonometry;

#[cfg(feature = "list")]
pub(super) use list::ListFunction;
use polars_core::prelude::*;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -75,6 +77,8 @@ pub enum FunctionExpr {
min: Option<AnyValue<'static>>,
max: Option<AnyValue<'static>>,
},
#[cfg(feature = "list")]
ListExpr(ListFunction),
}

#[cfg(feature = "trigonometry")]
Expand Down Expand Up @@ -102,20 +106,38 @@ impl FunctionExpr {
_cntxt: Context,
fields: &[Field],
) -> Result<Field> {
// set a dtype
let with_dtype = |dtype: DataType| Ok(Field::new(fields[0].name(), dtype));

// map a single dtype
let map_dtype = |func: &dyn Fn(&DataType) -> DataType| {
let dtype = func(fields[0].data_type());
Ok(Field::new(fields[0].name(), dtype))
};

// map all dtypes
#[cfg(feature = "list")]
let map_dtypes = |func: &dyn Fn(&[&DataType]) -> DataType| {
let mut fld = fields[0].clone();
let dtypes = fields.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = func(&dtypes);
fld.coerce(new_type);
Ok(fld)
};

#[cfg(any(feature = "rolling_window", feature = "trigonometry"))]
// set float supertype
let float_dtype = || {
map_dtype(&|dtype| match dtype {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
})
};

// map to same type
let same_type = || map_dtype(&|dtype| dtype.clone());

// get supertype of all types
let super_type = || {
let mut first = fields[0].clone();
let mut st = first.data_type().clone();
Expand All @@ -126,6 +148,30 @@ impl FunctionExpr {
Ok(first)
};

// inner super type of lists
#[cfg(feature = "list")]
let inner_super_type_list = || {
map_dtypes(&|dts| {
let mut super_type_inner = None;

for dt in dts {
match dt {
DataType::List(inner) => match super_type_inner {
None => super_type_inner = Some(*inner.clone()),
Some(st_inner) => {
super_type_inner = get_supertype(&st_inner, inner).ok()
}
},
dt => match super_type_inner {
None => super_type_inner = Some((*dt).clone()),
Some(st_inner) => super_type_inner = get_supertype(&st_inner, dt).ok(),
},
}
}
DataType::List(Box::new(super_type_inner.unwrap()))
})
};

use FunctionExpr::*;
match self {
NullCount => with_dtype(IDX_DTYPE),
Expand Down Expand Up @@ -173,6 +219,13 @@ impl FunctionExpr {
Nan(n) => n.get_field(fields),
#[cfg(feature = "round_series")]
Clip { .. } => same_type(),
#[cfg(feature = "list")]
ListExpr(l) => {
use ListFunction::*;
match l {
Concat => inner_super_type_list(),
}
}
}
}
}
Expand Down Expand Up @@ -309,6 +362,13 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
Clip { min, max } => {
map_owned!(clip::clip, min.clone(), max.clone())
}
#[cfg(feature = "list")]
ListExpr(lf) => {
use ListFunction::*;
match lf {
Concat => wrap!(list::concat),
}
}
}
}
}
Expand Down
39 changes: 5 additions & 34 deletions polars/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ use polars_core::utils::arrow::temporal_conversions::SECONDS_IN_DAY;
#[cfg(feature = "rank")]
use polars_core::utils::coalesce_nulls_series;
use polars_core::utils::get_supertype;
#[cfg(feature = "list")]
use polars_ops::prelude::ListNameSpaceImpl;
use rayon::prelude::*;

#[cfg(feature = "arg_where")]
use crate::dsl::function_expr::FunctionExpr;
#[cfg(feature = "list")]
use crate::dsl::function_expr::ListFunction;
use crate::dsl::*;
use crate::prelude::*;
use crate::utils::has_wildcard;

Expand Down Expand Up @@ -255,39 +256,9 @@ pub fn concat_str<E: AsRef<[Expr]>>(s: E, sep: &str) -> Expr {
pub fn concat_lst<E: AsRef<[IE]>, IE: Into<Expr> + Clone>(s: E) -> Expr {
let s = s.as_ref().iter().map(|e| e.clone().into()).collect();

let function = SpecialEq::new(Arc::new(move |s: &mut [Series]| {
let mut first = std::mem::take(&mut s[0]);
let other = &s[1..];

let first_ca = match first.list().ok() {
Some(ca) => ca,
None => {
first = first.reshape(&[-1, 1]).unwrap();
first.list().unwrap()
}
};
first_ca.lst_concat(other).map(|ca| ca.into_series())
}) as Arc<dyn SeriesUdf>);
Expr::AnonymousFunction {
Expr::Function {
input: s,
function,
output_type: GetOutput::map_dtypes(|dts| {
let mut super_type_inner = None;

for dt in dts {
match dt {
DataType::List(inner) => match super_type_inner {
None => super_type_inner = Some(*inner.clone()),
Some(st_inner) => super_type_inner = get_supertype(&st_inner, inner).ok(),
},
dt => match super_type_inner {
None => super_type_inner = Some((*dt).clone()),
Some(st_inner) => super_type_inner = get_supertype(&st_inner, dt).ok(),
},
}
}
DataType::List(Box::new(super_type_inner.unwrap()))
}),
function: FunctionExpr::ListExpr(ListFunction::Concat),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
input_wildcard_expansion: true,
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1319,3 +1319,14 @@ def test_shift_and_fill_group_logicals() -> None:
assert df.select(
pl.col("d").shift_and_fill(-1, pl.col("d").max()).over("s")
).dtypes == [pl.Date]


def test_date_arr_concat() -> None:
expected = {"d": [[date(2000, 1, 1), date(2000, 1, 1)]]}

# type date
df = pl.DataFrame({"d": [date(2000, 1, 1)]})
assert df.select(pl.col("d").arr.concat(pl.col("d"))).to_dict(False) == expected
# type list[date]
df = pl.DataFrame({"d": [[date(2000, 1, 1)]]})
assert df.select(pl.col("d").arr.concat(pl.col("d"))).to_dict(False) == expected

0 comments on commit 8732182

Please sign in to comment.