Skip to content

Commit

Permalink
fix(rust, python): fix pivot on floating point indexes (#5704)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 2, 2022
1 parent f95a49b commit e95396a
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
10 changes: 6 additions & 4 deletions polars/polars-core/src/chunked_array/ops/bit_repr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ impl Reinterpret for Int64Chunked {
}

impl UInt64Chunked {
pub(crate) fn reinterpret_float(&self) -> Float64Chunked {
#[doc(hidden)]
pub fn _reinterpret_float(&self) -> Float64Chunked {
let chunks = self
.downcast_iter()
.map(|array| {
Expand Down Expand Up @@ -220,7 +221,8 @@ impl UInt64Chunked {
}
}
impl UInt32Chunked {
pub(crate) fn reinterpret_float(&self) -> Float32Chunked {
#[doc(hidden)]
pub fn _reinterpret_float(&self) -> Float32Chunked {
let chunks = self
.downcast_iter()
.map(|array| {
Expand Down Expand Up @@ -258,7 +260,7 @@ impl Float32Chunked {
let s = self.bit_repr_small().into_series();
let out = f(&s);
let out = out.u32().unwrap();
out.reinterpret_float().into()
out._reinterpret_float().into()
}
}
impl Float64Chunked {
Expand All @@ -269,6 +271,6 @@ impl Float64Chunked {
let s = self.bit_repr_large().into_series();
let out = f(&s);
let out = out.u64().unwrap();
out.reinterpret_float().into()
out._reinterpret_float().into()
}
}
4 changes: 2 additions & 2 deletions polars/polars-core/src/chunked_array/ops/unique/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl ChunkUnique<Float32Type> for Float32Chunked {
fn unique(&self) -> PolarsResult<ChunkedArray<Float32Type>> {
let ca = self.bit_repr_small();
let ca = ca.unique()?;
Ok(ca.reinterpret_float())
Ok(ca._reinterpret_float())
}

fn arg_unique(&self) -> PolarsResult<IdxCa> {
Expand All @@ -408,7 +408,7 @@ impl ChunkUnique<Float64Type> for Float64Chunked {
fn unique(&self) -> PolarsResult<ChunkedArray<Float64Type>> {
let ca = self.bit_repr_large();
let ca = ca.unique()?;
Ok(ca.reinterpret_float())
Ok(ca._reinterpret_float())
}

fn arg_unique(&self) -> PolarsResult<IdxCa> {
Expand Down
2 changes: 2 additions & 0 deletions polars/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ impl<'a> Hash for AnyValue<'a> {
UInt64(v) => state.write_u64(*v),
Utf8(v) => state.write(v.as_bytes()),
Utf8Owned(v) => state.write(v.as_bytes()),
Float32(v) => state.write_u32(v.to_bits()),
Float64(v) => state.write_u64(v.to_bits()),
#[cfg(feature = "dtype-binary")]
Binary(v) => state.write(v),
#[cfg(feature = "dtype-binary")]
Expand Down
8 changes: 8 additions & 0 deletions polars/polars-ops/src/pivot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,14 @@ fn restore_logical_type(s: &Series, logical_type: &DataType) -> Series {
.into_series()
}
}
DataType::Float32 if matches!(s.dtype(), DataType::UInt32) => {
let ca = s.u32().unwrap();
ca._reinterpret_float().into_series()
}
DataType::Float64 if matches!(s.dtype(), DataType::UInt64) => {
let ca = s.u64().unwrap();
ca._reinterpret_float().into_series()
}
_ => s.cast(logical_type).unwrap(),
}
}
Expand Down
33 changes: 33 additions & 0 deletions py-polars/tests/unit/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,36 @@ def test_pivot_multiple_values_column_names_5116() -> None:
"x2_C": [8, 7],
"x2_D": [6, 5],
}


def test_pivot_floats() -> None:

df = pl.DataFrame(
{
"article": ["a", "a", "a", "b", "b", "b"],
"weight": [1.0, 1.0, 4.4, 1.0, 8.8, 1.0],
"quantity": [1.0, 5.0, 1.0, 1.0, 1.0, 7.5],
"price": [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
}
)

assert df.pivot(values="price", index="weight", columns="quantity",).to_dict(
False
) == {
"weight": [1.0, 4.4, 8.8],
"1.0": [1.0, 3.0, 5.0],
"5.0": [2.0, None, None],
"7.5": [6.0, None, None],
}

assert df.pivot(
values="price",
index=["article", "weight"],
columns="quantity",
).to_dict(False) == {
"article": ["a", "a", "b", "b"],
"weight": [1.0, 4.4, 1.0, 8.8],
"1.0": [1.0, 3.0, 4.0, 5.0],
"5.0": [2.0, None, None, None],
"7.5": [None, None, 6.0, None],
}

0 comments on commit e95396a

Please sign in to comment.