Skip to content

Commit

Permalink
recursively convert arrow logical types in to_arrow (#4067)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 18, 2022
1 parent d707480 commit 083745a
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
6 changes: 1 addition & 5 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl Series {
let chunks = cast_chunks(&chunks, &DataType::Utf8, false).unwrap();
Ok(Utf8Chunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::List(_) => {
ArrowDataType::List(_) | ArrowDataType::LargeList(_) => {
let chunks = chunks.iter().map(convert_inner_types).collect();
Ok(ListChunked::from_chunks(name, chunks).into_series())
}
Expand Down Expand Up @@ -177,10 +177,6 @@ impl Series {
ArrowTimeUnit::Nanosecond => s,
})
}
ArrowDataType::LargeList(_) => {
let chunks = chunks.iter().map(convert_inner_types).collect();
Ok(ListChunked::from_chunks(name, chunks).into_series())
}
ArrowDataType::Null => {
// we don't support null types yet so we use a small digit type filled with nulls
let len = chunks.iter().fold(0, |acc, array| acc + array.len());
Expand Down
23 changes: 23 additions & 0 deletions polars/polars-core/src/series/into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,29 @@ impl Series {
/// 1 on 1 mapping for logical/ categoricals, etc.
pub fn to_arrow(&self, chunk_idx: usize) -> ArrayRef {
match self.dtype() {
// special list branch to
// make sure that we recursively apply all logical types.
DataType::List(inner) => {
let ca = self.list().unwrap();
let arr = ca.chunks[chunk_idx].clone();
let arr = arr.as_any().downcast_ref::<ListArray<i64>>().unwrap();

let s = unsafe {
Series::from_chunks_and_dtype_unchecked("", vec![arr.values().clone()], inner)
};
let new_values = s.to_arrow(0);

let data_type = ListArray::<i64>::default_datatype(inner.to_arrow());
let arr = unsafe {
ListArray::<i64>::new_unchecked(
data_type,
arr.offsets().clone(),
new_values,
arr.validity().cloned(),
)
};
Box::new(arr)
}
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
let ca = self.categorical().unwrap();
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,16 @@ def test_lazy_self_join_file_cache_prop_3979(io_test_dir: str) -> None:

assert a.join(b, how="cross").collect().shape == (3, 17)
assert b.join(a, how="cross").collect().shape == (3, 17)


def recursive_logical_type() -> None:
df = pl.DataFrame({"str": ["A", "B", "A", "B", "C"], "group": [1, 1, 2, 1, 2]})
df = df.with_column(pl.col("str").cast(pl.Categorical))

df_groups = df.groupby("group").agg([pl.col("str").list().alias("cat_list")])
f = io.BytesIO()
df_groups.write_parquet(f, use_pyarrow=True)
f.seek(0)
read = pl.read_parquet(f, use_pyarrow=True)
assert read.dtypes == [pl.Int64, pl.List(pl.Categorical)]
assert read.shape == (2, 2)

0 comments on commit 083745a

Please sign in to comment.