Skip to content

Commit

Permalink
pivot: fix categorical logicaltype (#4048)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 17, 2022
1 parent 396018b commit 20fe262
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 81 deletions.
15 changes: 12 additions & 3 deletions polars/polars-core/src/chunked_array/logical/categorical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,25 @@ impl CategoricalChunked {
self.bit_settings & 1 << 1 != 0
}

pub(crate) fn from_cats_and_rev_map(idx: UInt32Chunked, rev_map: Arc<RevMapping>) -> Self {
/// Create a [`CategoricalChunked`] from an array of `idx` and an existing [`RevMapping`]: `rev_map`.
///
/// # Safety
/// Invariant in `v < rev_map.len() for v in idx` must be hold.
pub unsafe fn from_cats_and_rev_map_unchecked(
idx: UInt32Chunked,
rev_map: Arc<RevMapping>,
) -> Self {
let mut logical = Logical::<UInt32Type, _>::new_logical::<CategoricalType>(idx);
logical.2 = Some(DataType::Categorical(Some(rev_map)));
Self {
logical,
bit_settings: 0,
bit_settings: Default::default(),
}
}

pub(crate) fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
/// # Safety
/// The existing index values must be in bounds of the new [`RevMapping`].
pub(crate) unsafe fn set_rev_map(&mut self, rev_map: Arc<RevMapping>, keep_fast_unique: bool) {
self.logical.2 = Some(DataType::Categorical(Some(rev_map)));
if !keep_fast_unique {
self.set_fast_unique(false)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::series::IsSorted;
impl CategoricalChunked {
pub fn append(&mut self, other: &Self) -> Result<()> {
let new_rev_map = self.merge_categorical_map(other)?;
self.set_rev_map(new_rev_map, false);
unsafe { self.set_rev_map(new_rev_map, false) };

let len = self.len();
new_chunks(&mut self.logical.chunks, &other.logical().chunks, len);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@ use super::*;
impl CategoricalChunked {
pub fn full_null(name: &str, length: usize) -> CategoricalChunked {
let cats = UInt32Chunked::full_null(name, length);
CategoricalChunked::from_cats_and_rev_map(cats, Arc::new(RevMapping::default()))
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
Arc::new(RevMapping::default()),
)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,24 @@ impl CategoricalChunked {
UInt32Chunked::from_iter_values(self.logical().name(), map.keys().copied())
}
};
let mut out = CategoricalChunked::from_cats_and_rev_map(ca, cat_map.clone());
out.set_fast_unique(true);
Ok(out)
// safety:
// we only removed some indexes so we are still in bounds
unsafe {
let mut out =
CategoricalChunked::from_cats_and_rev_map_unchecked(ca, cat_map.clone());
out.set_fast_unique(true);
Ok(out)
}
} else {
let ca = self.logical().unique()?;
Ok(CategoricalChunked::from_cats_and_rev_map(
ca,
cat_map.clone(),
))
// safety:
// we only removed some indexes so we are still in bounds
unsafe {
Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
ca,
cat_map.clone(),
))
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ impl CategoricalChunked {
_ => self.logical().zip_with(mask, other.logical())?,
};
let new_state = self.merge_categorical_map(other)?;
Ok(CategoricalChunked::from_cats_and_rev_map(cats, new_state))

// Safety:
// we checked the rev_maps.
unsafe {
Ok(CategoricalChunked::from_cats_and_rev_map_unchecked(
cats, new_state,
))
}
}
}
3 changes: 2 additions & 1 deletion polars/polars-core/src/chunked_array/ops/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ pub(crate) unsafe fn arr_to_any_value<'a>(
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(rev_map)) => {
let cats = s.u32().unwrap().clone();
let out = CategoricalChunked::from_cats_and_rev_map(cats, rev_map.clone());
let out =
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, rev_map.clone());
s = out.into_series();
}
DataType::Date
Expand Down
7 changes: 6 additions & 1 deletion polars/polars-core/src/chunked_array/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,12 @@ impl ChunkExplode for ListChunked {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(rev_map) => {
let cats = s.u32().unwrap().clone();
s = CategoricalChunked::from_cats_and_rev_map(cats, rev_map.unwrap()).into_series();
// safety:
// rev_map is from same array, so we are still in bounds
s = unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, rev_map.unwrap())
.into_series()
};
}
#[cfg(feature = "dtype-date")]
DataType::Date => s = s.into_date(),
Expand Down
33 changes: 24 additions & 9 deletions polars/polars-core/src/chunked_array/ops/sort/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ impl CategoricalChunked {
let sorted = ca.sort(options.descending);
let arr = sorted.downcast_iter().next().unwrap().clone();
let rev_map = RevMapping::Local(arr);
CategoricalChunked::from_cats_and_rev_map(
self.logical().clone(),
Arc::new(rev_map),
)
// safety:
// we only reordered the indexes so we are still in bounds
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
self.logical().clone(),
Arc::new(rev_map),
)
}
}
RevMapping::Global(_, _, _) => {
// a global rev map must always point to the same string values
Expand All @@ -61,15 +65,26 @@ impl CategoricalChunked {
);
let cats: NoNull<UInt32Chunked> =
vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
CategoricalChunked::from_cats_and_rev_map(
cats.into_inner(),
self.get_rev_map().clone(),
)
// safety:
// we only reordered the indexes so we are still in bounds
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats.into_inner(),
self.get_rev_map().clone(),
)
}
}
}
} else {
let cats = self.logical().sort_with(options);
CategoricalChunked::from_cats_and_rev_map(cats, self.get_rev_map().clone())
// safety:
// we only reordered the indexes so we are still in bounds
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
self.get_rev_map().clone(),
)
}
}
}

