Skip to content

Commit

Permalink
add conversion between categorical <-> arrow dictionary
Browse files Browse the repository at this point in the history
This allows for round trips between polars
categorical and arrow. This means
that serde to/from parquet and IPC now
also works as expected.
  • Loading branch information
ritchie46 committed Sep 17, 2021
1 parent fce369f commit 06448f6
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 1 deletion.
55 changes: 55 additions & 0 deletions polars/polars-core/src/chunked_array/categorical/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use crate::chunked_array::RevMapping;
use crate::prelude::*;
use arrow::array::DictionaryArray;

impl From<&CategoricalChunked> for DictionaryArray<u32> {
fn from(ca: &CategoricalChunked) -> Self {
let ca = ca.rechunk();
let keys = ca.downcast_iter().next().unwrap();
let map = &**ca.categorical_map.as_ref().unwrap();
match map {
RevMapping::Local(arr) => {
DictionaryArray::from_data(keys.clone(), Arc::new(arr.clone()))
}
RevMapping::Global(reverse_map, values, _uuid) => {
let iter = keys
.into_iter()
.map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap()));
let keys = PrimitiveArray::from_trusted_len_iter(iter);

DictionaryArray::from_data(keys, Arc::new(values.clone()))
}
}
}
}

#[cfg(test)]
mod test {
use super::*;
use crate::{reset_string_cache, SINGLE_LOCK};
use std::convert::TryFrom;

#[test]
fn test_categorical_round_trip() -> Result<()> {
let _lock = SINGLE_LOCK.lock();
reset_string_cache();
let slice = &[
Some("foo"),
None,
Some("bar"),
Some("foo"),
Some("foo"),
Some("bar"),
];
let ca = Utf8Chunked::new_from_opt_slice("a", slice);
let ca = ca.cast::<CategoricalType>()?;

let arr: DictionaryArray<u32> = (&ca).into();
let s = Series::try_from(("foo", Arc::new(arr) as ArrayRef))?;
assert_eq!(s.dtype(), &DataType::Categorical);
assert_eq!(s.null_count(), 1);
assert_eq!(s.len(), 6);

Ok(())
}
}
3 changes: 3 additions & 0 deletions polars/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ pub mod temporal;
mod trusted_len;
pub mod upstream_traits;
use arrow::array::Array;
#[cfg(feature = "dtype-categorical")]
pub(crate) mod categorical;
pub(crate) mod list;

use polars_arrow::prelude::*;

#[cfg(feature = "dtype-categorical")]
Expand Down
10 changes: 9 additions & 1 deletion polars/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1789,7 +1789,15 @@ impl<'a> Iterator for RecordBatchIter<'a> {
let batch_cols = self
.columns
.iter()
.map(|s| s.chunks()[self.idx].clone())
.map(|s| {
#[cfg(feature = "dtype-categorical")]
if let DataType::Categorical = s.dtype() {
let ca = s.categorical().unwrap();
let arr: DictionaryArray<u32> = ca.into();
return Arc::new(arr) as ArrayRef;
}
s.chunks()[self.idx].clone()
})
.collect();
self.idx += 1;

Expand Down
29 changes: 29 additions & 0 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,35 @@ impl std::convert::TryFrom<(&str, Vec<ArrayRef>)> for Series {
.collect();
Ok(Date64Chunked::new_from_chunks(name, chunks).into_series())
}
#[cfg(feature = "dtype-categorical")]
ArrowDataType::Dictionary(key_type, value_type) => {
use crate::chunked_array::builder::CategoricalChunkedBuilder;
match (&**key_type, &**value_type) {
(ArrowDataType::UInt32, ArrowDataType::LargeUtf8) => {
let chunks = chunks.iter().map(|arr| &**arr).collect::<Vec<_>>();
let arr = arrow::compute::concat::concatenate(&chunks)?;

let arr = arr.as_any().downcast_ref::<DictionaryArray<u32>>().unwrap();
let keys = arr.keys();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();

let mut builder = CategoricalChunkedBuilder::new(name, keys.len());
let iter = keys.into_iter().map(|opt_key| {
opt_key.map(|k| unsafe { values.value_unchecked(*k as usize) })
});
builder.from_iter(iter);
Ok(builder.finish().into())
}
(k, v) => Err(PolarsError::InvalidOperation(
format!(
"Cannot create polars series dictionary type of key: {:?} value: {:?}",
k, v
)
.into(),
)),
}
}
dt => Err(PolarsError::InvalidOperation(
format!("Cannot create polars series from {:?} type", dt).into(),
)),
Expand Down

0 comments on commit 06448f6

Please sign in to comment.