Skip to content

Commit

Permalink
fix[rust]: Categorical::default repsect global string cache (#4496)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 19, 2022
1 parent 3496e33 commit 90f1feb
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,14 @@ pub enum RevMapping {
impl Default for RevMapping {
fn default() -> Self {
let slice: &[Option<&str>] = &[];
RevMapping::Local(Utf8Array::<i64>::from(slice))
let cats = Utf8Array::<i64>::from(slice);
if use_string_cache() {
let cache = &mut crate::STRING_CACHE.lock_map();
let id = cache.uuid;
RevMapping::Global(Default::default(), cats, id)
} else {
RevMapping::Local(cats)
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use super::*;
impl CategoricalChunked {
pub fn full_null(name: &str, length: usize) -> CategoricalChunked {
let cats = UInt32Chunked::full_null(name, length);

unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
Expand Down
15 changes: 15 additions & 0 deletions py-polars/tests/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,18 @@ def test_shift_and_fill() -> None:
s = df.with_column(pl.col("a").shift_and_fill(1, "c"))["a"]
assert s.dtype == pl.Categorical
assert s.to_list() == ["c", "a"]


def test_merge_lit_under_global_cache_4491() -> None:
with pl.StringCache():
df = pl.DataFrame(
[
pl.Series("label", ["foo", "bar"], dtype=pl.Categorical),
pl.Series("value", [3, 9]),
]
)
assert df.with_column(
pl.when(pl.col("value") > 5)
.then(pl.col("label"))
.otherwise(pl.lit(None, pl.Categorical))
).to_dict(False) == {"label": [None, "bar"], "value": [3, 9]}

0 comments on commit 90f1feb

Please sign in to comment.