Expand Down
12 changes: 11 additions & 1 deletion polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,17 @@ pub struct Pivot<'df> {
// Takes a `DataFrame` that only consists of the column aggregates that are pivoted by
// the values in `columns`
fn finish_logical_type(column: &mut Series, dtype: &DataType) {
*column = column.cast(dtype).unwrap();
*column = match dtype {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(rev_map)) => {
let ca = column.u32().unwrap();
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(ca.clone(), rev_map.clone())
}
.into_series()
}
_ => column.cast(dtype).unwrap(),
};
}

impl<'df> Pivot<'df> {
Expand Down
7 changes: 6 additions & 1 deletion polars/polars-core/src/frame/hash_join/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,12 @@ impl DataFrame {
let ca_left = s_left.categorical().unwrap();
let new_rev_map = ca_left.merge_categorical_map(s_right.categorical().unwrap())?;
let logical = s.u32().unwrap().clone();
CategoricalChunked::from_cats_and_rev_map(logical, new_rev_map).into_series()
// safety:
// categorical maps are merged
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(logical, new_rev_map)
.into_series()
}
}
dt @ DataType::Datetime(_, _)
| dt @ DataType::Time
Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl Series {
#[cfg(feature = "dtype-categorical")]
Categorical(rev_map) => {
let cats = UInt32Chunked::from_chunks(name, chunks);
CategoricalChunked::from_cats_and_rev_map(cats, rev_map.clone().unwrap())
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, rev_map.clone().unwrap())
.into_series()
}
Boolean => BooleanChunked::from_chunks(name, chunks).into_series(),
Expand Down
12 changes: 9 additions & 3 deletions polars/polars-core/src/series/implementations/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ impl IntoSeries for CategoricalChunked {

impl SeriesWrap<CategoricalChunked> {
fn finish_with_state(&self, keep_fast_unique: bool, cats: UInt32Chunked) -> CategoricalChunked {
let mut out = CategoricalChunked::from_cats_and_rev_map(cats, self.0.get_rev_map().clone());
let mut out = unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, self.0.get_rev_map().clone())
};
if keep_fast_unique && self.0.can_fast_unique() {
out.set_fast_unique(true)
}
Expand Down Expand Up @@ -127,7 +129,9 @@ impl private::PrivateSeries for SeriesWrap<CategoricalChunked> {
let cats = left.zip_outer_join_column(&right, opt_join_tuples);
let cats = cats.u32().unwrap().clone();

CategoricalChunked::from_cats_and_rev_map(cats, new_rev_map).into_series()
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, new_rev_map).into_series()
}
}
fn group_tuples(&self, multithreaded: bool, sorted: bool) -> GroupsProxy {
self.0.logical().group_tuples(multithreaded, sorted)
Expand Down Expand Up @@ -196,7 +200,9 @@ impl SeriesTrait for SeriesWrap<CategoricalChunked> {
let other = other.categorical()?;
self.0.logical_mut().extend(other.logical());
let new_rev_map = self.0.merge_categorical_map(other)?;
self.0.set_rev_map(new_rev_map, false);
// safety:
// rev_maps are merged
unsafe { self.0.set_rev_map(new_rev_map, false) };
Ok(())
} else {
Err(PolarsError::SchemaMisMatch(
Expand Down
9 changes: 8 additions & 1 deletion polars/polars-core/src/series/into.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,14 @@ impl Series {
let arr = ca.logical().chunks()[chunk_idx].clone();
let cats = UInt32Chunked::from_chunks("", vec![arr]);

let new = CategoricalChunked::from_cats_and_rev_map(cats, ca.get_rev_map().clone());
// safety:
// we only take a single chunk and change nothing about the index/rev_map mapping
let new = unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
ca.get_rev_map().clone(),
)
};

let arr: DictionaryArray<u32> = (&new).into();
Box::new(arr) as ArrayRef
Expand Down
50 changes: 0 additions & 50 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,39 +744,6 @@ def test_groupby() -> None:
assert df.groupby("b").agg_list().shape == (2, 3)


def test_pivot() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": ["a", "a", "b", "b", "b"],
"c": [None, 1, None, 1, None],
}
)
gb = df.groupby("b").pivot("a", "c")
assert gb.first().shape == (2, 6)
assert gb.max().shape == (2, 6)
assert gb.mean().shape == (2, 6)
assert gb.count().shape == (2, 6)
assert gb.median().shape == (2, 6)

for agg_fn in ["sum", "min", "max", "mean", "count", "median", "mean"]:
out = df.pivot(
values="c", index="b", columns="a", aggregate_fn=agg_fn, sort_columns=True
)
assert out.shape == (2, 6)

# example in polars-book
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
}
)
out = df.groupby("foo").pivot(pivot_column="bar", values_column="N").first()
assert out.shape == (3, 6)


