Skip to content

Commit

Permalink
improve logical types in Lists (#2281)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 6, 2022
1 parent 81129eb commit e137974
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 16 deletions.
6 changes: 6 additions & 0 deletions polars/polars-core/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,10 @@ impl ListChunked {
_ => unreachable!(),
}
}

pub fn to_logical(&mut self, inner_dtype: DataType) {
assert_eq!(inner_dtype.to_physical(), self.inner_dtype());
let fld = Arc::make_mut(&mut self.field);
fld.coerce(DataType::List(Box::new(inner_dtype)))
}
}
18 changes: 13 additions & 5 deletions polars/polars-core/src/chunked_array/ops/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,20 @@ unsafe fn arr_to_any_value<'a>(
let v: ArrayRef = downcast!(LargeListArray).into();
let mut s = Series::try_from(("", v)).unwrap();

if let DataType::Categorical = **dt {
let mut s_new = s.cast(&DataType::Categorical).unwrap();
let ca: &mut CategoricalChunked = s_new.get_inner_mut().as_mut();
ca.categorical_map = categorical_map.clone();
s = s_new;
match **dt {
DataType::Categorical => {
let mut s_new = s.cast(&DataType::Categorical).unwrap();
let ca: &mut CategoricalChunked = s_new.get_inner_mut().as_mut();
ca.categorical_map = categorical_map.clone();
s = s_new;
}
DataType::Date
| DataType::Datetime(_, _)
| DataType::Time
| DataType::Duration(_) => s = s.cast(dt).unwrap(),
_ => {}
}

AnyValue::List(s)
}
#[cfg(feature = "dtype-categorical")]
Expand Down
25 changes: 19 additions & 6 deletions polars/polars-core/src/chunked_array/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,19 +249,32 @@ impl ChunkExplode for ListChunked {
last = o;
}
if !has_empty {
panic!()
panic!("could have fast exploded")
}
}

let values = Series::try_from((self.name(), values)).unwrap();
values.explode_by_offsets(offsets)
};
debug_assert_eq!(s.name(), self.name());
if let DataType::Categorical = self.inner_dtype() {
let ca = s.u32().unwrap();
let mut ca = ca.clone();
ca.categorical_map = self.categorical_map.clone();
s = ca.cast(&DataType::Categorical)?;
// make sure we restore the logical type
match self.inner_dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical => {
let ca = s.u32().unwrap();
let mut ca = ca.clone();
ca.categorical_map = self.categorical_map.clone();
s = ca.cast(&DataType::Categorical)?;
}
#[cfg(feature = "dtype-date")]
DataType::Date => s = s.into_date(),
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, tz) => s = s.into_datetime(tu, tz),
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => s = s.into_duration(tu),
#[cfg(feature = "dtype-time")]
DataType::Time => s = s.into_time(),
_ => {}
}

Ok((s, offsets_buf))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ mod test {
}

#[test]
#[cfg(feature = "dtype-datetime")]
fn test_arithmetic_dispatch() {
let s = Int64Chunked::new("", &[1, 2, 3])
.into_datetime(TimeUnit::Nanoseconds, None)
Expand Down
31 changes: 26 additions & 5 deletions polars/polars-core/src/series/ops/to_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@ impl Series {
let values = &s.chunks()[0];

let offsets = vec![0i64, values.len() as i64];
let inner_type = self.dtype();

let data_type = ListArray::<i64>::default_datatype(self.dtype().to_arrow());
let data_type = ListArray::<i64>::default_datatype(inner_type.to_physical().to_arrow());

let arr = ListArray::from_data(data_type, offsets.into(), values.clone(), None);
let name = self.name();

Ok(ListChunked::new_from_chunks(
self.name(),
vec![Arc::new(arr)],
))
let mut ca = ListChunked::new_from_chunks(name, vec![Arc::new(arr)]);
if self.dtype() != &self.dtype().to_physical() {
ca.to_logical(inner_type.clone())
}
ca.set_fast_explode();

Ok(ca)
}

pub fn reshape(&self, dims: &[i64]) -> Result<Series> {
Expand Down Expand Up @@ -117,6 +122,22 @@ mod test {
Ok(())
}

#[test]
#[cfg(all(feature = "temporal", feature = "dtype-date"))]
fn test_to_list_logical() -> Result<()> {
let ca = Utf8Chunked::new("a", &["2021-01-01", "2021-01-02", "2021-01-03"]);
let out = ca.as_date(None)?.into_series();
let out = out.to_list().unwrap();
assert_eq!(out.len(), 1);
let s = format!("{:?}", out);
// check if dtype is maintained all the way to formatting
assert!(s.contains("[2021-01-01, 2021-01-02, 2021-01-03]"));

let expl = out.explode().unwrap();
assert_eq!(expl.dtype(), &DataType::Date);
Ok(())
}

#[test]
fn test_reshape() -> Result<()> {
let s = Series::new("a", &[1, 2, 3, 4]);
Expand Down

0 comments on commit e137974

Please sign in to comment.