Skip to content

Commit

Permalink
Proposal: Enable mypy on tests (#1798)
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Nov 17, 2021
1 parent 46a7dc2 commit a9570f3
Show file tree
Hide file tree
Showing 18 changed files with 368 additions and 346 deletions.
10 changes: 6 additions & 4 deletions py-polars/polars/eager/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,7 +701,7 @@ def to_pandas(

def to_csv(
self,
file: Optional[Union[TextIO, str, Path]] = None,
file: Optional[Union[TextIO, BytesIO, str, Path]] = None,
has_headers: bool = True,
sep: str = ",",
) -> Optional[str]:
Expand Down Expand Up @@ -740,7 +740,9 @@ def to_csv(
return None

def to_ipc(
self, file: Union[BinaryIO, str, Path], compression: str = "uncompressed"
self,
file: Union[BinaryIO, BytesIO, str, Path],
compression: str = "uncompressed",
) -> None:
"""
Write to Arrow IPC binary stream, or a feather file.
Expand Down Expand Up @@ -795,7 +797,7 @@ def transpose(

def to_parquet(
self,
file: Union[str, Path],
file: Union[str, Path, BytesIO],
compression: Optional[str] = "snappy",
use_pyarrow: bool = False,
**kwargs: Any,
Expand Down Expand Up @@ -1473,7 +1475,7 @@ def replace_at_idx(self, index: int, series: "pl.Series") -> None:

def sort(
self,
by: Union[str, "pl.Expr", tp.List["pl.Expr"]],
by: Union[str, "pl.Expr", tp.List[str], tp.List["pl.Expr"]],
reverse: Union[bool, tp.List[bool]] = False,
in_place: bool = False,
) -> Optional["DataFrame"]:
Expand Down
14 changes: 8 additions & 6 deletions py-polars/polars/eager/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,7 +1516,7 @@ def is_not_nan(self) -> "Series":
"""
return Series._from_pyseries(self._s.is_not_nan())

def is_in(self, other: "Series") -> "Series":
def is_in(self, other: Union["Series", tp.List]) -> "Series":
"""
Check if elements of this Series are in the right Series, or List values of the right Series.
Expand All @@ -1539,7 +1539,7 @@ def is_in(self, other: "Series") -> "Series":
"""
if type(other) is list:
other = Series("", other)
return wrap_s(self._s.is_in(other._s))
return wrap_s(self._s.is_in(other._s)) # type: ignore

def arg_true(self) -> "Series":
"""
Expand Down Expand Up @@ -2023,7 +2023,7 @@ def __copy__(self) -> "Series": # type: ignore
def __deepcopy__(self, memodict={}) -> "Series": # type: ignore
return self.clone()

def fill_null(self, strategy: Union[str, "pl.Expr"]) -> "Series":
def fill_null(self, strategy: Union[str, int, "pl.Expr"]) -> "Series":
"""
Fill null values with a filling strategy.
Expand Down Expand Up @@ -2319,7 +2319,9 @@ def shift(self, periods: int = 1) -> "Series":
"""
return wrap_s(self._s.shift(periods))

def shift_and_fill(self, periods: int, fill_value: "pl.Expr") -> "Series":
def shift_and_fill(
self, periods: int, fill_value: Union[int, "pl.Expr"]
) -> "Series":
"""
Shift the values by a given period and fill the parts that will be empty due to this operation
with the result of the `fill_value` expression.
Expand All @@ -2332,7 +2334,7 @@ def shift_and_fill(self, periods: int, fill_value: "pl.Expr") -> "Series":
Fill None values with the result of this expression.
"""
return self.to_frame().select(
pl.col(self.name).shift_and_fill(periods, fill_value)
pl.col(self.name).shift_and_fill(periods, fill_value) # type: ignore
)[self.name]

def zip_with(self, mask: "Series", other: "Series") -> "Series":
Expand Down Expand Up @@ -2999,7 +3001,7 @@ class StringNameSpace:
def __init__(self, series: "Series"):
self._s = series._s

def strptime(self, datatype: DataType, fmt: Optional[str] = None) -> Series:
def strptime(self, datatype: Type[DataType], fmt: Optional[str] = None) -> Series:
"""
Parse a Series of dtype Utf8 to a Date/Datetime Series.
Expand Down
6 changes: 3 additions & 3 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def update_columns(df: "pl.DataFrame", new_columns: List[str]) -> "pl.DataFrame"


def read_csv(
file: Union[str, TextIO, Path, BinaryIO, bytes],
file: Union[str, TextIO, BytesIO, Path, BinaryIO, bytes],
infer_schema_length: Optional[int] = 100,
batch_size: int = 8192,
has_headers: bool = True,
Expand Down Expand Up @@ -542,7 +542,7 @@ def read_ipc_schema(


def read_ipc(
file: Union[str, BinaryIO, Path, bytes],
file: Union[str, BinaryIO, BytesIO, Path, bytes],
columns: Optional[List[str]] = None,
projection: Optional[List[int]] = None,
stop_after_n_rows: Optional[int] = None,
Expand Down Expand Up @@ -609,7 +609,7 @@ def read_ipc(


def read_parquet(
source: Union[str, List[str], Path, BinaryIO, bytes],
source: Union[str, List[str], Path, BinaryIO, BytesIO, bytes],
columns: Optional[List[str]] = None,
projection: Optional[List[int]] = None,
stop_after_n_rows: Optional[int] = None,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/lazy/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,7 +1972,7 @@ def __init__(self, expr: Expr):

def strptime(
self,
datatype: Union[Date, Datetime],
datatype: Union[Type[Date], Type[Datetime]],
fmt: Optional[str] = None,
) -> Expr:
"""
Expand Down Expand Up @@ -2329,7 +2329,7 @@ def timestamp(self) -> Expr:


def expr_to_lit_or_expr(
expr: Union[Expr, int, float, str, tp.List[Expr], "pl.Series"],
expr: Union[Expr, int, float, str, tp.List[Expr], tp.List[str], "pl.Series"],
str_to_lit: bool = True,
) -> Expr:
"""
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/lazy/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def inspect(s: "pl.DataFrame") -> "pl.DataFrame":

def sort(
self,
by: Union[str, "Expr", tp.List["Expr"]],
by: Union[str, "Expr", tp.List[str], tp.List["Expr"]],
reverse: Union[bool, tp.List[bool]] = False,
) -> "LazyFrame":
"""
Expand Down
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ profile = "black"
[tool.mypy]
ignore_missing_imports = true
disallow_untyped_defs = true
files = "polars"
files = ["polars", "tests"]
2 changes: 2 additions & 0 deletions py-polars/tests/db-benchmark/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# type: ignore

import sys
import time

Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/files/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import polars as pl


def test_date_datetime():
def test_date_datetime() -> None:
df = pl.DataFrame(
{
"year": [2001, 2002, 2003],
Expand All @@ -13,9 +13,9 @@ def test_date_datetime():

out = df.select(
[
pl.all(),
pl.datetime("year", "month", "day", "hour").dt.hour().alias("h2"),
pl.date("year", "month", "day").dt.day().alias("date"),
pl.all(), # type: ignore
pl.datetime("year", "month", "day", "hour").dt.hour().alias("h2"), # type: ignore
pl.date("year", "month", "day").dt.day().alias("date"), # type: ignore
]
)

Expand Down
12 changes: 7 additions & 5 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from typing import List, Optional

import polars as pl


def test_apply_none():
def test_apply_none() -> None:
df = pl.DataFrame(
{
"g": [1, 1, 1, 2, 2, 2, 5],
Expand All @@ -12,7 +14,7 @@ def test_apply_none():

out = (
df.groupby("g", maintain_order=True).agg(
pl.apply(
pl.apply( # type: ignore
exprs=["a", pl.col("b") ** 4, pl.col("a") / 4],
f=lambda x: x[0] * x[1] + x[2].sum(),
).alias("multiple")
Expand All @@ -21,19 +23,19 @@ def test_apply_none():
assert out[0].to_list() == [4.75, 326.75, 82.75]
assert out[1].to_list() == [238.75, 3418849.75, 372.75]

out = df.select(pl.map(exprs=["a", "b"], f=lambda s: s[0] * s[1]))
out = df.select(pl.map(exprs=["a", "b"], f=lambda s: s[0] * s[1])) # type: ignore
assert out["a"].to_list() == (df["a"] * df["b"]).to_list()

# check if we can return None
def func(s):
def func(s: List) -> Optional[int]:
if s[0][0] == 190:
return None
else:
return s[0]

out = (
df.groupby("g", maintain_order=True).agg(
pl.apply(exprs=["a", pl.col("b") ** 4, pl.col("a") / 4], f=func).alias(
pl.apply(exprs=["a", pl.col("b") ** 4, pl.col("a") / 4], f=func).alias( # type: ignore
"multiple"
)
)
Expand Down
28 changes: 14 additions & 14 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import polars as pl


def test_fill_null():
def test_fill_null() -> None:
dt = datetime.strptime("2021-01-01", "%Y-%m-%d")
s = pl.Series("A", [dt, None])

for fill_val in (dt, pl.lit(dt)):
out = s.fill_null(fill_val)
out = s.fill_null(fill_val) # type: ignore

assert out.null_count() == 0
assert out.dt[0] == dt
Expand All @@ -18,17 +18,17 @@ def test_fill_null():
dt2 = date(2001, 1, 2)
dt3 = date(2001, 1, 3)
s = pl.Series("a", [dt1, dt2, dt3, None])
dt = date(2001, 1, 4)
for fill_val in (dt, pl.lit(dt)):
out = s.fill_null(fill_val)
dt_2 = date(2001, 1, 4)
for fill_val in (dt_2, pl.lit(dt_2)): # type: ignore
out = s.fill_null(fill_val) # type: ignore

assert out.null_count() == 0
assert out.dt[0] == dt1
assert out.dt[1] == dt2
assert out.dt[-1] == dt
assert out.dt[-1] == dt_2


def test_downsample():
def test_downsample() -> None:
s = pl.Series(
"datetime",
[
Expand Down Expand Up @@ -70,7 +70,7 @@ def test_downsample():
assert out["a"].dtype == "datetime64[ns]"


def test_filter_date():
def test_filter_date() -> None:
dataset = pl.DataFrame(
{"date": ["2020-01-02", "2020-01-03", "2020-01-04"], "index": [1, 2, 3]}
)
Expand All @@ -83,7 +83,7 @@ def test_filter_date():
assert df.filter(pl.col("date") < pl.lit(datetime(2020, 1, 5))).shape[0] == 3


def test_diff_datetime():
def test_diff_datetime() -> None:

df = pl.DataFrame(
{
Expand All @@ -104,7 +104,7 @@ def test_diff_datetime():
assert out[0] == out[1]


def test_timestamp():
def test_timestamp() -> None:
a = pl.Series("a", [10000, 20000, 30000], dtype=pl.Datetime)
assert a.dt.timestamp() == [10000, 20000, 30000]
out = a.dt.to_python_datetime()
Expand All @@ -117,7 +117,7 @@ def test_timestamp():
assert isinstance(df.row(0)[0], datetime)


def test_from_pydatetime():
def test_from_pydatetime() -> None:
dates = [
datetime(2021, 1, 1),
datetime(2021, 1, 2),
Expand All @@ -133,7 +133,7 @@ def test_from_pydatetime():
# fmt dates and nulls
print(s)

dates = [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3), None]
dates = [date(2021, 1, 1), date(2021, 1, 2), date(2021, 1, 3), None] # type: ignore
s = pl.Series("name", dates)
assert s.dtype == pl.Date
assert s.name == "name"
Expand All @@ -144,7 +144,7 @@ def test_from_pydatetime():
print(s)


def test_to_python_datetime():
def test_to_python_datetime() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
assert (
df.select(pl.col("a").cast(pl.Datetime).dt.to_python_datetime())["a"].dtype
Expand All @@ -155,7 +155,7 @@ def test_to_python_datetime():
)


def test_datetime_consistency():
def test_datetime_consistency() -> None:
dt = datetime(2021, 1, 1)
df = pl.DataFrame({"date": [dt]})
assert df["date"].dt[0] == dt
Expand Down

0 comments on commit a9570f3

Please sign in to comment.