Skip to content

Commit

Permalink
fix(rust, python): allow nonstrict cast of categorical/enum to enum (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
mcrumiller committed Mar 10, 2024
1 parent 8a61d29 commit 419b891
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 34 deletions.
48 changes: 16 additions & 32 deletions crates/polars-core/src/chunked_array/logical/categorical/mod.rs
Expand Up @@ -130,18 +130,18 @@ impl CategoricalChunked {
}
}

// Convert to fixed enum. In case a value is not in the categories return Error
pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> PolarsResult<Self> {
// Convert to fixed enum. Values not in categories are mapped to None.
pub fn to_enum(&self, categories: &Utf8ViewArray, hash: u128) -> Self {
// Fast paths
match self.get_rev_map().as_ref() {
RevMapping::Local(_, cur_hash) if hash == *cur_hash => {
return unsafe {
Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
CategoricalChunked::from_cats_and_rev_map_unchecked(
self.physical().clone(),
self.get_rev_map().clone(),
true,
self.get_ordering(),
))
)
};
},
_ => (),
Expand All @@ -159,34 +159,18 @@ impl CategoricalChunked {
let new_phys: UInt32Chunked = self
.physical()
.into_iter()
.map(|opt_v: Option<u32>| {
let Some(v) = opt_v else {
return Ok(None);
};

let Some(idx) = idx_map.get(&v) else {
polars_bail!(
not_in_enum,
value = old_rev_map.get(v),
categories = &categories
);
};
.map(|opt_v: Option<u32>| opt_v.and_then(|v| idx_map.get(&v).copied()))
.collect();

Ok(Some(*idx))
})
.collect::<PolarsResult<_>>()?;

Ok(
// SAFETY: we created the physical from the enum categories
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
new_phys,
Arc::new(RevMapping::Local(categories.clone(), hash)),
true,
self.get_ordering(),
)
},
)
// SAFETY: we created the physical from the enum categories
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
new_phys,
Arc::new(RevMapping::Local(categories.clone(), hash)),
true,
self.get_ordering(),
)
}
}

pub(crate) fn get_flags(&self) -> Settings {
Expand Down Expand Up @@ -373,7 +357,7 @@ impl LogicalType for CategoricalChunked {
polars_bail!(ComputeError: "can not cast to enum with global mapping")
};
Ok(self
.to_enum(categories, *hash)?
.to_enum(categories, *hash)
.set_ordering(*ordering, true)
.into_series()
.with_name(self.name()))
Expand Down
28 changes: 26 additions & 2 deletions py-polars/tests/unit/datatypes/test_enum.py
Expand Up @@ -117,6 +117,26 @@ def test_casting_to_an_enum_from_categorical() -> None:
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_categorical_nonstrict() -> None:
dtype = pl.Enum(["a", "b"])
s = pl.Series([None, "a", "b", "c"], dtype=pl.Categorical)
s2 = s.cast(dtype, strict=False)
assert s2.dtype == dtype
assert s2.null_count() == 2 # "c" mapped to null
expected = pl.Series([None, "a", "b", None], dtype=dtype)
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_enum_nonstrict() -> None:
dtype = pl.Enum(["a", "b"])
s = pl.Series([None, "a", "b", "c"], dtype=pl.Enum(["a", "b", "c"]))
s2 = s.cast(dtype, strict=False)
assert s2.dtype == dtype
assert s2.null_count() == 2 # "c" mapped to null
expected = pl.Series([None, "a", "b", None], dtype=dtype)
assert_series_equal(s2, expected)


def test_casting_to_an_enum_from_integer() -> None:
dtype = pl.Enum(["a", "b", "c"])
expected = pl.Series([None, "b", "a", "c"], dtype=dtype)
Expand All @@ -139,7 +159,9 @@ def test_casting_to_an_enum_oob_from_integer() -> None:
def test_casting_to_an_enum_from_categorical_nonexistent() -> None:
with pytest.raises(
pl.ComputeError,
match=("value 'c' is not present in Enum"),
match=(
r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]"
),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"]))

Expand All @@ -159,7 +181,9 @@ def test_casting_to_an_enum_from_global_categorical() -> None:
def test_casting_to_an_enum_from_global_categorical_nonexistent() -> None:
with pytest.raises(
pl.ComputeError,
match=("value 'c' is not present in Enum"),
match=(
r"conversion from `cat` to `enum` failed in column '' for 1 out of 4 values: \[\"c\"\]"
),
):
pl.Series([None, "a", "b", "c"], dtype=pl.Categorical).cast(pl.Enum(["a", "b"]))

Expand Down

0 comments on commit 419b891

Please sign in to comment.