Skip to content

Commit

Permalink
Add tests and fix type annotations (#1860)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 22, 2021
1 parent c4782b6 commit 62557ee
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 22 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def from_dicts(dicts: Sequence[Dict[str, Any]]) -> DataFrame:


def from_arrow(
a: Union["pa.Table", "pa.Array"], rechunk: bool = True
a: Union["pa.Table", "pa.Array", "pa.ChunkedArray"], rechunk: bool = True
) -> Union[DataFrame, Series]:
"""
Create a DataFrame or Series from an Arrow Table or Array.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def cache(
"""
return wrap_ldf(self._ldf.cache())

def filter(self, predicate: "pli.Expr") -> "LazyFrame":
def filter(self, predicate: Union["pli.Expr", str]) -> "LazyFrame":
"""
Filter the rows in the DataFrame based on a predicate expression.
Expand Down
22 changes: 21 additions & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,16 @@ def last(column: Union[str, "pli.Series"]) -> "pli.Expr":
return col(column).last()


@tp.overload
def head(column: str, n: Optional[int]) -> "pli.Expr":
...


@tp.overload
def head(column: "pli.Series", n: Optional[int]) -> "pli.Series":
...


def head(
column: Union[str, "pli.Series"], n: Optional[int] = None
) -> Union["pli.Expr", "pli.Series"]:
Expand All @@ -449,6 +459,16 @@ def head(
return col(column).head(n)


@tp.overload
def tail(column: str, n: Optional[int]) -> "pli.Expr":
...


@tp.overload
def tail(column: "pli.Series", n: Optional[int]) -> "pli.Series":
...


def tail(
column: Union[str, "pli.Series"], n: Optional[int] = None
) -> Union["pli.Expr", "pli.Series"]:
Expand Down Expand Up @@ -861,7 +881,7 @@ def argsort_by(
Default is ascending.
"""
if not isinstance(reverse, list):
reverse = [reverse]
reverse = [reverse] * len(exprs)
exprs = pli._selection_to_pyexpr_list(exprs)
return pli.wrap_expr(pyargsort_by(exprs, reverse))

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,7 +1123,7 @@ def append(self, other: "Series") -> None:
"""
self._s.append(other._s)

def filter(self, predicate: "Series") -> "Series":
def filter(self, predicate: Union["Series", list]) -> "Series":
"""
Filter elements by a boolean mask.
Expand Down
12 changes: 9 additions & 3 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_init_ndarray() -> None:

# 3D array
with pytest.raises(ValueError):
df = pl.DataFrame(np.random.randn(2, 2, 2))
_ = pl.DataFrame(np.random.randn(2, 2, 2))


# TODO: Remove this test case when removing deprecated behaviour
Expand Down Expand Up @@ -205,6 +205,10 @@ def test_init_records() -> None:
assert df.frame_equal(expected)
assert df.to_dicts() == dicts

df_cd = pl.DataFrame(dicts, columns=["c", "d"])
expected = pl.DataFrame({"c": [1, 2, 1], "d": [2, 1, 2]})
assert df_cd.frame_equal(expected)


def test_selection() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0], "c": ["a", "b", "c"]})
Expand Down Expand Up @@ -786,7 +790,7 @@ def test_describe() -> None:
}
)
assert df.describe().shape != df.shape
assert set(df.describe().select_at_idx(2)) == set([1.0, 4.0, 5.0, 6.0])
assert set(df.describe().select_at_idx(2)) == {1.0, 4.0, 5.0, 6.0}


def test_string_cache_eager_lazy() -> None:
Expand Down Expand Up @@ -829,6 +833,9 @@ def test_argsort_by(df: pl.DataFrame) -> None:
a = df[pl.argsort_by(["int_nulls", "floats"], reverse=[False, True])]["int_nulls"]
assert a == [1, 0, 3]

a = df[pl.argsort_by(["int_nulls", "floats"], reverse=False)]["int_nulls"]
assert a == [1, 0, 2]


def test_literal_series() -> None:
df = pl.DataFrame(
Expand Down Expand Up @@ -1066,7 +1073,6 @@ def test_asof_join() -> None:
2016-05-25 13:30:00.075""".split(
"\n"
)
dates

ticker = """GOOG
MSFT
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ def test_horizontal_agg(fruits_cars: pl.DataFrame) -> None:

out = df.select(pl.min([pl.col("A"), pl.col("B")]))
assert out[:, 0].to_list() == [1, 2, 3, 2, 1]


def test_prefix(fruits_cars: pl.DataFrame) -> None:
df = fruits_cars
out = df.select([pl.all().suffix("_reverse")])
assert out.columns == ["A_reverse", "fruits_reverse", "B_reverse", "cars_reverse"]
3 changes: 2 additions & 1 deletion py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def test_from_pandas_datetime() -> None:

def test_arrow_list_roundtrip() -> None:
# https://github.com/pola-rs/polars/issues/1064
pl.from_arrow(pa.table({"a": [1], "b": [[1, 2]]})).to_arrow()
tbl = pa.table({"a": [1], "b": [[1, 2]]})
assert pl.from_arrow(tbl).to_arrow().shape == tbl.shape


def test_arrow_dict_to_polars() -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pickle
import zlib
from pathlib import Path
from typing import Callable, Dict, List, Tuple, Type
from typing import Dict, Type

import numpy as np
import pandas as pd
Expand Down

0 comments on commit 62557ee

Please sign in to comment.