Skip to content

Commit

Permalink
feat[python]: Add string literal types for better type checking (#4400)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 14, 2022
1 parent 1444cd1 commit f2829dc
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 29 deletions.
14 changes: 8 additions & 6 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

if TYPE_CHECKING:
from polars.internals.type_aliases import (
CategoricalOrdering,
ClosedWindow,
EpochTimeUnit,
FillNullStrategy,
Expand Down Expand Up @@ -7129,17 +7130,18 @@ class ExprCatNameSpace:
def __init__(self, expr: Expr):
self._pyexpr = expr._pyexpr

def set_ordering(self, ordering: str) -> Expr:
def set_ordering(self, ordering: CategoricalOrdering) -> Expr:
"""
Determine how this categorical series should be sorted.
Parameters
----------
ordering
One of:
- 'physical' -> use the physical representation of the categories to
determine the order (default)
- 'lexical' -. use the string values to determine the ordering
ordering : {'physical', 'lexical'}
Ordering type:
- 'physical' -> Use the physical representation of the categories to
determine the order (default).
- 'lexical' -> Use the string values to determine the ordering.
Examples
--------
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
FillNullStrategy,
InterpolationMethod,
IpcCompression,
JoinStrategy,
NullStrategy,
Orientation,
ParallelStrategy,
Expand Down Expand Up @@ -3504,7 +3505,7 @@ def join(
left_on: str | pli.Expr | list[str | pli.Expr] | None = None,
right_on: str | pli.Expr | list[str | pli.Expr] | None = None,
on: str | pli.Expr | list[str | pli.Expr] | None = None,
how: str = "inner",
how: JoinStrategy = "inner",
suffix: str = "_right",
) -> DataFrame:
"""
Expand Down
18 changes: 8 additions & 10 deletions py-polars/polars/internals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
_DOCUMENTING = True

if TYPE_CHECKING:
from polars.internals.type_aliases import ClosedWindow, TimeUnit
from polars.internals.type_aliases import ClosedWindow, ConcatMethod, TimeUnit


def get_dummies(df: pli.DataFrame) -> pli.DataFrame:
Expand All @@ -44,7 +44,7 @@ def get_dummies(df: pli.DataFrame) -> pli.DataFrame:
def concat(
items: Sequence[pli.DataFrame],
rechunk: bool = True,
how: str = "vertical",
how: ConcatMethod = "vertical",
) -> pli.DataFrame:
...

Expand All @@ -53,7 +53,7 @@ def concat(
def concat(
items: Sequence[pli.Series],
rechunk: bool = True,
how: str = "vertical",
how: ConcatMethod = "vertical",
) -> pli.Series:
...

Expand All @@ -62,7 +62,7 @@ def concat(
def concat(
items: Sequence[pli.LazyFrame],
rechunk: bool = True,
how: str = "vertical",
how: ConcatMethod = "vertical",
) -> pli.LazyFrame:
...

Expand All @@ -71,7 +71,7 @@ def concat(
def concat(
items: Sequence[pli.Expr],
rechunk: bool = True,
how: str = "vertical",
how: ConcatMethod = "vertical",
) -> pli.Expr:
...

Expand All @@ -84,7 +84,7 @@ def concat(
| Sequence[pli.Expr]
),
rechunk: bool = True,
how: str = "vertical",
how: ConcatMethod = "vertical",
) -> pli.DataFrame | pli.Series | pli.LazyFrame | pli.Expr:
"""
Aggregate multiple Dataframes/Series to a single DataFrame/Series.
Expand All @@ -95,11 +95,9 @@ def concat(
DataFrames/Series/LazyFrames to concatenate.
rechunk
rechunk the final DataFrame/Series.
how
how : {'vertical', 'diagonal', 'horizontal'}
Only used if the items are DataFrames.
One of {"vertical", "diagonal", "horizontal"}.
- Vertical: Applies multiple `vstack` operations.
- Diagonal: Finds a union between the column schemas and fills missing column
values with null.
Expand Down Expand Up @@ -137,7 +135,7 @@ def concat(
out = pli.wrap_df(_hor_concat_df(items))
else:
raise ValueError(
f"how should be one of {'vertical', 'diagonal'}, got {how}"
f"how must be one of {{'vertical', 'diagonal'}}, got {how}"
)
elif isinstance(first, pli.LazyFrame):
return pli.wrap_ldf(_concat_lf(items, rechunk))
Expand Down
3 changes: 2 additions & 1 deletion py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
CsvEncoding,
FillNullStrategy,
InterpolationMethod,
JoinStrategy,
ParallelStrategy,
UniqueKeepStrategy,
)
Expand Down Expand Up @@ -1342,7 +1343,7 @@ def join(
left_on: str | pli.Expr | list[str | pli.Expr] | None = None,
right_on: str | pli.Expr | list[str | pli.Expr] | None = None,
on: str | pli.Expr | list[str | pli.Expr] | None = None,
how: str = "inner",
how: JoinStrategy = "inner",
suffix: str = "_right",
allow_parallel: bool = True,
force_parallel: bool = False,
Expand Down
14 changes: 8 additions & 6 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@

if TYPE_CHECKING:
from polars.internals.type_aliases import (
CategoricalOrdering,
ComparisonOperator,
EpochTimeUnit,
FillNullStrategy,
Expand Down Expand Up @@ -5885,17 +5886,18 @@ class CatNameSpace:
def __init__(self, s: Series):
self._s = s

def set_ordering(self, ordering: str) -> Series:
def set_ordering(self, ordering: CategoricalOrdering) -> Series:
"""
Determine how this categorical series should be sorted.
Parameters
----------
ordering
One of:
- 'physical' -> use the physical representation of the categories to
determine the order (default)
- 'lexical' -. use the string values to determine the ordering
ordering : {'physical', 'lexical'}
Ordering type:
- 'physical' -> Use the physical representation of the categories to
determine the order (default).
- 'lexical' -> Use the string values to determine the ordering.
Examples
--------
Expand Down
5 changes: 5 additions & 0 deletions py-polars/polars/internals/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# User-facing string literal types
# The following all have an equivalent Rust enum with the same name
AvroCompression: TypeAlias = Literal["uncompressed", "snappy", "deflate"]
CategoricalOrdering: TypeAlias = Literal["physical", "lexical"]
ClosedWindow: TypeAlias = Literal["left", "right", "both", "none"]
CsvEncoding: TypeAlias = Literal["utf8", "utf8-lossy"]
FillNullStrategy: TypeAlias = Literal[
Expand All @@ -44,11 +45,15 @@
InterpolationMethod: TypeAlias = Literal[
"nearest", "higher", "lower", "midpoint", "linear"
] # QuantileInterpolOptions
JoinStrategy: TypeAlias = Literal[
"inner", "left", "outer", "semi", "anti", "cross"
] # JoinType
ToStructStrategy: TypeAlias = Literal[
"first_non_null", "max_width"
] # ListToStructWidthStrategy

# The following have no equivalent on the Rust side
ConcatMethod = Literal["vertical", "diagonal", "horizontal"]
EpochTimeUnit = Literal["ns", "us", "ms", "s", "d"]
Orientation: TypeAlias = Literal["col", "row"]
TransferEncoding: TypeAlias = Literal["hex", "base64"]
10 changes: 7 additions & 3 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import typing
from datetime import datetime, timedelta
from io import BytesIO
from typing import Any, Iterator
from typing import TYPE_CHECKING, Any, Iterator

import numpy as np
import pyarrow as pa
Expand All @@ -13,6 +13,9 @@
import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal, columns

if TYPE_CHECKING:
from polars.internals.type_aliases import JoinStrategy


def test_version() -> None:
pl.__version__
Expand Down Expand Up @@ -690,7 +693,7 @@ def test_concat() -> None:
_ = pl.concat([])

with pytest.raises(ValueError):
pl.concat([df, df], how="rubbish")
pl.concat([df, df], how="rubbish") # type: ignore[call-overload]


def test_arg_where() -> None:
Expand Down Expand Up @@ -1753,7 +1756,8 @@ def test_join_suffixes() -> None:
df_a = pl.DataFrame({"A": [1], "B": [1]})
df_b = pl.DataFrame({"A": [1], "B": [1]})

for how in ["left", "inner", "outer", "cross"]:
join_strategies: list[JoinStrategy] = ["left", "inner", "outer", "cross"]
for how in join_strategies:
# no need for an assert, we error if wrong
df_a.join(df_b, on="A", suffix="_y", how=how)["B_y"]

Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/test_joins.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from __future__ import annotations

from datetime import datetime
from typing import TYPE_CHECKING

import numpy as np
import pytest

import polars as pl

if TYPE_CHECKING:
from polars.internals.type_aliases import JoinStrategy


def test_semi_anti_join() -> None:
df_a = pl.DataFrame({"key": [1, 2, 3], "payload": ["f", "i", None]})
Expand Down Expand Up @@ -82,7 +86,8 @@ def test_sorted_merge_joins() -> None:
df_a = df_a.select(pl.all().reverse())
df_b = df_b.select(pl.all().reverse())

for how in ["left", "inner"]:
join_strategies: list[JoinStrategy] = ["left", "inner"]
for how in join_strategies:
# hash join
out_hash_join = df_a.join(df_b, on="a", how=how)

Expand Down Expand Up @@ -274,7 +279,8 @@ def test_joins_dispatch() -> None:
[pl.col("date").str.strptime(pl.Date), pl.col("datetime").cast(pl.Datetime)]
)

for how in ["left", "inner", "outer"]:
join_strategies: list[JoinStrategy] = ["left", "inner", "outer"]
for how in join_strategies:
dfa.join(dfa, on=["a", "b", "date", "datetime"], how=how)
dfa.join(dfa, on=["date", "datetime"], how=how)
dfa.join(dfa, on=["date", "datetime", "a"], how=how)
Expand Down

0 comments on commit f2829dc

Please sign in to comment.