Skip to content

Commit

Permalink
test(python): Refactor pivot tests (#6012)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jan 3, 2023
1 parent 359b026 commit a050aef
Showing 1 changed file with 98 additions and 73 deletions.
171 changes: 98 additions & 73 deletions py-polars/tests/unit/test_pivot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,37 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
from polars.internals.type_aliases import PivotAgg


def test_pivot() -> None:
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
}
)
result = df.pivot(values="N", index="foo", columns="bar")

expected = pl.DataFrame(
[
("A", 1, 2, None, None, None),
("B", None, None, 2, 4, None),
("C", None, None, None, None, 2),
],
columns=["foo", "k", "l", "m", "n", "o"],
)
assert_frame_equal(result, expected)


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

Expand All @@ -22,56 +44,61 @@ def test_pivot_list() -> None:
}
)
out = df.pivot("b", index="a", columns="a", aggregate_fn="first", sort_columns=True)
assert out.frame_equal(expected, null_equal=True)


def test_pivot() -> None:
assert_frame_equal(out, expected)


@pytest.mark.parametrize(
["agg_fn", "expected_rows"],
[
("first", [("a", 2, None, None), ("b", None, None, 10)]),
("count", [("a", 2, None, None), ("b", None, 2, 1)]),
("min", [("a", 2, None, None), ("b", None, 8, 10)]),
("max", [("a", 4, None, None), ("b", None, 8, 10)]),
("sum", [("a", 6, None, None), ("b", None, 8, 10)]),
("mean", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]),
("median", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]),
],
)
def test_pivot_aggregate(agg_fn: PivotAgg, expected_rows: list[tuple[Any]]) -> None:
df = pl.DataFrame(
{
"a": [1, 1, 2, 2, 3],
"b": ["a", "a", "b", "b", "b"],
"c": [2, 4, None, 8, 10],
}
)

with pytest.deprecated_call():
gb = df.groupby("b").pivot(
pivot_column="a",
values_column="c",
)
assert gb.count().rows() == [("a", 2, None, None), ("b", None, 2, 1)]
assert gb.first().rows() == [("a", 2, None, None), ("b", None, None, 10)]
assert gb.max().rows() == [("a", 4, None, None), ("b", None, 8, 10)]
assert gb.mean().rows() == [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]
assert gb.median().rows() == [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]
assert gb.min().rows() == [("a", 2, None, None), ("b", None, 8, 10)]
assert gb.sum().rows() == [("a", 6, None, None), ("b", None, 8, 10)]

agg_fns: list[PivotAgg] = ["sum", "min", "max", "mean", "count", "median", "mean"]
for agg_fn in agg_fns:
out = df.pivot(
values="c", index="b", columns="a", aggregate_fn=agg_fn, sort_columns=True
)
assert out.shape == (2, 4)
assert out.rows() == getattr(gb, agg_fn)().rows()

# example in polars-book
result = df.pivot(
values="c", index="b", columns="a", aggregate_fn=agg_fn, sort_columns=True
)
assert result.rows() == expected_rows


@pytest.mark.parametrize(
["agg_fn", "expected_rows"],
[
("first", [("a", 2, None, None), ("b", None, None, 10)]),
("count", [("a", 2, None, None), ("b", None, 2, 1)]),
("min", [("a", 2, None, None), ("b", None, 8, 10)]),
("max", [("a", 4, None, None), ("b", None, 8, 10)]),
("sum", [("a", 6, None, None), ("b", None, 8, 10)]),
("mean", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]),
("median", [("a", 3.0, None, None), ("b", None, 8.0, 10.0)]),
],
)
def test_pivot_groupby_aggregate(
agg_fn: PivotAgg, expected_rows: list[tuple[Any]]
) -> None:
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
"a": [1, 1, 2, 2, 3],
"b": ["a", "a", "b", "b", "b"],
"c": [2, 4, None, 8, 10],
}
)
with pytest.deprecated_call():
out = df.groupby("foo").pivot(pivot_column="bar", values_column="N").first()

assert out.shape == (3, 6)
assert out.rows() == [
("A", 1, 2, None, None, None),
("B", None, None, 2, 4, None),
("C", None, None, None, None, 2),
]
pivot = df.groupby("b").pivot(pivot_column="a", values_column="c")
result = getattr(pivot, agg_fn)()
assert result.rows() == expected_rows


