Skip to content

Commit

Permalink
fix[rust]: always use global string cache when parsing categoricals f…
Browse files Browse the repository at this point in the history
…rom csv (#5001)
  • Loading branch information
ritchie46 committed Sep 27, 2022
1 parent 893215a commit 4726805
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 25 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::borrow::Borrow;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Mutex, MutexGuard};
use std::time::{SystemTime, UNIX_EPOCH};

Expand All @@ -11,7 +11,33 @@ use smartstring::{LazyCompact, SmartString};
use crate::frame::groupby::hashing::HASHMAP_INIT_SIZE;
use crate::prelude::PlHashMap;

pub(crate) static USE_STRING_CACHE: AtomicBool = AtomicBool::new(false);
/// We use atomic reference counting
/// to determine how many threads use the string cache
/// if the refcount is zero, we may clear the string cache.
pub(crate) static USE_STRING_CACHE: AtomicU32 = AtomicU32::new(0);

/// RAII for the string cache
pub struct IUseStringCache {}

impl Default for IUseStringCache {
fn default() -> Self {
Self::new()
}
}

impl IUseStringCache {
/// Hold the StringCache
pub fn new() -> IUseStringCache {
toggle_string_cache(true);
IUseStringCache {}
}
}

impl Drop for IUseStringCache {
fn drop(&mut self) {
toggle_string_cache(false)
}
}

pub fn with_string_cache<F: FnOnce() -> T, T>(func: F) -> T {
toggle_string_cache(true);
Expand All @@ -25,20 +51,26 @@ pub fn with_string_cache<F: FnOnce() -> T, T>(func: F) -> T {
/// This is used to cache the string categories locally.
/// This allows join operations on categorical types.
pub fn toggle_string_cache(toggle: bool) {
USE_STRING_CACHE.store(toggle, Ordering::Release);
if !toggle {
STRING_CACHE.clear()
if toggle {
USE_STRING_CACHE.fetch_add(1, Ordering::Release);
} else {
let previous = USE_STRING_CACHE.fetch_sub(1, Ordering::Release);
if previous == 0 || previous == 1 {
USE_STRING_CACHE.store(0, Ordering::Release);
STRING_CACHE.clear()
}
}
}

/// Reset the global string cache used for the Categorical Types.
pub fn reset_string_cache() {
USE_STRING_CACHE.store(0, Ordering::Release);
STRING_CACHE.clear()
}

/// Check if string cache is set.
pub fn using_string_cache() -> bool {
USE_STRING_CACHE.load(Ordering::Acquire)
USE_STRING_CACHE.load(Ordering::Acquire) > 0
}

pub(crate) struct SCacheInner {
Expand Down
8 changes: 4 additions & 4 deletions polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1003,6 +1003,8 @@ impl DataFrame {
mod test {
use crate::df;
use crate::prelude::*;
#[cfg(feature = "dtype-categorical")]
use crate::{reset_string_cache, IUseStringCache};

fn create_frames() -> (DataFrame, DataFrame) {
let s0 = Series::new("days", &[0, 1, 2]);
Expand Down Expand Up @@ -1198,14 +1200,12 @@ mod test {
.join(&df_b, ["a", "b"], ["foo", "bar"], JoinType::Left, None)
.unwrap();
let ca = joined.column("ham").unwrap().utf8().unwrap();
dbg!(&df_a, &df_b);
assert_eq!(Vec::from(ca), correct_ham);
let joined_inner_hack = df_a.inner_join(&df_b, ["dummy"], ["dummy"]).unwrap();
let joined_inner = df_a
.join(&df_b, ["a", "b"], ["foo", "bar"], JoinType::Inner, None)
.unwrap();

dbg!(&joined_inner_hack, &joined_inner);
assert!(joined_inner_hack
.column("ham")
.unwrap()
Expand Down Expand Up @@ -1265,8 +1265,8 @@ mod test {
df_a.try_apply("b", |s| s.cast(&DataType::Categorical(None)))
.unwrap();
// create a new cache
toggle_string_cache(false);
toggle_string_cache(true);
reset_string_cache();
let sc = IUseStringCache::new();

df_b.try_apply("bar", |s| s.cast(&DataType::Categorical(None)))
.unwrap();
Expand Down
21 changes: 21 additions & 0 deletions polars/polars-io/src/csv/read.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "dtype-categorical")]
use polars_core::toggle_string_cache;

use super::*;

#[derive(Copy, Clone, Debug, Eq, PartialEq)]
Expand Down Expand Up @@ -368,6 +371,9 @@ where
#[allow(unused_mut)]
let mut to_cast_local = vec![];

#[cfg(feature = "dtype-categorical")]
let mut has_categorical = false;

let mut df = if let Some(schema) = self.schema_overwrite {
// This branch we check if there are dtypes we cannot parse.
// We only support a few dtypes in the parser and later cast to the required dtype
Expand All @@ -390,12 +396,22 @@ where
fld.coerce(DataType::Int32);
Some(fld)
}
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(_) => {
has_categorical = true;
Some(fld)
}
_ => Some(fld),
}
})
.collect();
let schema = Schema::from(fields);

#[cfg(feature = "dtype-categorical")]
if has_categorical {
toggle_string_cache(true);
}

// we cannot overwrite self, because the lifetime is already instantiated with `a, and
// the lifetime that accompanies this scope is shorter.
// So we just build_csv_reader from here
Expand Down Expand Up @@ -495,6 +511,11 @@ where
}

cast_columns(&mut df, &to_cast_local, true)?;

#[cfg(feature = "dtype-categorical")]
if has_categorical {
toggle_string_cache(false);
}
Ok(df)
}
}
Expand Down
32 changes: 17 additions & 15 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,19 +834,21 @@ def test_csv_categorical_lifetime() -> None:
"""
)

for string_cache in [True, False]:
pl.toggle_string_cache(string_cache)
df = pl.read_csv(
csv.encode(), dtypes={"a": pl.Categorical, "b": pl.Categorical}
)
assert df.dtypes == [pl.Categorical, pl.Categorical]
assert df.to_dict(False) == {
"a": ["needs_escape", ' "needs escape foo', ' "needs escape foo'],
"b": ["b", "b", None],
}
df = pl.read_csv(csv.encode(), dtypes={"a": pl.Categorical, "b": pl.Categorical})
assert df.dtypes == [pl.Categorical, pl.Categorical]
assert df.to_dict(False) == {
"a": ["needs_escape", ' "needs escape foo', ' "needs escape foo'],
"b": ["b", "b", None],
}

if string_cache:
assert (df["a"] == df["b"]).to_list() == [False, False, False]
else:
with pytest.raises(pl.ComputeError):
df["a"] == df["b"] # noqa: B015
assert (df["a"] == df["b"]).to_list() == [False, False, False]


def test_csv_categorical_categorical_merge() -> None:
N = 50
f = io.BytesIO()
pl.DataFrame({"x": ["A"] * N + ["B"] * N}).write_csv(f)
f.seek(0)
assert pl.read_csv(f, dtypes={"x": pl.Categorical}, sample_size=10).unique()[
"x"
].to_list() == ["A", "B"]

0 comments on commit 4726805

Please sign in to comment.