def test_join() -> None:
df_left = pl.DataFrame(
{
Expand Down Expand Up @@ -2059,23 +2026,6 @@ def test_get_item() -> None:
_ = df[pl.Series("", ["hello Im a string"])]


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

expected = pl.DataFrame(
{
"a": [1, 2, 3],
"1": [[1, 1], None, None],
"2": [None, [2, 2], None],
"3": [None, None, [3, 3]],
}
)

out = df.pivot("b", index="a", columns="a", aggregate_fn="first", sort_columns=True)

assert out.frame_equal(expected, null_equal=True)


@pytest.mark.parametrize("as_series,inner_dtype", [(True, pl.Series), (False, list)])
def test_to_dict(as_series: bool, inner_dtype: Any) -> None:
df = pl.DataFrame(
Expand Down
69 changes: 69 additions & 0 deletions py-polars/tests/test_pivot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import polars as pl


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

expected = pl.DataFrame(
{
"a": [1, 2, 3],
"1": [[1, 1], None, None],
"2": [None, [2, 2], None],
"3": [None, None, [3, 3]],
}
)

out = df.pivot("b", index="a", columns="a", aggregate_fn="first", sort_columns=True)

assert out.frame_equal(expected, null_equal=True)


def test_pivot() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": ["a", "a", "b", "b", "b"],
"c": [None, 1, None, 1, None],
}
)
gb = df.groupby("b").pivot("a", "c")
assert gb.first().shape == (2, 6)
assert gb.max().shape == (2, 6)
assert gb.mean().shape == (2, 6)
assert gb.count().shape == (2, 6)
assert gb.median().shape == (2, 6)

for agg_fn in ["sum", "min", "max", "mean", "count", "median", "mean"]:
out = df.pivot(
values="c", index="b", columns="a", aggregate_fn=agg_fn, sort_columns=True
)
assert out.shape == (2, 6)

# example in polars-book
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
}
)
out = df.groupby("foo").pivot(pivot_column="bar", values_column="N").first()
assert out.shape == (3, 6)


def test_pivot_categorical_3968() -> None:
df = pl.DataFrame(
{
"foo": ["one", "one", "one", "two", "two", "two"],
"bar": ["A", "B", "C", "A", "B", "C"],
"baz": [1, 2, 3, 4, 5, 6],
}
)

assert df.with_column(pl.col("baz").cast(str).cast(pl.Categorical)).to_dict(
False
) == {
"foo": ["one", "one", "one", "two", "two", "two"],
"bar": ["A", "B", "C", "A", "B", "C"],
"baz": ["1", "2", "3", "4", "5", "6"],
}

0 comments on commit 20fe262

Please sign in to comment.