def test_pivot_categorical_3968() -> None:
Expand All @@ -98,21 +125,13 @@ def test_pivot_categorical_index() -> None:
columns=[("A", pl.Categorical), ("B", pl.Categorical)],
)

result = df.pivot(values="B", index=["A"], columns="B", aggregate_fn="count")
expected = {"A": ["Fire", "Water"], "Car": [1, 2], "Ship": [1, None]}
assert (
df.pivot(values="B", index=["A"], columns="B", aggregate_fn="count").to_dict(
False
)
== expected
)
assert result.to_dict(False) == expected

# test expression dispatch
assert (
df.pivot(values="B", index=["A"], columns="B", aggregate_fn=pl.count()).to_dict(
False
)
== expected
)
result = df.pivot(values="B", index=["A"], columns="B", aggregate_fn=pl.count())
assert result.to_dict(False) == expected

df = pl.DataFrame(
{
Expand All @@ -122,14 +141,14 @@ def test_pivot_categorical_index() -> None:
},
columns=[("A", pl.Categorical), ("B", pl.Categorical), ("C", pl.Categorical)],
)
assert df.pivot(
values="B", index=["A", "C"], columns="B", aggregate_fn="count"
).to_dict(False) == {
result = df.pivot(values="B", index=["A", "C"], columns="B", aggregate_fn="count")
expected = {
"A": ["Fire", "Water"],
"C": ["Paper", "Paper"],
"Car": [1, 2],
"Ship": [1, None],
}
assert result.to_dict(False) == expected


def test_pivot_multiple_values_column_names_5116() -> None:
Expand All @@ -141,17 +160,18 @@ def test_pivot_multiple_values_column_names_5116() -> None:
"c2": ["C", "C", "D", "D"] * 2,
}
)
assert df.pivot(values=["x1", "x2"], index="c1", columns="c2").to_dict(False) == {
result = df.pivot(values=["x1", "x2"], index="c1", columns="c2")
expected = {
"c1": ["A", "B"],
"x1_C": [1, 2],
"x1_D": [3, 4],
"x2_C": [8, 7],
"x2_D": [6, 5],
}
assert result.to_dict(False) == expected


def test_pivot_floats() -> None:

df = pl.DataFrame(
{
"article": ["a", "a", "a", "b", "b", "b"],
Expand All @@ -161,41 +181,46 @@ def test_pivot_floats() -> None:
}
)

assert df.pivot(values="price", index="weight", columns="quantity",).to_dict(
False
) == {
result = df.pivot(values="price", index="weight", columns="quantity")
expected = {
"weight": [1.0, 4.4, 8.8],
"1.0": [1.0, 3.0, 5.0],
"5.0": [2.0, None, None],
"7.5": [6.0, None, None],
}
assert result.to_dict(False) == expected

assert df.pivot(
values="price",
index=["article", "weight"],
columns="quantity",
).to_dict(False) == {
result = df.pivot(values="price", index=["article", "weight"], columns="quantity")
expected = {
"article": ["a", "a", "b", "b"],
"weight": [1.0, 4.4, 1.0, 8.8],
"1.0": [1.0, 3.0, 4.0, 5.0],
"5.0": [2.0, None, None, None],
"7.5": [None, None, 6.0, None],
}
assert result.to_dict(False) == expected


def test_pivot_reinterpret_5907() -> None:
assert pl.DataFrame(
df = pl.DataFrame(
{
"A": pl.Series([3, -2, 3, -2], dtype=pl.Int32),
"B": ["x", "x", "y", "y"],
"C": [100, 50, 500, -80],
}
).pivot(
)

result = df.pivot(
index=["A"], values=["C"], columns=["B"], aggregate_fn=pl.element().sum()
).to_dict(
False
) == {
"A": [3, -2],
"x": [100, 50],
"y": [500, -80],
}
)
expected = {"A": [3, -2], "x": [100, 50], "y": [500, -80]}
assert result.to_dict(False) == expected


def test_pivot_subclassed_df() -> None:
class SubClassedDataFrame(pl.DataFrame):
pass

df = SubClassedDataFrame({"a": [1, 2], "b": [3, 4]})
result = df.pivot(values="b", index="a", columns="a", aggregate_fn="first")
assert isinstance(result, SubClassedDataFrame)

0 comments on commit a050aef

Please sign in to comment.