Skip to content

Commit

Permalink
fix(rust, python): fix categorical in struct anyvalue issue (#5987)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 2, 2023
1 parent 1b04b36 commit db0c741
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod stringcache;
pub use builder::*;
pub(crate) use merge::*;
pub(crate) use ops::{CategoricalTakeRandomGlobal, CategoricalTakeRandomLocal};
use polars_utils::sync::SyncPtr;

use super::*;
use crate::prelude::*;
Expand Down Expand Up @@ -147,7 +148,7 @@ impl LogicalType for CategoricalChunked {

unsafe fn get_any_value_unchecked(&self, i: usize) -> AnyValue<'_> {
match self.logical.0.get_unchecked(i) {
Some(i) => AnyValue::Categorical(i, self.get_rev_map()),
Some(i) => AnyValue::Categorical(i, self.get_rev_map(), SyncPtr::new_null()),
None => AnyValue::Null,
}
}
Expand Down Expand Up @@ -295,7 +296,7 @@ mod test {
);
assert!(matches!(
s.get(0)?,
AnyValue::Categorical(0, RevMapping::Local(_))
AnyValue::Categorical(0, RevMapping::Local(_), _)
));

let groups = s.group_tuples(false, true);
Expand Down
33 changes: 26 additions & 7 deletions polars/polars-core/src/chunked_array/ops/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use std::convert::TryFrom;

#[cfg(feature = "dtype-categorical")]
use polars_utils::sync::SyncPtr;

#[cfg(feature = "object")]
use crate::chunked_array::object::extension::polars_extension::PolarsExtension;
use crate::prelude::*;
Expand Down Expand Up @@ -70,7 +73,7 @@ pub(crate) unsafe fn arr_to_any_value<'a>(
DataType::Categorical(rev_map) => {
let arr = &*(arr as *const dyn Array as *const UInt32Array);
let v = arr.value_unchecked(idx);
AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref())
AnyValue::Categorical(v, rev_map.as_ref().unwrap().as_ref(), SyncPtr::new_null())
}
#[cfg(feature = "dtype-struct")]
DataType::Struct(flds) => {
Expand Down Expand Up @@ -120,12 +123,28 @@ impl<'a> AnyValue<'a> {
let idx = *idx;
unsafe {
arr.values().iter().zip(*flds).map(move |(arr, fld)| {
// TODO! this is hacky. Investigate if we only should put physical types
// into structs
if let Some(arr) = arr.as_any().downcast_ref::<DictionaryArray<u32>>() {
let keys = arr.keys();
arr_to_any_value(keys, idx, fld.data_type())
} else {
// The dictionary arrays categories don't have to map to the rev-map in the dtype
// so we set the array pointer with values of the dictionary array.
#[cfg(feature = "dtype-categorical")]
{
if let Some(arr) = arr.as_any().downcast_ref::<DictionaryArray<u32>>() {
let keys = arr.keys();
let values = arr.values();
let values =
values.as_any().downcast_ref::<Utf8Array<i64>>().unwrap();
let arr = &*(keys as *const dyn Array as *const UInt32Array);
let v = arr.value_unchecked(idx);
let DataType::Categorical(Some(rev_map)) = fld.data_type() else {
unimplemented!()
};
AnyValue::Categorical(v, rev_map, SyncPtr::from_const(values))
} else {
arr_to_any_value(&**arr, idx, fld.data_type())
}
}

#[cfg(not(feature = "dtype-categorical"))]
{
arr_to_any_value(&**arr, idx, fld.data_type())
}
})
Expand Down
10 changes: 7 additions & 3 deletions polars/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use arrow::types::PrimitiveType;
#[cfg(feature = "dtype-categorical")]
use polars_utils::sync::SyncPtr;
use polars_utils::unwrap::UnwrapUncheckedRelease;

use super::*;
Expand Down Expand Up @@ -56,7 +58,9 @@ pub enum AnyValue<'a> {
#[cfg(feature = "dtype-time")]
Time(i64),
#[cfg(feature = "dtype-categorical")]
Categorical(u32, &'a RevMapping),
// If syncptr is_null the data is in the rev-map
// otherwise it is in the array pointer
Categorical(u32, &'a RevMapping, SyncPtr<Utf8Array<i64>>),
/// Nested type, contains arrays that are filled with one of the datetypes.
List(Series),
#[cfg(feature = "object")]
Expand Down Expand Up @@ -357,7 +361,7 @@ impl<'a> AnyValue<'a> {
Boolean(_) => DataType::Boolean,
Utf8(_) => DataType::Utf8,
#[cfg(feature = "dtype-categorical")]
Categorical(_, _) => DataType::Categorical(None),
Categorical(_, _, _) => DataType::Categorical(None),
List(s) => DataType::List(Box::new(s.dtype().clone())),
#[cfg(feature = "dtype-struct")]
Struct(_, _, fields) => DataType::Struct(fields.to_vec()),
Expand Down Expand Up @@ -616,7 +620,7 @@ impl PartialEq for AnyValue<'_> {
// should it?
(Null, Null) => true,
#[cfg(feature = "dtype-categorical")]
(Categorical(idx_l, rev_l), Categorical(idx_r, rev_r)) => match (rev_l, rev_r) {
(Categorical(idx_l, rev_l, _), Categorical(idx_r, rev_r, _)) => match (rev_l, rev_r) {
(RevMapping::Global(_, _, id_l), RevMapping::Global(_, _, id_r)) => {
id_l == id_r && idx_l == idx_r
}
Expand Down
8 changes: 6 additions & 2 deletions polars/polars-core/src/fmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,8 +746,12 @@ impl Display for AnyValue<'_> {
write!(f, "{nt}")
}
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(idx, rev) => {
let s = rev.get(*idx);
AnyValue::Categorical(idx, rev, arr) => {
let s = if arr.is_null() {
rev.get(*idx)
} else {
unsafe { arr.deref_unchecked().value(*idx as usize) }
};
write!(f, "\"{s}\"")
}
AnyValue::List(s) => write!(f, "{}", s.fmt_list()),
Expand Down
10 changes: 9 additions & 1 deletion polars/polars-core/src/series/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,15 @@ impl<'a> From<&AnyValue<'a>> for DataType {
Int8(_) => DataType::Int8,
Int16(_) => DataType::Int16,
#[cfg(feature = "dtype-categorical")]
Categorical(_, rev_map) => DataType::Categorical(Some(Arc::new((*rev_map).clone()))),
Categorical(_, rev_map, arr) => {
if arr.is_null() {
DataType::Categorical(Some(Arc::new((*rev_map).clone())))
} else {
let array = unsafe { arr.deref_unchecked().clone() };
let rev_map = RevMapping::Local(array);
DataType::Categorical(Some(Arc::new(rev_map)))
}
}
#[cfg(feature = "object")]
Object(o) => DataType::Object(o.type_name()),
#[cfg(feature = "object")]
Expand Down
8 changes: 7 additions & 1 deletion polars/polars-core/src/series/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,13 @@ impl Series {
AnyValue::Utf8(s) => Cow::Borrowed(s),
AnyValue::Null => Cow::Borrowed("null"),
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(idx, rev) => Cow::Borrowed(rev.get(idx)),
AnyValue::Categorical(idx, rev, arr) => {
if arr.is_null() {
Cow::Borrowed(rev.get(idx))
} else {
unsafe { Cow::Borrowed(arr.deref_unchecked().value(idx as usize)) }
}
}
av => Cow::Owned(format!("{av}")),
};
Ok(out)
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-io/src/csv/write_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn write_anyvalue(f: &mut Vec<u8>, value: AnyValue, options: &SerializeOptions)
AnyValue::Boolean(v) => write!(f, "{v}"),
AnyValue::Utf8(v) => fmt_and_escape_str(f, v, options),
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(idx, rev_map) => {
AnyValue::Categorical(idx, rev_map, _) => {
let v = rev_map.get(idx);
fmt_and_escape_str(f, v, options)
}
Expand Down
12 changes: 11 additions & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,17 @@ impl TryFrom<AnyValue<'_>> for LiteralValue {
AnyValue::List(l) => Ok(Self::Series(SpecialEq::new(l))),
AnyValue::Utf8Owned(o) => Ok(Self::Utf8(o.into())),
#[cfg(feature = "dtype-categorical")]
AnyValue::Categorical(c, rev_mapping) => Ok(Self::Utf8(rev_mapping.get(c).to_string())),
AnyValue::Categorical(c, rev_mapping, arr) => {
if arr.is_null() {
Ok(Self::Utf8(rev_mapping.get(c).to_string()))
} else {
unsafe {
Ok(Self::Utf8(
arr.deref_unchecked().value(c as usize).to_string(),
))
}
}
}
_ => Err(PolarsError::ComputeError(
"Unsupported AnyValue type variant, cannot convert to Literal".into(),
)),
Expand Down
25 changes: 24 additions & 1 deletion polars/polars-utils/src/sync.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
/// Utility that allows use to send pointers to another thread.
/// This is better than going through `usize` as MIRI can follow these.
#[derive(Copy, Clone)]
#[derive(Copy, Clone, Debug)]
#[repr(transparent)]
pub struct SyncPtr<T>(*mut T);

impl<T> SyncPtr<T> {
Expand All @@ -12,10 +13,32 @@ impl<T> SyncPtr<T> {
Self(ptr)
}

/// # Safety
///
/// This will make a pointer sync and send.
/// Ensure that you don't break aliasing rules.
pub unsafe fn from_const(ptr: *const T) -> Self {
Self(ptr as *mut T)
}

pub fn new_null() -> Self {
Self(std::ptr::null_mut())
}

#[inline(always)]
pub fn get(self) -> *mut T {
self.0
}

pub fn is_null(&self) -> bool {
self.0.is_null()
}

/// # Safety
/// Derefs a raw pointer, no guarantees whatsoever.
pub unsafe fn deref_unchecked(&self) -> &'static T {
&*(self.0 as *const T)
}
}

unsafe impl<T> Sync for SyncPtr<T> {}
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,12 @@ impl IntoPy<PyObject> for Wrap<AnyValue<'_>> {
AnyValue::Boolean(v) => v.into_py(py),
AnyValue::Utf8(v) => v.into_py(py),
AnyValue::Utf8Owned(v) => v.into_py(py),
AnyValue::Categorical(idx, rev) => {
let s = rev.get(idx);
AnyValue::Categorical(idx, rev, arr) => {
let s = if arr.is_null() {
rev.get(idx)
} else {
unsafe { arr.deref_unchecked().value(idx as usize) }
};
s.into_py(py)
}
AnyValue::Date(v) => {
Expand Down
14 changes: 14 additions & 0 deletions py-polars/tests/unit/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,17 @@ def test_streaming_categorical() -> None:
"name": ["Bob", "Alice"],
"amount": [400, 200],
}


def test_parquet_struct_categorical() -> None:
if os.name != "nt":
df = pl.DataFrame(
[
pl.Series("a", ["bob"], pl.Categorical),
pl.Series("b", ["foo"], pl.Categorical),
]
)
df.write_parquet("/tmp/tmp.pq")
with pl.StringCache():
out = pl.read_parquet("/tmp/tmp.pq").select(pl.col("b").value_counts())
assert out.to_dict(False) == {"b": [{"b": "foo", "counts": 1}]}

0 comments on commit db0c741

Please sign in to comment.