Skip to content

Commit

Permalink
chore[python]: Enforce strict type equality in mypy check
Browse files Browse the repository at this point in the history
  • Loading branch information
zundertj committed Aug 14, 2022
1 parent 7e721d0 commit beb802a
Show file tree
Hide file tree
Showing 12 changed files with 53 additions and 41 deletions.
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ disallow_untyped_calls = true
warn_redundant_casts = true
# warn_return_any = true
no_implicit_reexport = true
# strict_equality = true
strict_equality = true
# TODO: When all flags are enabled, replace by strict = true
enable_error_code = [
"redundant-expr",
Expand Down
5 changes: 3 additions & 2 deletions py-polars/tests/db-benchmark/various.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# may contain many things that seemed to go wrong at scale

import time
from typing import cast

import numpy as np

Expand Down Expand Up @@ -50,8 +51,8 @@
computed = permuted.select(
[pl.col("id").min().alias("min"), pl.col("id").max().alias("max")]
)
assert computed[0, "min"] == minimum
assert computed[0, "max"] == maximum
assert cast(int, computed[0, "min"]) == minimum
assert cast(float, computed[0, "max"]) == maximum


def test_windows_not_cached() -> None:
Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import zlib
from datetime import date, datetime, time
from pathlib import Path
from typing import cast

import pytest

Expand Down Expand Up @@ -402,8 +403,8 @@ def test_csv_globbing(examples_dir: str) -> None:

df = pl.read_csv(path, columns=["category", "sugars_g"])
assert df.shape == (135, 2)
assert df.row(-1) == ("seafood", 1)
assert df.row(0) == ("vegetables", 2)
assert df.row(-1) == ("seafood", 1) # type: ignore[comparison-overlap]
assert df.row(0) == ("vegetables", 2) # type: ignore[comparison-overlap]

with pytest.raises(ValueError):
_ = pl.read_csv(path, dtypes=[pl.Utf8, pl.Int64, pl.Int64, pl.Int64])
Expand Down Expand Up @@ -509,7 +510,10 @@ def test_fallback_chrono_parser() -> None:
2021-10-10,2021-10-10
"""
)
assert pl.read_csv(data.encode(), parse_dates=True).null_count().row(0) == (0, 0)
assert cast(
tuple[int, int],
pl.read_csv(data.encode(), parse_dates=True).null_count().row(0),
) == (0, 0)


def test_csv_string_escaping() -> None:
Expand Down
3 changes: 2 additions & 1 deletion py-polars/tests/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ def test_init_dict() -> None:

# List of empty list/tuple
df = pl.DataFrame({"a": [[]], "b": [()]})
assert df.schema == {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
expected = {"a": pl.List(pl.Float64), "b": pl.List(pl.Float64)}
assert df.schema == expected # type: ignore[comparison-overlap]
assert df.rows() == [([], [])]

# Mixed dtypes
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def test_dtype_init_equivalence() -> None:
if inspect.isclass(dtype) and issubclass(dtype, datatypes.DataType)
}
for dtype in all_datatypes:
assert dtype == dtype()
assert dtype == dtype() # type: ignore[comparison-overlap]


def test_dtype_temporal_units() -> None:
Expand Down
10 changes: 5 additions & 5 deletions py-polars/tests/test_datelike.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from __future__ import annotations

import io
import typing
from datetime import date, datetime, time, timedelta
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, no_type_check

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -154,7 +153,7 @@ def test_datetime_consistency() -> None:
pl.lit(dt).cast(pl.Datetime("ns")).alias("dt_ns"),
]
)
assert ddf.schema == {
assert ddf.schema == { # type: ignore[comparison-overlap]
"date": pl.Datetime("us"),
"dt": pl.Datetime("us"),
"dt_ms": pl.Datetime("ms"),
Expand Down Expand Up @@ -886,7 +885,7 @@ def test_agg_logical() -> None:
assert s.min() == dates[0]


@typing.no_type_check
@no_type_check
def test_from_time_arrow() -> None:
times = pa.array([10, 20, 30], type=pa.time32("s"))
times_table = pa.table([times], names=["times"])
Expand Down Expand Up @@ -1027,7 +1026,8 @@ def test_datetime_instance_selection() -> None:
],
)
for tu in DTYPE_TEMPORAL_UNITS:
assert df.select(pl.col([pl.Datetime(tu)])).dtypes == [pl.Datetime(tu)]
res = df.select(pl.col([pl.Datetime(tu)])).dtypes
assert res == [pl.Datetime(tu)] # type: ignore[comparison-overlap]
assert len(df.filter(pl.col(tu) == test_data[tu][0])) == 1


Expand Down
26 changes: 13 additions & 13 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def test_selection() -> None:

# select columns by mask
assert df[:2, :1].shape == (2, 1)
assert df[:2, "a"].shape == (2, 1)
assert df[:2, "a"].shape == (2, 1) # type: ignore[comparison-overlap]

# column selection by string(s) in first dimension
assert df["a"].to_list() == [1, 2, 3]
Expand All @@ -117,11 +117,11 @@ def test_selection() -> None:
assert df[[1, 2], [1, 2]].frame_equal(
pl.DataFrame({"b": [2.0, 3.0], "c": ["b", "c"]})
)
assert df[1, 2] == "b"
assert df[1, 1] == 2.0
assert df[2, 0] == 3
assert typing.cast(str, df[1, 2]) == "b"
assert typing.cast(float, df[1, 1]) == 2.0
assert typing.cast(int, df[2, 0]) == 3

assert df[[0, 1], "b"].shape == (2, 1)
assert df[[0, 1], "b"].shape == (2, 1) # type: ignore[comparison-overlap]
assert df[[2], ["a", "b"]].shape == (1, 2)
assert df.to_series(0).name == "a"
assert (df["a"] == df["a"]).sum() == 3
Expand All @@ -132,10 +132,10 @@ def test_selection() -> None:
assert df[1, [2]].frame_equal(expect)
expect = pl.DataFrame({"b": [1.0, 3.0]})
assert df[[0, 2], [1]].frame_equal(expect)
assert df[0, "c"] == "a"
assert df[1, "c"] == "b"
assert df[2, "c"] == "c"
assert df[0, "a"] == 1
assert typing.cast(str, df[0, "c"]) == "a"
assert typing.cast(str, df[1, "c"]) == "b"
assert typing.cast(str, df[2, "c"]) == "c"
assert typing.cast(int, df[0, "a"]) == 1

# more slicing
expect = pl.DataFrame({"a": [3, 2, 1], "b": [3.0, 2.0, 1.0], "c": ["c", "b", "a"]})
Expand Down Expand Up @@ -766,9 +766,9 @@ def test_df_fold() -> None:

def test_row_tuple() -> None:
df = pl.DataFrame({"a": ["foo", "bar", "2"], "b": [1, 2, 3], "c": [1.0, 2.0, 3.0]})
assert df.row(0) == ("foo", 1, 1.0)
assert df.row(1) == ("bar", 2, 2.0)
assert df.row(-1) == ("2", 3, 3.0)
assert df.row(0) == ("foo", 1, 1.0) # type: ignore[comparison-overlap]
assert df.row(1) == ("bar", 2, 2.0) # type: ignore[comparison-overlap]
assert df.row(-1) == ("2", 3, 3.0) # type: ignore[comparison-overlap]


def test_df_apply() -> None:
Expand Down Expand Up @@ -1058,7 +1058,7 @@ def dot_product() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [2, 2, 2, 2]})

assert df["a"].dot(df["b"]) == 20
assert df.select([pl.col("a").dot("b")])[0, "a"] == 20
assert typing.cast(int, df.select([pl.col("a").dot("b")])[0, "a"]) == 20


def test_hash_rows() -> None:
Expand Down
10 changes: 7 additions & 3 deletions py-polars/tests/test_exprs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import cast

import polars as pl
from polars.testing import assert_series_equal, verify_series_and_expr_api

Expand Down Expand Up @@ -82,7 +84,7 @@ def test_count_expr() -> None:

out = df.select(pl.count())
assert out.shape == (1, 1)
assert out[0, 0] == 5
assert cast(int, out[0, 0]) == 5

out = df.groupby("b", maintain_order=True).agg(pl.count())
assert out["b"].to_list() == ["a", "b"]
Expand Down Expand Up @@ -274,9 +276,11 @@ def test_regex_in_filter() -> None:
}
)

assert df.filter(
res = df.filter(
pl.fold(acc=False, f=lambda acc, s: acc | s, exprs=(pl.col("^nrs|flt*$") < 3))
).row(0) == (1, "foo", 1.0)
).row(0)
expected = (1, "foo", 1.0)
assert res == expected # type: ignore[comparison-overlap]


def test_arr_contains() -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def test_from_pandas_ns_resolution() -> None:
[pd.Timestamp(year=2021, month=1, day=1, hour=1, second=1, nanosecond=1)],
columns=["date"],
)
assert pl.from_pandas(df)[0, 0] == datetime(2021, 1, 1, 1, 0, 1)
assert cast(datetime, pl.from_pandas(df)[0, 0]) == datetime(2021, 1, 1, 1, 0, 1)


@no_type_check
Expand Down
20 changes: 11 additions & 9 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def test_is_finite_is_infinite() -> None:

def test_len() -> None:
df = pl.DataFrame({"nrs": [1, 2, 3]})
assert df.select(col("nrs").len())[0, 0] == 3
assert cast(int, df.select(col("nrs").len())[0, 0]) == 3


def test_cum_agg() -> None:
Expand Down Expand Up @@ -505,7 +505,7 @@ def test_round() -> None:

def test_dot() -> None:
df = pl.DataFrame({"a": [1.8, 1.2, 3.0], "b": [3.2, 1, 2]})
assert df.select(pl.col("a").dot(pl.col("b")))[0, 0] == 12.96
assert cast(float, df.select(pl.col("a").dot(pl.col("b")))[0, 0]) == 12.96


def test_sort() -> None:
Expand Down Expand Up @@ -696,8 +696,8 @@ def test_rolling(fruits_cars: pl.DataFrame) -> None:
]
)

assert out_single_val_variance[0, "std"] == 0.0
assert out_single_val_variance[0, "var"] == 0.0
assert cast(float, out_single_val_variance[0, "std"]) == 0.0
assert cast(float, out_single_val_variance[0, "var"]) == 0.0


def test_rolling_apply() -> None:
Expand Down Expand Up @@ -993,16 +993,16 @@ def test_join_suffix() -> None:
def test_str_concat() -> None:
df = pl.DataFrame({"foo": [1, None, 2]})
df = df.select(pl.col("foo").str.concat("-"))
assert df[0, 0] == "1-null-2"
assert cast(str, df[0, 0]) == "1-null-2"


@pytest.mark.parametrize("no_optimization", [False, True])
def test_collect_all(df: pl.DataFrame, no_optimization: bool) -> None:
lf1 = df.lazy().select(pl.col("int").sum())
lf2 = df.lazy().select((pl.col("floats") * 2).sum())
out = pl.collect_all([lf1, lf2], no_optimization=no_optimization)
assert out[0][0, 0] == 6
assert out[1][0, 0] == 12.0
assert cast(int, out[0][0, 0]) == 6
assert cast(float, out[1][0, 0]) == 12.0


def test_spearman_corr() -> None:
Expand Down Expand Up @@ -1058,8 +1058,10 @@ def test_pearson_corr() -> None:


def test_cov(fruits_cars: pl.DataFrame) -> None:
assert fruits_cars.select(pl.cov("A", "B"))[0, 0] == -2.5
assert fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0] == -2.5
assert cast(float, fruits_cars.select(pl.cov("A", "B"))[0, 0]) == -2.5
assert (
cast(float, fruits_cars.select(pl.cov(pl.col("A"), pl.col("B")))[0, 0]) == -2.5
)


def test_std(fruits_cars: pl.DataFrame) -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_dtype() -> None:
("dtm", pl.List(pl.Datetime)),
],
)
assert df.schema == {
assert df.schema == { # type: ignore[comparison-overlap]
"i": pl.List(pl.Int8),
"tm": pl.List(pl.Time),
"dt": pl.List(pl.Date),
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def test_when_then_edge_cases_3994() -> None:
.groupby(["id"])
.agg(pl.col("type"))
.with_column(
pl.when(pl.col("type").arr.lengths == 0)
pl.when(pl.col("type").arr.lengths() == 0)
.then(pl.lit(None))
.otherwise(pl.col("type"))
.keep_name()
Expand Down

0 comments on commit beb802a

Please sign in to comment.