Skip to content

Commit

Permalink
fix(rust, python): don't allow categorical append that is not under s… (
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 13, 2022
1 parent 1e03e52 commit b497bbb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,22 @@ use crate::series::IsSorted;

impl CategoricalChunked {
pub fn append(&mut self, other: &Self) -> PolarsResult<()> {
let new_rev_map = self.merge_categorical_map(other)?;
unsafe { self.set_rev_map(new_rev_map, false) };
let is_local_different_source =
match (self.get_rev_map().as_ref(), other.get_rev_map().as_ref()) {
(RevMapping::Local(arr_l), RevMapping::Local(arr_r)) => !std::ptr::eq(arr_l, arr_r),
_ => false,
};

self.logical_mut().length += other.len() as IdxSize;
let len = self.len();
new_chunks(&mut self.logical.chunks, &other.logical().chunks, len);
if is_local_different_source {
return Err(PolarsError::ComputeError("Cannot concat Categoricals coming from a different source. Consider setting a global StringCache.".into()));
} else {
let len = self.len();
let new_rev_map = self.merge_categorical_map(other)?;
unsafe { self.set_rev_map(new_rev_map, false) };

self.logical_mut().length += other.len() as IdxSize;
new_chunks(&mut self.logical.chunks, &other.logical().chunks, len);
}
self.logical.set_sorted2(IsSorted::Not);
Ok(())
}
Expand Down
30 changes: 15 additions & 15 deletions py-polars/tests/unit/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,21 @@ def test_categorical_is_in_list() -> None:


def test_unset_sorted_on_append() -> None:
df1 = pl.DataFrame(
[
pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical),
pl.Series("val", [1, 2, 3, 4]),
]
).sort("key")
df2 = pl.DataFrame(
[
pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical),
pl.Series("val", [5, 6, 7, 8]),
]
).sort("key")
df = pl.concat([df1, df2], rechunk=False)
assert df.groupby("key").count()["count"].to_list() == [4, 4]
with pl.StringCache():
df1 = pl.DataFrame(
[
pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical),
pl.Series("val", [1, 2, 3, 4]),
]
).sort("key")
df2 = pl.DataFrame(
[
pl.Series("key", ["a", "b", "a", "b"], dtype=pl.Categorical),
pl.Series("val", [5, 6, 7, 8]),
]
).sort("key")
df = pl.concat([df1, df2], rechunk=False)
assert df.groupby("key").count()["count"].to_list() == [4, 4]


def test_categorical_error_on_local_cmp() -> None:
Expand Down Expand Up @@ -267,4 +268,3 @@ def test_categorical_list_concat_4762() -> None:
q = df.lazy().select([pl.concat_list([pl.col("x").cast(pl.Categorical)] * 2)])
with pl.StringCache():
assert q.collect().to_dict(False) == expected
assert q.collect().to_dict(False) == expected

0 comments on commit b497bbb

Please sign in to comment.