Skip to content

Commit

Permalink
feat[python]: allow ".row(idx)" to index with a predicate (#4785)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Sep 9, 2022
1 parent 04e3526 commit f1726fd
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 14 deletions.
3 changes: 3 additions & 0 deletions py-polars/docs/source/reference/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ Exceptions
DuplicateError
NoDataError
NotFoundError
NoRowsReturned
PanicException
RowsException
SchemaError
ShapeError
TooManyRowsReturned
15 changes: 15 additions & 0 deletions py-polars/polars/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ class PanicException(Exception): # type: ignore[no-redef]
"""Exception raised when an unexpected state causes a panic in the underlying Rust library.""" # noqa: E501


class RowsException(Exception):
"""Exception raised when the number of returned rows does not match expectation."""


class NoRowsReturned(RowsException):
"""Exception raised when no rows are returned, but at least one row is expected."""


class TooManyRowsReturned(RowsException):
"""Exception raised when more rows than expected are returned."""


__all__ = [
"ArrowError",
"ComputeError",
Expand All @@ -47,4 +59,7 @@ class PanicException(Exception): # type: ignore[no-redef]
"ShapeError",
"DuplicateError",
"PanicException",
"RowsException",
"NoRowsReturned",
"TooManyRowsReturned",
]
45 changes: 41 additions & 4 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_idx_type,
py_type_to_dtype,
)
from polars.exceptions import NoRowsReturned, TooManyRowsReturned
from polars.internals.construction import (
arrow_to_pydf,
dict_to_pydf,
Expand Down Expand Up @@ -6009,14 +6010,27 @@ def fold(
acc = operation(acc, self.to_series(i))
return acc

def row(self, index: int) -> tuple[Any]:
def row(
self, index: int | None = None, *, by_predicate: pli.Expr | None = None
) -> tuple[Any, ...]:
"""
Get a row as tuple.
Get a row as tuple, either by index or by predicate.
Parameters
----------
index
Row index.
by_predicate
Select the row according to a given expression/predicate.
Notes
-----
The `index` and `by_predicate` params are mutually exclusive. Additionally,
to ensure clarity, the `by_predicate` parameter must be supplied by keyword.
When using `by_predicate` it is an error condition if anything other than
one row is returned; more than one row raises `TooManyRowsReturned`, and
zero rows will raise `NoRowsReturned` (both inherit from `RowsException`).
Examples
--------
Expand All @@ -6027,13 +6041,36 @@ def row(self, index: int) -> tuple[Any]:
... "ham": ["a", "b", "c"],
... }
... )
>>> # return the row at the given index
>>> df.row(2)
(3, 8, 'c')
>>> # return the row that matches the given predicate
>>> df.row(by_predicate=(pl.col("ham") == "b"))
(2, 7, 'b')
"""
return self._df.row_tuple(index)
if index is not None and by_predicate is not None:
raise ValueError(
"Cannot set both 'index' and 'by_predicate'; mutually exclusive"
)
elif isinstance(index, pli.Expr):
raise TypeError("Expressions should be passed to the 'by_predicate' param")
elif isinstance(index, int):
return self._df.row_tuple(index)
elif isinstance(by_predicate, pli.Expr):
rows = self.filter(by_predicate).rows()
n_rows = len(rows)
if n_rows > 1:
raise TooManyRowsReturned(
f"Predicate <{by_predicate!s}> returned {n_rows} rows"
)
elif n_rows == 0:
raise NoRowsReturned(f"Predicate <{by_predicate!s}> returned no rows")
return rows[0]
else:
raise ValueError("One of 'index' or 'by_predicate' must be set")

def rows(self) -> list[tuple[object, ...]]:
def rows(self) -> list[tuple[Any, ...]]:
"""
Convert columnar data to rows as python tuples.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/tests/parametric/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ def test_strategy_null_probability(
assert df2.null_count().fold(sum).sum() < df3.null_count().fold(sum).sum()

nulls_col0, nulls_col1 = df2.null_count().rows()[0]
assert nulls_col0 > nulls_col1 # type: ignore[operator]
assert nulls_col0 < 50 # type: ignore[operator]
assert nulls_col0 > nulls_col1
assert nulls_col0 < 50

nulls_col0, nulls_colx = df3.null_count().rows()[0]
assert nulls_col0 > nulls_colx # type: ignore[operator]
assert nulls_col0 > nulls_colx
assert nulls_col0 == 50
6 changes: 3 additions & 3 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,8 @@ def test_csv_globbing(examples_dir: str) -> None:

df = pl.read_csv(path, columns=["category", "sugars_g"])
assert df.shape == (135, 2)
assert df.row(-1) == ("seafood", 1) # type: ignore[comparison-overlap]
assert df.row(0) == ("vegetables", 2) # type: ignore[comparison-overlap]
assert df.row(-1) == ("seafood", 1)
assert df.row(0) == ("vegetables", 2)

with pytest.raises(ValueError):
_ = pl.read_csv(path, dtypes=[pl.Utf8, pl.Int64, pl.Int64, pl.Int64])
Expand Down Expand Up @@ -581,7 +581,7 @@ def test_fallback_chrono_parser() -> None:
"""
)
df = pl.read_csv(data.encode(), parse_dates=True)
assert df.null_count().row(0) == (0, 0) # type: ignore[comparison-overlap]
assert df.null_count().row(0) == (0, 0)


def test_csv_string_escaping() -> None:
Expand Down
36 changes: 33 additions & 3 deletions py-polars/tests/unit/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest

import polars as pl
from polars.exceptions import NoRowsReturned, TooManyRowsReturned
from polars.testing import assert_frame_equal, assert_series_equal, columns

if TYPE_CHECKING:
Expand Down Expand Up @@ -826,9 +827,38 @@ def test_df_fold() -> None:

def test_row_tuple() -> None:
df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
assert df.row(0) == ("foo", 1, 1.0) # type: ignore[comparison-overlap]
assert df.row(1) == ("bar", 2, 2.0) # type: ignore[comparison-overlap]
assert df.row(-1) == ("2", 3, 3.0) # type: ignore[comparison-overlap]

# return row by index
assert df.row(0) == ("foo", 1, 1.0)
assert df.row(1) == ("bar", 2, 2.0)
assert df.row(-1) == ("2", 3, 3.0)

# return row by predicate
assert df.row(by_predicate=pl.col("a") == "bar") == ("bar", 2, 2.0)
assert df.row(by_predicate=pl.col("b").is_in([2, 4, 6])) == ("bar", 2, 2.0)

# expected error conditions
with pytest.raises(TooManyRowsReturned):
df.row(by_predicate=pl.col("b").is_in([1, 3, 5]))

with pytest.raises(NoRowsReturned):
df.row(by_predicate=pl.col("a") == "???")

# cannot set both 'index' and 'by_predicate'
with pytest.raises(ValueError):
df.row(0, by_predicate=pl.col("a") == "bar")

# must call 'by_predicate' by keyword
with pytest.raises(TypeError):
df.row(None, pl.col("a") == "bar") # type: ignore[misc]

# cannot pass predicate into 'index'
with pytest.raises(TypeError):
df.row(pl.col("a") == "bar") # type: ignore[arg-type]

# at least one of 'index' and 'by_predicate' must be set
with pytest.raises(ValueError):
df.row()


def test_df_apply() -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_regex_in_filter() -> None:
pl.fold(acc=False, f=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3))
).row(0)
expected = (1, "foo", 1.0)
assert res == expected # type: ignore[comparison-overlap]
assert res == expected


def test_arr_contains() -> None:
Expand Down

0 comments on commit f1726fd

Please sign in to comment.