Skip to content

Commit

Permalink
support all arrow dictionary keys < 64 bit (#3508)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed May 26, 2022
1 parent e3a89a5 commit 487379a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 116 deletions.
152 changes: 36 additions & 116 deletions polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,133 +192,53 @@ impl Series {
let chunks = chunks.iter().map(|arr| &**arr).collect::<Vec<_>>();
let arr = arrow::compute::concatenate::concatenate(&chunks)?;

let (keys, values) = match (key_type, &**value_type) {
(IntegerType::Int8, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i8>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
}
(IntegerType::Int16, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i16>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
}
(IntegerType::Int32, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i32>>().unwrap();
if !matches!(
value_type.as_ref(),
ArrowDataType::Utf8 | ArrowDataType::LargeUtf8
) {
return Err(PolarsError::ComputeError(
"polars only support dictionaries with string like values".into(),
));
}

macro_rules! unpack_keys_values {
($dt:ty) => {{
let arr = arr.as_any().downcast_ref::<DictionaryArray<$dt>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let keys = cast(keys, &ArrowDataType::UInt32).unwrap();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
let values = cast(&**values, &ArrowDataType::LargeUtf8)?;
(keys, values)
}};
}

let (keys, values) = match key_type {
IntegerType::Int8 => {
unpack_keys_values!(i8)
}
(IntegerType::Int64, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i64>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys, values.clone())
IntegerType::UInt8 => {
unpack_keys_values!(u8)
}
(IntegerType::UInt32, ArrowDataType::LargeUtf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<u32>>().unwrap();
let keys = arr.keys();
let values = arr.values();
let values = values.as_any().downcast_ref::<LargeStringArray>().unwrap();
(keys.clone(), values.clone())
IntegerType::Int16 => {
unpack_keys_values!(i16)
}
(IntegerType::Int8, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i8>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
IntegerType::UInt16 => {
unpack_keys_values!(u16)
}
(IntegerType::Int16, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i16>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
IntegerType::Int32 => {
unpack_keys_values!(i32)
}
(IntegerType::Int32, ArrowDataType::Utf8) => {
let arr = arr.as_any().downcast_ref::<DictionaryArray<i32>>().unwrap();
let keys = arr.keys();
let keys = cast(keys, &ArrowDataType::UInt32)
.unwrap()
.as_any()
.downcast_ref::<PrimitiveArray<u32>>()
.unwrap()
.clone();
let values = arr.values();
let values = values.as_any().downcast_ref::<Utf8Array<i32>>().unwrap();
let values = cast(values, &ArrowDataType::LargeUtf8)
.unwrap()
.as_any()
.downcast_ref::<Utf8Array<i64>>()
.unwrap()
.clone();
(keys, values)
IntegerType::UInt32 => {
unpack_keys_values!(u32)
}
(k, v) => {
return Err(PolarsError::InvalidOperation(
format!(
"Cannot create polars series dictionary type of key: {:?} value: {:?}",
k, v
)
.into(),
_ => {
return Err(PolarsError::ComputeError(
"dictionaries with 64 bits keys are not supported by polars".into(),
))
}
};
let keys = keys.as_any().downcast_ref::<PrimitiveArray<u32>>().unwrap();
let values = values.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();

let mut builder = CategoricalChunkedBuilder::new(name, keys.len());
let iter = keys
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,21 @@ def test_pandas_string_none_conversion_3298() -> None:
df_pd = pd.DataFrame(data)
df_pl = pl.DataFrame(df_pd)
assert df_pl.to_series().to_list() == [None, "b", "c", "d"]


def test_cat_int_types_3500() -> None:
with pl.StringCache():
# Create an enum / categorical / dictionary typed pyarrow array
# Most simply done by creating a pandas categorical series first
categorical_df = pd.Series(["a", "a", "b"], dtype="category")
pyarrow_array = pa.Array.from_pandas(categorical_df)

# The in-memory representation of each category can either be a signed or unsigned 8-bit integer
# Pandas uses Int8...
int_dict_type = pa.dictionary(index_type=pa.int8(), value_type=pa.utf8())
# ... while DuckDB uses UInt8
uint_dict_type = pa.dictionary(index_type=pa.uint8(), value_type=pa.utf8())

for t in [int_dict_type, uint_dict_type]:
s = pl.from_arrow(pyarrow_array.cast(t))
assert s.series_equal(pl.Series(["a", "a", "b"]).cast(pl.Categorical))

0 comments on commit 487379a

Please sign in to comment.