Skip to content

Commit

Permalink
fix(rust, python): properly deal with categoricals in streaming queri…
Browse files Browse the repository at this point in the history
…es (#5974)
  • Loading branch information
ritchie46 committed Jan 1, 2023
1 parent f8f9a7c commit 96cfab4
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 35 deletions.
15 changes: 12 additions & 3 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,18 @@ where
fn cast_impl(&self, data_type: &DataType, checked: bool) -> PolarsResult<Series> {
match data_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => Err(PolarsError::ComputeError(
"Cannot cast numeric types to 'Categorical'".into(),
)),
DataType::Categorical(_) => {
if self.dtype() == &DataType::UInt32 {
// safety:
// we are guarded by the type system.
let ca = unsafe { &*(self as *const ChunkedArray<T> as *const UInt32Chunked) };
CategoricalChunked::from_global_indices(ca.clone()).map(|ca| ca.into_series())
} else {
Err(PolarsError::ComputeError(
"Cannot cast numeric types to 'Categorical'".into(),
))
}
}
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => {
// cast to first field dtype
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::hash::{Hash, Hasher};

use arrow::array::*;
use hashbrown::hash_map::RawEntryMut;
use hashbrown::hash_map::{Entry, RawEntryMut};
use polars_arrow::trusted_len::PushUnchecked;
use polars_utils::HashSingle;

use crate::datatypes::PlHashMap;
use crate::error::PolarsError::ComputeError;
use crate::frame::groupby::hashing::HASHMAP_INIT_SIZE;
use crate::prelude::*;
use crate::{using_string_cache, StringCache, POOL};
Expand Down Expand Up @@ -414,6 +415,59 @@ fn fill_global_to_local(local_to_global: &[u32], global_to_local: &mut PlHashMap
}
}

impl CategoricalChunked {
/// Create a [`CategoricalChunked`] from a categorical indices. The indices will
/// probe the global string cache.
pub(crate) fn from_global_indices(cats: UInt32Chunked) -> PolarsResult<CategoricalChunked> {
let cache = crate::STRING_CACHE.read_map();
let len = cache.len() as u32;
drop(cache);
let mut oob = false;

// fastest happy path
for cat in cats.into_iter().flatten() {
if cat >= len {
oob = true
}
}

if oob {
return Err(ComputeError("Cannot construct 'Categorical' from these categories. At least on of them is out of bounds.".into()));
}
Ok(unsafe { Self::from_global_indices_unchecked(cats) })
}

/// Create a [`CategoricalChunked`] from a categorical indices. The indices will
/// probe the global string cache.
///
/// # Safety
///
/// This does not do any bound checks
pub unsafe fn from_global_indices_unchecked(cats: UInt32Chunked) -> CategoricalChunked {
let cache = crate::STRING_CACHE.read_map();

let cap = std::cmp::min(std::cmp::min(cats.len(), cache.len()), HASHMAP_INIT_SIZE);
let mut rev_map = PlHashMap::with_capacity(cap);
let mut str_values = MutableUtf8Array::with_capacities(cap, cap * 24);

for arr in cats.downcast_iter() {
for cat in arr.into_iter().flatten().copied() {
let offset = str_values.len() as u32;

if let Entry::Vacant(entry) = rev_map.entry(cat) {
entry.insert(offset);
let str_val = cache.get_unchecked(cat);
str_values.push(Some(str_val))
}
}
}

let rev_map = RevMapping::Global(rev_map, str_values.into(), cache.uuid);

CategoricalChunked::from_cats_and_rev_map_unchecked(cats, Arc::new(rev_map))
}
}

#[cfg(test)]
mod test {
use crate::chunked_array::categorical::CategoricalChunkedBuilder;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Mutex, MutexGuard};
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::{SystemTime, UNIX_EPOCH};

use ahash::RandomState;
Expand Down Expand Up @@ -106,6 +106,11 @@ pub(crate) struct SCacheInner {
}

impl SCacheInner {
#[inline]
pub(crate) unsafe fn get_unchecked(&self, cat: u32) -> &str {
self.payloads.get_unchecked(cat as usize).as_str()
}

pub(crate) fn len(&self) -> usize {
self.map.len()
}
Expand Down Expand Up @@ -166,7 +171,8 @@ impl Default for SCacheInner {
/// In *eager* you need to specifically toggle global string cache to have a global effect.
/// In *lazy* it is toggled on at the start of a computation run and turned of (deleted) when a
/// result is produced.
pub(crate) struct StringCache(pub(crate) Mutex<SCacheInner>);
#[derive(Default)]
pub(crate) struct StringCache(pub(crate) RwLock<SCacheInner>);

impl StringCache {
/// The global `StringCache` will always use a predictable seed. This allows local builders to mimic
Expand All @@ -176,8 +182,12 @@ impl StringCache {
}

/// Lock the string cache
pub(crate) fn lock_map(&self) -> MutexGuard<SCacheInner> {
self.0.lock().unwrap()
pub(crate) fn lock_map(&self) -> RwLockWriteGuard<SCacheInner> {
self.0.write().unwrap()
}

pub(crate) fn read_map(&self) -> RwLockReadGuard<SCacheInner> {
self.0.read().unwrap()
}

pub(crate) fn clear(&self) {
Expand All @@ -186,12 +196,6 @@ impl StringCache {
}
}

impl Default for StringCache {
fn default() -> Self {
StringCache(Mutex::new(Default::default()))
}
}

pub(crate) static STRING_CACHE: Lazy<StringCache> = Lazy::new(Default::default);

type StrHashGlobal = SmartString<LazyCompact>;
30 changes: 24 additions & 6 deletions polars/polars-lazy/polars-pipe/src/executors/sinks/groupby/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ mod utils;

pub(crate) use generic::*;
use polars_core::prelude::*;
#[cfg(feature = "dtype-categorical")]
use polars_core::using_string_cache;
pub(crate) use primitive::*;
pub(crate) use string::*;

Expand All @@ -16,13 +18,29 @@ pub(super) fn physical_agg_to_logical(cols: &mut [Series], output_schema: &Schem
}
match dtype {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(rev_map)) => {
let cats = s.u32().unwrap().clone();
// safety:
// the rev-map comes from these categoricals
unsafe {
*s = CategoricalChunked::from_cats_and_rev_map_unchecked(cats, rev_map.clone())
DataType::Categorical(rev_map) => {
if let Some(rev_map) = rev_map {
let cats = s.u32().unwrap().clone();
// safety:
// the rev-map comes from these categoricals
unsafe {
*s = CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
rev_map.clone(),
)
.into_series()
}
} else {
let cats = s.u32().unwrap().clone();
if using_string_cache() {
// Safety, we go from logical to primitive back to logical so the categoricals should still match the global map.
*s = unsafe {
CategoricalChunked::from_global_indices_unchecked(cats).into_series()
};
} else {
// we set the global string cache once we start a streaming pipeline
unreachable!()
}
}
}
_ => {
Expand Down
16 changes: 15 additions & 1 deletion polars/polars-lazy/polars-plan/src/logical_plan/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::fmt::{Debug, Display, Formatter};
use std::sync::Arc;

use polars_core::prelude::*;
#[cfg(feature = "dtype-categorical")]
use polars_core::IUseStringCache;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -181,7 +183,19 @@ impl FunctionNode {
panic!("activate feature 'dtype-struct'")
}
}
Pipeline { function, .. } => Arc::get_mut(function).unwrap().call_udf(df),
Pipeline { function, .. } => {
// we use a global string cache here as streaming chunks all have different rev maps
#[cfg(feature = "dtype-categorical")]
{
let _hold = IUseStringCache::new();
Arc::get_mut(function).unwrap().call_udf(df)
}

#[cfg(not(feature = "dtype-categorical"))]
{
Arc::get_mut(function).unwrap().call_udf(df)
}
}
}
}
}
Expand Down
19 changes: 5 additions & 14 deletions polars/polars-lazy/src/physical_plan/streaming/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,20 +45,11 @@ fn is_streamable(node: Node, expr_arena: &Arena<AExpr>) -> bool {
AExpr::Function { options, .. } | AExpr::AnonymousFunction { options, .. } => {
matches!(options.collect_groups, ApplyOptions::ApplyFlat)
}
AExpr::Column(_) | AExpr::Literal(_) | AExpr::BinaryExpr { .. } | AExpr::Alias(_, _) => {
true
}
AExpr::Cast { data_type, .. } => {
// a Categorical's indices are bound to the rev map
// in its data type and streaming different chunks
// will create different rev-maps that are hard
// to combine in a streaming sort, join, groupby
match data_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => false,
_ => true,
}
}
AExpr::Column(_)
| AExpr::Literal(_)
| AExpr::BinaryExpr { .. }
| AExpr::Alias(_, _)
| AExpr::Cast { .. } => true,
_ => false,
})
}
Expand Down
17 changes: 17 additions & 0 deletions py-polars/tests/unit/io/test_lazy_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,3 +158,20 @@ def test_parquet_statistics(io_test_dir: str, capfd: CaptureFixture[str]) -> Non
"parquet file can be skipped, the statistics were sufficient"
" to apply the predicate." in captured
)


def test_streaming_categorical() -> None:
if os.name != "nt":
pl.DataFrame(
[
pl.Series("name", ["Bob", "Alice", "Bob"], pl.Categorical),
pl.Series("amount", [100, 200, 300]),
]
).write_parquet("/tmp/tmp.pq")
with pl.StringCache():
assert pl.scan_parquet("/tmp/tmp.pq").groupby("name").agg(
pl.col("amount").sum()
).collect().to_dict(False) == {
"name": ["Bob", "Alice"],
"amount": [400, 200],
}

0 comments on commit 96cfab4

Please sign in to comment.