Skip to content

Commit

Permalink
expose categorical round trip to python and add more dictonary types: c…
Browse files Browse the repository at this point in the history
…loses #1308
  • Loading branch information
ritchie46 committed Sep 20, 2021
1 parent 8eb817d commit 942caf2
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 40 deletions.
27 changes: 27 additions & 0 deletions polars/polars-core/src/chunked_array/categorical/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::chunked_array::RevMapping;
use crate::prelude::*;
use arrow::array::DictionaryArray;
use arrow::compute::cast::cast;

impl From<&CategoricalChunked> for DictionaryArray<u32> {
fn from(ca: &CategoricalChunked) -> Self {
Expand All @@ -22,6 +23,32 @@ impl From<&CategoricalChunked> for DictionaryArray<u32> {
}
}
}
impl From<&CategoricalChunked> for DictionaryArray<i64> {
fn from(ca: &CategoricalChunked) -> Self {
let ca = ca.rechunk();
let keys = ca.downcast_iter().next().unwrap();
let map = &**ca.categorical_map.as_ref().unwrap();
match map {
RevMapping::Local(arr) => DictionaryArray::from_data(
cast(keys, &ArrowDataType::Int64)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<i64>>()
.unwrap()
.clone(),
Arc::new(arr.clone()),
),
RevMapping::Global(reverse_map, values, _uuid) => {
let iter = keys
.into_iter()
.map(|opt_k| opt_k.map(|k| *reverse_map.get(k).unwrap() as i64));
let keys = PrimitiveArray::from_trusted_len_iter(iter);

DictionaryArray::from_data(keys, Arc::new(values.clone()))
}
}
}
}

#[cfg(test)]
mod test {
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-core/src/datatypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,14 @@ impl Schema {
))),
true,
),
DataType::Categorical => ArrowField::new(
f.name(),
ArrowDataType::Dictionary(
Box::new(ArrowDataType::UInt32),
Box::new(ArrowDataType::LargeUtf8),
),
true,
),
_ => f.to_arrow(),
}
})
Expand Down
132 changes: 116 additions & 16 deletions polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1866,31 +1866,131 @@ impl std::convert::TryFrom<(&str, Vec<ArrayRef>)> for Series {
#[cfg(feature = "dtype-categorical")]
ArrowDataType::Dictionary(key_type, value_type) => {
use crate::chunked_array::builder::CategoricalChunkedBuilder;
match (&**key_type, &**value_type) {
(ArrowDataType::UInt32, ArrowDataType::LargeUtf8) => {
let chunks = chunks.iter().map(|arr| &**arr).collect::<Vec<_>>();
let arr = arrow::compute::concat::concatenate(&chunks)?;
use arrow::compute::cast::cast;
let chunks = chunks.iter().map(|arr| &**arr).collect::<Vec<_>>();
let arr = arrow::compute::concat::concatenate(&chunks)?;

let (keys, values) = match (&**key_type, &**value_type) {
(ArrowDataType::Int8, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i8>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
}
(ArrowDataType::Int16, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i16>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
}
(ArrowDataType::Int32, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i32>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
}
(ArrowDataType::UInt32, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<u32>>().unwrap();
let keys = arr.keys();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();

let mut builder = CategoricalChunkedBuilder::new(name, keys.len());
let iter = keys.into_iter().map(|opt_key| {
opt_key.map(|k| unsafe { values.value_unchecked(*k as usize) })
});
builder.from_iter(iter);
Ok(builder.finish().into())
(keys.clone(), values.clone())
}
(k, v) => Err(PolarsError::InvalidOperation(
format!(
(ArrowDataType::Int8, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i8>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
}
(ArrowDataType::Int16, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i16>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
}
(ArrowDataType::Int32, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i32>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
}
(k, v) => {
return Err(PolarsError::InvalidOperation(
format!(
"Cannot create polars series dictionary type of key: {:?} value: {:?}",
k, v
)
.into(),
)),
}
.into(),
))
}
};

let mut builder = CategoricalChunkedBuilder::new(name, keys.len());
let iter = keys
.into_iter()
.map(|opt_key| opt_key.map(|k| unsafe { values.value_unchecked(*k as usize) }));
builder.from_iter(iter);
Ok(builder.finish().into())
}
dt => Err(PolarsError::InvalidOperation(
format!("Cannot create polars series from {:?} type", dt).into(),
Expand Down
10 changes: 0 additions & 10 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,6 @@ def coerce_arrow(array: pa.Array) -> pa.Array:
elif isinstance(array.type, pa.Decimal128Type):
array = pa.compute.cast(array, pa.float64())

# simplest solution is to cast to (large)-string arrays
# this is copy and expensive
elif isinstance(array.type, pa.DictionaryType):
if pa.types.is_string(array.type.value_type):
array = pa.compute.cast(array, pa.large_utf8())
else:
raise ValueError(
"polars does not support dictionary encoded types other than strings"
)

if hasattr(array, "num_chunks") and array.num_chunks > 1:
# we have to coerce before combining chunks, because pyarrow panics if
# offsets overflow
Expand Down
14 changes: 1 addition & 13 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,19 +310,7 @@ impl PyDataFrame {
let py = gil.python();
let pyarrow = py.import("pyarrow")?;

// Arrow does not know about our categorical type implementation, so we cast to utf8
let cols = self
.df
.get_columns()
.iter()
.map(|s| match s.dtype() {
DataType::Categorical => s.cast::<Utf8Type>().unwrap(),
_ => s.clone(),
})
.collect::<Vec<_>>();
let df = DataFrame::new_no_checks(cols);

let rbs = df
let rbs = self.df
.iter_record_batches()
.map(|rb| arrow_interop::to_py::to_py_rb(&rb, py, pyarrow))
.collect::<PyResult<_>>()?;
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import polars as pl
Expand Down Expand Up @@ -237,3 +238,14 @@ def test_ipc_schema():
f.seek(0)

assert pl.read_ipc_schema(f) == {"a": pl.Int64, "b": pl.Utf8, "c": pl.Boolean}


def test_categorical_round_trip():
df = pl.DataFrame({"ints": [1, 2, 3], "cat": ["a", "b", "c"]})
df = df.with_column(pl.col("cat").cast(pl.Categorical))

tbl = df.to_arrow()
assert "dictionary" in str(tbl["cat"].type)

df = pl.from_arrow(tbl)
assert df.dtypes == [pl.Int64, pl.Categorical]
2 changes: 1 addition & 1 deletion py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_arrow():

a = pa.array(["foo", "bar"], pa.dictionary(pa.int32(), pa.utf8()))
s = pl.Series("a", a)
assert s.dtype == pl.Utf8
assert s.dtype == pl.Categorical
assert (
pl.from_arrow(pa.array([["foo"], ["foo", "bar"]], pa.list_(pa.utf8()))).dtype
== pl.List
Expand Down

0 comments on commit 942caf2

Please sign in to comment.