Skip to content

Commit

Permalink
Fix typing & add some tests (#1937)
Browse files Browse the repository at this point in the history

A number of `# type: ignore` statements were not needed any more, or could easily be fixed. Also added some tests to cover some corner cases not tested yet.

* Remove literals
  • Loading branch information
zundertj committed Dec 2, 2021
1 parent 762aad3 commit 8a8bbe1
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 136 deletions.
2 changes: 2 additions & 0 deletions py-polars/polars/datatypes_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Int32,
Int64,
Object,
Time,
UInt8,
UInt16,
UInt32,
Expand Down Expand Up @@ -44,6 +45,7 @@
UInt64: PySeries.new_opt_u64,
Date: PySeries.new_opt_i32,
Datetime: PySeries.new_opt_i32,
Time: PySeries.new_opt_i32,
Boolean: PySeries.new_opt_bool,
Utf8: PySeries.new_str,
Object: PySeries.new_object,
Expand Down
13 changes: 4 additions & 9 deletions py-polars/polars/internals/construction.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,9 @@ def sequence_to_pyseries(
if dtype is not None:
constructor = polars_type_to_constructor(dtype)
pyseries = constructor(name, values, strict)
if dtype == Date:
pyseries = pyseries.cast(str(Date), True)
elif dtype == Datetime:
pyseries = pyseries.cast(str(Datetime), True)
elif dtype == Time:
pyseries = pyseries.cast(str(Time), True)
elif dtype == Categorical:
pyseries = pyseries.cast(str(Categorical), True)

if dtype in (Date, Datetime, Time, Categorical):
pyseries = pyseries.cast(str(dtype), True)

return pyseries

Expand Down Expand Up @@ -155,7 +150,7 @@ def sequence_to_pyseries(
else:
try:
nested_arrow_dtype = py_type_to_arrow_type(nested_dtype)
except ValueError as e:
except ValueError as e: # pragma: no cover
raise ValueError(
f"Cannot construct Series from sequence of {nested_dtype}."
) from e
Expand Down
108 changes: 57 additions & 51 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def _selection_to_pyexpr_list(
exprs: Union[str, "Expr", Sequence[Union[str, "Expr"]]]
exprs: Union[str, "Expr", Sequence[Union[str, "Expr"]], "pli.Series"]
) -> tp.List["PyExpr"]:
pyexpr_list: tp.List[PyExpr]
if isinstance(exprs, Sequence) and not isinstance(exprs, str):
Expand All @@ -47,7 +47,7 @@ class Expr:
"""

def __init__(self) -> None:
self._pyexpr: PyExpr
self._pyexpr: PyExpr # pragma: no cover

@staticmethod
def _from_pyexpr(pyexpr: "PyExpr") -> "Expr":
Expand Down Expand Up @@ -166,7 +166,7 @@ def lt(self, other: "Expr") -> "Expr":
return wrap_expr(self._pyexpr.lt(other._pyexpr))

def __neg__(self) -> "Expr":
return pli.lit(0) - self # type: ignore
return pli.lit(0) - self

def __array_ufunc__(
self, ufunc: Callable[..., Any], method: str, *inputs: Any, **kwargs: Any
Expand All @@ -175,10 +175,11 @@ def __array_ufunc__(
Numpy universal functions.
"""
out_type = ufunc(np.array([1])).dtype
dtype: Optional[Type[DataType]]
if "float" in str(out_type):
dtype = Float64 # type: ignore
dtype = Float64
else:
dtype = None # type: ignore
dtype = None

args = [inp for inp in inputs if not isinstance(inp, Expr)]

Expand Down Expand Up @@ -789,12 +790,12 @@ def take(
Values taken by index
"""
if isinstance(index, (list, np.ndarray)):
index = pli.lit(pli.Series("", index, dtype=UInt32)) # type: ignore
index_lit = pli.lit(pli.Series("", index, dtype=UInt32))
elif isinstance(index, pli.Series):
index = pli.lit(index) # type: ignore
index_lit = pli.lit(index)
else:
index = pli.expr_to_lit_or_expr(index, str_to_lit=False) # type: ignore
return pli.wrap_expr(self._pyexpr.take(index._pyexpr)) # type: ignore
index_lit = pli.expr_to_lit_or_expr(index, str_to_lit=False)
return pli.wrap_expr(self._pyexpr.take(index_lit._pyexpr))

def shift(self, periods: int = 1) -> "Expr":
"""
Expand All @@ -808,7 +809,9 @@ def shift(self, periods: int = 1) -> "Expr":
"""
return wrap_expr(self._pyexpr.shift(periods))

def shift_and_fill(self, periods: int, fill_value: "Expr") -> "Expr":
def shift_and_fill(
self, periods: int, fill_value: Union[int, float, bool, str, "Expr"]
) -> "Expr":
"""
Shift the values by a given period and fill the parts that will be empty due to this operation
with the result of the `fill_value` expression.
Expand All @@ -823,7 +826,7 @@ def shift_and_fill(self, periods: int, fill_value: "Expr") -> "Expr":
fill_value = expr_to_lit_or_expr(fill_value, str_to_lit=True)
return wrap_expr(self._pyexpr.shift_and_fill(periods, fill_value._pyexpr))

def fill_null(self, fill_value: Union[str, int, float, "Expr"]) -> "Expr":
def fill_null(self, fill_value: Union[int, float, bool, str, "Expr"]) -> "Expr":
"""
Fill none value with a fill value or strategy
Expand Down Expand Up @@ -853,7 +856,7 @@ def fill_null(self, fill_value: Union[str, int, float, "Expr"]) -> "Expr":
fill_value = expr_to_lit_or_expr(fill_value, str_to_lit=True)
return wrap_expr(self._pyexpr.fill_null(fill_value._pyexpr))

def fill_nan(self, fill_value: Union[str, int, float, "Expr"]) -> "Expr":
def fill_nan(self, fill_value: Union[str, int, float, bool, "Expr"]) -> "Expr":
"""
Fill none value with a fill value
"""
Expand Down Expand Up @@ -969,7 +972,7 @@ def over(self, expr: Union[str, "Expr", tp.List[Union["Expr", str]]]) -> "Expr":
Examples
--------
>>> df = DataFrame(
>>> df = pl.DataFrame(
... {
... "groups": [1, 1, 2, 2, 1, 2, 3, 3, 1],
... "values": [1, 2, 3, 4, 5, 6, 7, 8, 8],
Expand Down Expand Up @@ -1261,27 +1264,6 @@ def is_between(
expr = self
return ((expr > start) & (expr < end)).alias("is_between")

@property
def dt(self) -> "ExprDateTimeNameSpace":
"""
Create an object namespace of all datetime related methods.
"""
return ExprDateTimeNameSpace(self)

@property
def str(self) -> "ExprStringNameSpace":
"""
Create an object namespace of all string related methods.
"""
return ExprStringNameSpace(self)

@property
def arr(self) -> "ExprListNameSpace":
"""
Create an object namespace of all datetime related methods.
"""
return ExprListNameSpace(self)

def hash(self, k0: int = 0, k1: int = 1, k2: int = 2, k3: int = 3) -> "Expr":
"""
Hash the Series.
Expand Down Expand Up @@ -1315,15 +1297,15 @@ def reinterpret(self, signed: bool) -> "Expr":
"""
return wrap_expr(self._pyexpr.reinterpret(signed))

def inspect(self, fmt: str = "{}") -> "Expr": # type: ignore
def inspect(self, fmt: str = "{}") -> "Expr":
"""
Prints the value that this expression evaluates to and passes on the value.
>>> df.select(pl.col("foo").cumsum().inspect("value is: {}").alias("bar"))
"""

def inspect(s: "pli.Series") -> "pli.Series":
print(fmt.format(s)) # type: ignore
print(fmt.format(s))
return s

return self.map(inspect, return_dtype=None, agg_list=True)
Expand Down Expand Up @@ -1665,7 +1647,7 @@ def argsort(self, reverse: bool = False) -> "Expr":
"""
return pli.argsort_by([self], [reverse])

def rank(self, method: str = "average") -> "Expr": # type: ignore
def rank(self, method: str = "average") -> "Expr":
"""
Assign ranks to data, dealing with ties appropriately.
Expand All @@ -1692,7 +1674,7 @@ def rank(self, method: str = "average") -> "Expr": # type: ignore
"""
return wrap_expr(self._pyexpr.rank(method))

def diff(self, n: int = 1, null_behavior: str = "ignore") -> "Expr": # type: ignore
def diff(self, n: int = 1, null_behavior: str = "ignore") -> "Expr":
"""
Calculate the n-th discrete difference.
Expand Down Expand Up @@ -1770,14 +1752,14 @@ def clip(self, min_val: Union[int, float], max_val: Union[int, float]) -> "Expr"
min_val, max_val
Minimum and maximum value.
"""
min_val = pli.lit(min_val) # type: ignore
max_val = pli.lit(max_val) # type: ignore
min_val_lit = pli.lit(min_val)
max_val_lit = pli.lit(max_val)

return (
pli.when(self < min_val) # type: ignore
.then(min_val)
.when(self > max_val)
.then(max_val)
pli.when(self < min_val_lit)
.then(min_val_lit)
.when(self > max_val_lit)
.then(max_val_lit)
.otherwise(self)
).keep_name()

Expand All @@ -1793,7 +1775,7 @@ def upper_bound(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.upper_bound())

def str_concat(self, delimiter: str = "-") -> "Expr": # type: ignore
def str_concat(self, delimiter: str = "-") -> "Expr":
"""
Vertically concat the values in the Series to a single string value.
Expand Down Expand Up @@ -1965,6 +1947,30 @@ def reshape(self, dims: tp.Tuple[int, ...]) -> "Expr":
"""
return wrap_expr(self._pyexpr.reshape(dims))

# Below are the namespaces defined. Keep these at the end of the definition of Expr, as to not confuse mypy with
# the type annotation `str` with the namespace "str"

@property
def dt(self) -> "ExprDateTimeNameSpace":
"""
Create an object namespace of all datetime related methods.
"""
return ExprDateTimeNameSpace(self)

@property
def str(self) -> "ExprStringNameSpace":
"""
Create an object namespace of all string related methods.
"""
return ExprStringNameSpace(self)

@property
def arr(self) -> "ExprListNameSpace":
"""
Create an object namespace of all datetime related methods.
"""
return ExprListNameSpace(self)


class ExprListNameSpace:
"""
Expand Down Expand Up @@ -2022,7 +2028,7 @@ def unique(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.lst_unique())

def concat(self, other: Union[tp.List[Expr], Expr, str, tp.List[str]]) -> "Expr":
def concat(self, other: Union[tp.List[Union[Expr, str]], Expr, str]) -> "Expr":
"""
Concat the arrays in a Series dtype List in linear time.
Expand All @@ -2031,13 +2037,13 @@ def concat(self, other: Union[tp.List[Expr], Expr, str, tp.List[str]]) -> "Expr"
other
Columns to concat into a List Series
"""
other_list: tp.List[Union[Expr, str]]
if not isinstance(other, list):
other = [other] # type: ignore
other_list = [other]
else:
other = copy.copy(other)
# mypy does not understand we have a list by now
other.insert(0, wrap_expr(self._pyexpr)) # type: ignore
return pli.concat_list(other) # type: ignore
other_list = copy.copy(other)
other_list.insert(0, wrap_expr(self._pyexpr))
return pli.concat_list(other_list)


class ExprStringNameSpace:
Expand Down
27 changes: 13 additions & 14 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,9 +757,7 @@ def to_csv(
def to_ipc(
self,
file: Union[BinaryIO, BytesIO, str, Path],
compression: Optional[
Union[Literal["uncompressed", "lz4", "zstd"], str]
] = "uncompressed",
compression: Optional[Literal["uncompressed", "lz4", "zstd"]] = "uncompressed",
) -> None:
"""
Write to Arrow IPC binary stream, or a feather file.
Expand Down Expand Up @@ -2162,18 +2160,18 @@ def upsample(self, by: str, interval: timedelta) -> "DataFrame":
low = bounds["low"].dt[0]
high = bounds["high"].dt[0]
upsampled = pli.date_range(low, high, interval, name=by)
return DataFrame(upsampled).join(self, on=by, how="left") # type: ignore
return DataFrame(upsampled).join(self, on=by, how="left")

def join(
self,
df: "DataFrame",
left_on: Optional[
Union[str, "pli.Expr", tp.List[str], tp.List["pli.Expr"]]
Union[str, "pli.Expr", tp.List[Union[str, "pli.Expr"]]]
] = None,
right_on: Optional[
Union[str, "pli.Expr", tp.List[str], tp.List["pli.Expr"]]
Union[str, "pli.Expr", tp.List[Union[str, "pli.Expr"]]]
] = None,
on: Optional[Union[str, tp.List[str]]] = None,
on: Optional[Union[str, "pli.Expr", tp.List[Union[str, "pli.Expr"]]]] = None,
how: str = "inner",
suffix: str = "_right",
asof_by: Optional[Union[str, tp.List[str]]] = None,
Expand Down Expand Up @@ -2264,15 +2262,15 @@ def join(
if how == "cross":
return wrap_df(self._df.join(df._df, [], [], how, suffix))

left_on_: Union[tp.List[str], tp.List[pli.Expr], None]
left_on_: Optional[tp.List[Union[str, pli.Expr]]]
if isinstance(left_on, (str, pli.Expr)):
left_on_ = [left_on] # type: ignore[assignment]
left_on_ = [left_on]
else:
left_on_ = left_on

right_on_: Union[tp.List[str], tp.List[pli.Expr], None]
right_on_: Optional[tp.List[Union[str, pli.Expr]]]
if isinstance(right_on, (str, pli.Expr)):
right_on_ = [right_on] # type: ignore[assignment]
right_on_ = [right_on]
else:
right_on_ = right_on

Expand Down Expand Up @@ -2864,6 +2862,7 @@ def select(
Sequence[bool],
Sequence[int],
Sequence[float],
"pli.Series",
],
) -> "DataFrame":
"""
Expand Down Expand Up @@ -3449,7 +3448,7 @@ def interpolate(self) -> "DataFrame":
"""
Interpolate intermediate values. The interpolation method is linear.
"""
return self.select(pli.col("*").interpolate()) # type: ignore
return self.select(pli.col("*").interpolate())

def is_empty(self) -> bool:
"""
Expand Down Expand Up @@ -3735,7 +3734,7 @@ def head(self, n: int = 5) -> DataFrame:
wrap_df(self._df)
.lazy()
.groupby(self.by, self.maintain_order)
.head(n) # type: ignore[arg-type]
.head(n)
.collect(no_optimization=True, string_cache=False)
)

Expand Down Expand Up @@ -3799,7 +3798,7 @@ def tail(self, n: int = 5) -> DataFrame:
wrap_df(self._df)
.lazy()
.groupby(self.by, self.maintain_order)
.tail(n) # type: ignore[arg-type]
.tail(n)
.collect(no_optimization=True, string_cache=False)
)

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def concat(
out = pli.wrap_s(_concat_series(items))

if rechunk:
return out.rechunk() # type: ignore
return out.rechunk()
return out


Expand Down

0 comments on commit 8a8bbe1

Please sign in to comment.