Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): auto-determine index/columns/values columns in pivot if one is left out, deprecate passing arguments positionally #12125

Closed
wants to merge 10 commits into from
31 changes: 25 additions & 6 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -6915,11 +6915,12 @@ def explode(
"""
return self.lazy().explode(columns, *more_columns).collect(_eager=True)

@deprecate_nonkeyword_arguments(allowed_args=None, version="0.19.14")
def pivot(
self,
values: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
columns: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None,
values: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
index: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
columns: ColumnNameOrSelector | Sequence[ColumnNameOrSelector] | None = None,
aggregate_function: PivotAgg | Expr | None = None,
*,
maintain_order: bool = True,
Expand All @@ -6936,12 +6937,15 @@ def pivot(
----------
values
Column values to aggregate. Can be multiple columns if the *columns*
arguments contains multiple columns as well.
arguments contains multiple columns as well. If None, all columns not
specified by `index` or `columns` are used.
index
One or multiple keys to group by.
One or multiple keys to group by. If None, all columns not specified by
`values` or `columns` are used.
columns
Name of the column(s) whose values will be used as the header of the output
DataFrame.
DataFrame. If None, all columns not specified by `values` or `index` are
used.
aggregate_function
Choose from:

Expand All @@ -6950,6 +6954,9 @@ def pivot(
{'first', 'sum', 'max', 'min', 'mean', 'median', 'last', 'count'}
- An expression to do the aggregation.

Note that only two of `values`, `index`, and `columns` are required; if any are
not specified, the remaining columns are used.

maintain_order
Sort the grouped keys so that the output order is predictable.
sort_columns
Expand Down Expand Up @@ -7065,6 +7072,18 @@ def pivot(
index = _expand_selectors(self, index)
columns = _expand_selectors(self, columns)

# If only two of three values/index/columns are supplied, infer the third.
if (not values) + (not index) + (not columns) > 1:
raise ValueError(
"must provide at least two of `values`, `index`, and `columns`"
)
if not values:
values = [x for x in self.columns if x not in set(index).union(columns)]
elif not index:
index = [x for x in self.columns if x not in set(values).union(columns)]
elif not columns:
columns = [x for x in self.columns if x not in set(values).union(index)]

if isinstance(aggregate_function, str):
if aggregate_function == "first":
aggregate_expr = F.element().first()._pyexpr
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _expand_selectors(
if is_selector(item):
selector_cols = expand_selector(frame, item)
expanded.extend(selector_cols)
else:
elif item is not None:
expanded.append(item)
return expanded

Expand Down
59 changes: 57 additions & 2 deletions py-polars/tests/unit/operations/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,57 @@ def test_pivot() -> None:
assert_frame_equal(result, expected)


def test_pivot_missing_arg() -> None:
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"M": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
}
)

expected = pl.DataFrame(
[
("A", 1, 2, None, None, None, 1, 2, None, None, None),
("B", None, None, 2, 4, None, None, None, 2, 4, None),
("C", None, None, None, None, 2, None, None, None, None, 2),
],
schema=[
"foo",
"N_bar_k",
"N_bar_l",
"N_bar_m",
"N_bar_n",
"N_bar_o",
"M_bar_k",
"M_bar_l",
"M_bar_m",
"M_bar_n",
"M_bar_o",
],
)

result = df.pivot(values=None, index="foo", columns="bar", aggregate_function=None)
assert_frame_equal(result, expected)
result = df.pivot(
values=["N", "M"], index=None, columns="bar", aggregate_function=None
)
assert_frame_equal(result, expected)
result = df.pivot(
values=["N", "M"], index="foo", columns=None, aggregate_function=None
)
assert_frame_equal(result, expected)

# failing tests: only one of values/index/columns supplied
with pytest.raises(ValueError, match="must provide at least two of "):
df.pivot(values="N", index=None, columns=None, aggregate_function=None)
with pytest.raises(ValueError, match="must provide at least two of "):
df.pivot(values=None, index="foo", columns=None, aggregate_function=None)
with pytest.raises(ValueError, match="must provide at least two of "):
df.pivot(values=None, index=None, columns="bar", aggregate_function=None)


def test_pivot_list() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [[1, 1], [2, 2], [3, 3]]})

Expand All @@ -47,7 +98,11 @@ def test_pivot_list() -> None:
}
)
out = df.pivot(
"b", index="a", columns="a", aggregate_function="first", sort_columns=True
values="b",
index="a",
columns="a",
aggregate_function="first",
sort_columns=True,
)
assert_frame_equal(out, expected)

Expand Down Expand Up @@ -319,7 +374,7 @@ def test_aggregate_function_deprecation_warning() -> None:
with pytest.raises(
pl.ComputeError, match="found multiple elements in the same group"
):
df.pivot("a", "b", "c")
df.pivot(values="a", index="b", columns="c")


def test_pivot_struct() -> None:
Expand Down