Skip to content

Commit

Permalink
fix pivot args
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 29, 2022
1 parent 971365d commit a9752d3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 0 deletions.
4 changes: 4 additions & 0 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4775,6 +4775,10 @@ def pivot(
values_column
Column that will be aggregated.
"""
if isinstance(pivot_column, str):
pivot_column = [pivot_column]
if isinstance(values_column, str):
values_column = [values_column]
return PivotOps(self._df, self.by, pivot_column, values_column)

def first(self) -> DataFrame:
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,17 @@ def test_pivot() -> None:
out = df.pivot(values="c", index="b", columns="a", aggregate_fn=agg_fn)
assert out.shape == (2, 6)

# example in polars-book
df = pl.DataFrame(
{
"foo": ["A", "A", "B", "B", "C"],
"N": [1, 2, 2, 4, 2],
"bar": ["k", "l", "m", "n", "o"],
}
)
out = df.groupby("foo").pivot(pivot_column="bar", values_column="N").first()
assert out.shape == (3, 6)


def test_join() -> None:
df_left = pl.DataFrame(
Expand Down

0 comments on commit a9752d3

Please sign in to comment.