Skip to content

Commit

Permalink
pivot: fix logical type of multiple indexes (#4159)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 26, 2022
1 parent e2dacbf commit ba400f4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
42 changes: 20 additions & 22 deletions polars/polars-core/src/frame/groupby/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ pub enum PivotAgg {
Last,
}

fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
// restore logical type
match logical_type {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(rev_map)) => {
let cats = s.u32().unwrap().clone();
// safety:
// the rev-map comes from these categoricals
unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(cats, rev_map.clone())
.into_series()
}
}
_ => s.cast(logical_type).unwrap(),
}
}

impl DataFrame {
/// Do a pivot operation based on the group key, a pivot column and an aggregation function on the values column.
///
Expand Down Expand Up @@ -141,29 +158,11 @@ impl DataFrame {

let row_index = match count {
0 => {
let mut s = Series::new(
let s = Series::new(
&index[0],
row_to_idx.into_iter().map(|(k, _)| k).collect::<Vec<_>>(),
);
// restore logical type
match index_s.dtype() {
#[cfg(feature = "dtype-categorical")]
DataType::Categorical(Some(rev_map)) => {
let cats = s.u32().unwrap().clone();
// safety:
// the rev-map comes from these categoricals
s = unsafe {
CategoricalChunked::from_cats_and_rev_map_unchecked(
cats,
rev_map.clone(),
)
.into_series()
};
}
_ => {
s = s.cast(index_s.dtype()).unwrap();
}
}
let s = restore_logical_type(&s, index_s.dtype());
Some(vec![s])
}
_ => None,
Expand Down Expand Up @@ -218,8 +217,7 @@ impl DataFrame {
})
.collect::<Vec<_>>(),
);
// restore logical type
s.cast(index_s[i].dtype()).unwrap()
restore_logical_type(&s, index_s[i].dtype())
})
.collect::<Vec<_>>(),
),
Expand Down
19 changes: 16 additions & 3 deletions py-polars/tests/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def test_pivot_categorical_index() -> None:
columns=[("A", pl.Categorical), ("B", pl.Categorical)],
)

assert df.pivot(values="B", index=["A"], columns="B", aggregate_fn="count").to_dict(
False
) == {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]}
df = pl.DataFrame(
{
"A": ["Fire", "Water", "Water", "Fire"],
"B": ["Car", "Car", "Car", "Ship"],
"C": ["Paper", "Paper", "Paper", "Paper"],
},
columns=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)],
)
assert df.pivot(
values="B", index=["A", "C"], columns="B", aggregate_fn="count"
).to_dict(False) == {
"A": ["Fire", "Water"],
"C": ["Paper", "Paper"],
"Car": [1, 2],
"Ship": [1, None],
}

0 comments on commit ba400f4

Please sign in to comment.