Skip to content

Commit

Permalink
refactor(python): improve typing; many list types are better define…
Browse files Browse the repository at this point in the history
…d as `Sequence` (#5164)
  • Loading branch information
alexander-beedie committed Oct 11, 2022
1 parent 4bd94b1 commit 6ffed3c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 48 deletions.
78 changes: 47 additions & 31 deletions py-polars/polars/internals/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _read_csv(
cls: type[DF],
file: str | Path | BinaryIO | bytes,
has_header: bool = True,
columns: list[int] | list[str] | None = None,
columns: Sequence[int] | Sequence[str] | None = None,
sep: str = ",",
comment_char: str | None = None,
quote_char: str | None = r'"',
Expand Down Expand Up @@ -563,6 +563,8 @@ def _read_csv(

processed_null_values = _process_null_values(null_values)

if isinstance(columns, str):
columns = [columns]
if isinstance(file, str) and "*" in file:
dtypes_dict = None
if dtype_list is not None:
Expand Down Expand Up @@ -638,7 +640,7 @@ def _read_csv(
def _read_parquet(
cls: type[DF],
file: str | Path | BinaryIO,
columns: list[int] | list[str] | None = None,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
parallel: ParallelStrategy = "auto",
row_count_name: str | None = None,
Expand All @@ -657,6 +659,8 @@ def _read_parquet(
"""
if isinstance(file, (str, Path)):
file = format_path(file)
if isinstance(columns, str):
columns = [columns]

if isinstance(file, str) and "*" in file and pli._is_local_file(file):
from polars import scan_parquet
Expand Down Expand Up @@ -698,7 +702,7 @@ def _read_parquet(
def _read_avro(
cls: type[DF],
file: str | Path | BinaryIO,
columns: list[int] | list[str] | None = None,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
) -> DF:
"""
Expand Down Expand Up @@ -729,7 +733,7 @@ def _read_avro(
def _read_ipc(
cls,
file: str | Path | BinaryIO,
columns: list[int] | list[str] | None = None,
columns: Sequence[int] | Sequence[str] | None = None,
n_rows: int | None = None,
row_count_name: str | None = None,
row_count_offset: int = 0,
Expand Down Expand Up @@ -766,6 +770,8 @@ def _read_ipc(
"""
if isinstance(file, (str, Path)):
file = format_path(file)
if isinstance(columns, str):
columns = [columns]

if isinstance(file, str) and "*" in file and pli._is_local_file(file):
from polars import scan_ipc
Expand Down Expand Up @@ -1395,7 +1401,9 @@ def __getitem__(
)

def __setitem__(
self, key: str | list[int] | list[str] | tuple[Any, str | int], value: Any
self,
key: str | Sequence[int] | Sequence[str] | tuple[Any, str | int],
value: Any,
) -> None: # pragma: no cover
# df["foo"] = series
if isinstance(key, str):
Expand Down Expand Up @@ -2616,7 +2624,7 @@ def replace_at_idx(self: DF, index: int, series: pli.Series) -> DF:

def sort(
self: DF,
by: str | pli.Expr | list[str] | list[pli.Expr],
by: str | pli.Expr | Sequence[str] | Sequence[pli.Expr],
reverse: bool | list[bool] = False,
nulls_last: bool = False,
) -> DF | DataFrame:
Expand Down Expand Up @@ -2676,7 +2684,7 @@ def sort(
└─────┴─────┴─────┘
"""
if type(by) is list or isinstance(by, pli.Expr):
if not isinstance(by, str) and isinstance(by, (Sequence, pli.Expr)):
df = (
self.lazy()
.sort(by, reverse, nulls_last)
Expand Down Expand Up @@ -2899,7 +2907,7 @@ def tail(self: DF, n: int = 5) -> DF:
"""
return self._from_pydf(self._df.tail(n))

def drop_nulls(self: DF, subset: str | list[str] | None = None) -> DF:
def drop_nulls(self: DF, subset: str | Sequence[str] | None = None) -> DF:
"""
Return a new DataFrame where the null values are dropped.
Expand Down Expand Up @@ -3164,7 +3172,7 @@ def groupby_rolling(
period: str,
offset: str | None = None,
closed: ClosedWindow = "right",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
) -> RollingGroupBy[DF]:
"""
Create rolling groups based on a time column.
Expand Down Expand Up @@ -3277,7 +3285,7 @@ def groupby_dynamic(
truncate: bool = True,
include_boundaries: bool = False,
closed: ClosedWindow = "left",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
) -> DynamicGroupBy[DF]:
"""
Group based on a time value (or index value of type Int32, Int64).
Expand Down Expand Up @@ -3689,9 +3697,9 @@ def join_asof(
left_on: str | None = None,
right_on: str | None = None,
on: str | None = None,
by_left: str | list[str] | None = None,
by_right: str | list[str] | None = None,
by: str | list[str] | None = None,
by_left: str | Sequence[str] | None = None,
by_right: str | Sequence[str] | None = None,
by: str | Sequence[str] | None = None,
strategy: AsofJoinStrategy = "backward",
suffix: str = "_right",
tolerance: str | int | float | None = None,
Expand Down Expand Up @@ -4219,7 +4227,7 @@ def extend(self: DF, other: DF) -> DF:
return self

@deprecated_alias(name="columns")
def drop(self: DF, columns: str | list[str]) -> DF:
def drop(self: DF, columns: str | Sequence[str]) -> DF:
"""
Remove column from DataFrame and return as new.
Expand Down Expand Up @@ -4552,7 +4560,7 @@ def fill_nan(self, fill_value: pli.Expr | int | float | None) -> DataFrame:

def explode(
self,
columns: str | list[str] | pli.Expr | list[pli.Expr],
columns: str | Sequence[str] | pli.Expr | Sequence[pli.Expr],
) -> DataFrame:
"""
Explode `DataFrame` to long format by exploding a column with Lists.
Expand Down Expand Up @@ -4618,9 +4626,9 @@ def explode(

def pivot(
self: DF,
values: list[str] | str,
index: list[str] | str,
columns: list[str] | str,
values: Sequence[str] | str,
index: Sequence[str] | str,
columns: Sequence[str] | str,
aggregate_fn: PivotAgg | pli.Expr = "first",
maintain_order: bool = True,
sort_columns: bool = False,
Expand Down Expand Up @@ -4712,8 +4720,8 @@ def pivot(

def melt(
self: DF,
id_vars: list[str] | str | None = None,
value_vars: list[str] | str | None = None,
id_vars: Sequence[str] | str | None = None,
value_vars: Sequence[str] | str | None = None,
variable_name: str | None = None,
value_name: str | None = None,
) -> DF:
Expand Down Expand Up @@ -4785,7 +4793,7 @@ def unstack(
self: DF,
step: int,
how: UnstackDirection = "vertical",
columns: str | list[str] | None = None,
columns: str | Sequence[str] | None = None,
fill_values: list[Any] | None = None,
) -> DF:
"""
Expand Down Expand Up @@ -4922,7 +4930,7 @@ def unstack(
@overload
def partition_by(
self: DF,
groups: str | list[str],
groups: str | Sequence[str],
maintain_order: bool = False,
*,
as_dict: Literal[False] = ...,
Expand All @@ -4932,7 +4940,7 @@ def partition_by(
@overload
def partition_by(
self: DF,
groups: str | list[str],
groups: str | Sequence[str],
maintain_order: bool = False,
*,
as_dict: Literal[True],
Expand All @@ -4942,7 +4950,7 @@ def partition_by(
@overload
def partition_by(
self: DF,
groups: str | list[str],
groups: str | Sequence[str],
maintain_order: bool,
*,
as_dict: bool,
Expand All @@ -4951,7 +4959,7 @@ def partition_by(

def partition_by(
self: DF,
groups: str | list[str],
groups: str | Sequence[str],
maintain_order: bool = True,
*,
as_dict: bool = False,
Expand Down Expand Up @@ -5012,6 +5020,8 @@ def partition_by(
"""
if isinstance(groups, str):
groups = [groups]
elif not isinstance(groups, list):
groups = list(groups)

if as_dict:
out: dict[Any, DF] = {}
Expand Down Expand Up @@ -5771,7 +5781,7 @@ def quantile(
"""
return self._from_pydf(self._df.quantile(quantile, interpolation))

def to_dummies(self: DF, *, columns: list[str] | None = None) -> DF:
def to_dummies(self: DF, *, columns: Sequence[str] | None = None) -> DF:
"""
Get one hot encoded dummy variables.
Expand Down Expand Up @@ -5805,12 +5815,14 @@ def to_dummies(self: DF, *, columns: list[str] | None = None) -> DF:
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘
"""
if isinstance(columns, str):
columns = [columns]
return self._from_pydf(self._df.to_dummies(columns))

def unique(
self: DF,
maintain_order: bool = True,
subset: str | list[str] | None = None,
subset: str | Sequence[str] | None = None,
keep: UniqueKeepStrategy = "first",
) -> DF:
"""
Expand All @@ -5829,7 +5841,7 @@ def unique(
subset
Subset to use to compare rows.
keep : {'first', 'last'}
Which of the duplicate rows to keep.
Which of the duplicate rows to keep (in conjunction with ``subset``).
Returns
-------
Expand Down Expand Up @@ -5863,8 +5875,12 @@ def unique(
└─────┴──────┴───────┘
"""
if subset is not None and not isinstance(subset, list):
subset = [subset]
if subset is not None:
if isinstance(subset, str):
subset = [subset]
elif not isinstance(subset, list):
subset = list(subset)

return self._from_pydf(self._df.unique(maintain_order, subset, keep))

def rechunk(self: DF) -> DF:
Expand Down Expand Up @@ -6299,7 +6315,7 @@ def to_struct(self, name: str) -> pli.Series:
"""
return pli.wrap_s(self._df.to_struct(name))

def unnest(self: DF, names: str | list[str]) -> DF:
def unnest(self: DF, names: str | Sequence[str]) -> DF:
"""
Decompose a struct into its fields.
Expand Down
8 changes: 4 additions & 4 deletions py-polars/polars/internals/dataframe/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def __init__(
period: str,
offset: str | None,
closed: ClosedWindow = "none",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
):
self.df = df
self.time_column = index_column
Expand Down Expand Up @@ -875,7 +875,7 @@ def __init__(
truncate: bool = True,
include_boundaries: bool = True,
closed: ClosedWindow = "none",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
):
self.df = df
self.time_column = index_column
Expand Down Expand Up @@ -911,8 +911,8 @@ class GBSelection(Generic[DF]):
def __init__(
self,
df: PyDataFrame,
by: str | list[str],
selection: list[str] | None,
by: str | Sequence[str],
selection: Sequence[str] | None,
dataframe_class: type[DF],
):
self._df = df
Expand Down
18 changes: 9 additions & 9 deletions py-polars/polars/internals/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,7 @@ def select(

def groupby(
self: LDF,
by: str | list[str] | pli.Expr | list[pli.Expr],
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr],
maintain_order: bool = False,
) -> LazyGroupBy[LDF]:
"""
Expand Down Expand Up @@ -1201,7 +1201,7 @@ def groupby_rolling(
period: str,
offset: str | None = None,
closed: ClosedWindow = "right",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
) -> LazyGroupBy[LDF]:
"""
Create rolling groups based on a time column.
Expand Down Expand Up @@ -1321,7 +1321,7 @@ def groupby_dynamic(
truncate: bool = True,
include_boundaries: bool = False,
closed: ClosedWindow = "left",
by: str | list[str] | pli.Expr | list[pli.Expr] | None = None,
by: str | Sequence[str] | pli.Expr | Sequence[pli.Expr] | None = None,
) -> LazyGroupBy[LDF]:
"""
Group based on a time value (or index value of type Int32, Int64).
Expand Down Expand Up @@ -1420,9 +1420,9 @@ def join_asof(
left_on: str | None = None,
right_on: str | None = None,
on: str | None = None,
by_left: str | list[str] | None = None,
by_right: str | list[str] | None = None,
by: str | list[str] | None = None,
by_left: str | Sequence[str] | None = None,
by_right: str | Sequence[str] | None = None,
by: str | Sequence[str] | None = None,
strategy: AsofJoinStrategy = "backward",
suffix: str = "_right",
tolerance: str | int | float | None = None,
Expand Down Expand Up @@ -1507,13 +1507,13 @@ def join_asof(
if left_on is None or right_on is None:
raise ValueError("You should pass the column to join on as an argument.")

by_left_: list[str] | None
by_left_: Sequence[str] | None
if isinstance(by_left, str):
by_left_ = [by_left]
else:
by_left_ = by_left

by_right_: list[str] | None
by_right_: Sequence[str] | None
if isinstance(by_right, (str, pli.Expr)):
by_right_ = [by_right]
else:
Expand Down Expand Up @@ -2374,7 +2374,7 @@ def quantile(

def explode(
self: LDF,
columns: str | list[str] | pli.Expr | list[pli.Expr],
columns: str | Sequence[str] | pli.Expr | Sequence[pli.Expr],
) -> LDF:
"""
Explode lists to long format.
Expand Down
8 changes: 5 additions & 3 deletions py-polars/polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,15 @@ def range_to_slice(rng: range) -> slice:


def handle_projection_columns(
columns: list[str] | list[int] | None,
columns: Sequence[str] | Sequence[int] | str | None,
) -> tuple[list[int] | None, list[str] | None]:
"""Disambiguates between columns specified as integers vs. strings."""
projection: list[int] | None = None
if columns:
if is_int_sequence(columns):
projection = columns # type: ignore[assignment]
if isinstance(columns, str):
columns = [columns]
elif is_int_sequence(columns):
projection = list(columns)
columns = None
elif not is_str_sequence(columns):
raise ValueError(
Expand Down

0 comments on commit 6ffed3c

Please sign in to comment.