Skip to content

Commit

Permalink
Improve handling of encoding literal argument (#4293)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 6, 2022
1 parent 517e945 commit 093d55c
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 38 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal # pragma: no cover
from typing_extensions import Literal


def from_dict(
Expand Down
33 changes: 22 additions & 11 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import copy
import math
import random
import sys
from datetime import date, datetime, timedelta
from typing import Any, Callable, List, Sequence

Expand Down Expand Up @@ -35,6 +36,11 @@
except ImportError:
_NUMPY_AVAILABLE = False

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


def selection_to_pyexpr_list(
exprs: str | Expr | Sequence[str | Expr | pli.Series] | pli.Series,
Expand Down Expand Up @@ -6105,18 +6111,19 @@ def json_path_match(self, json_path: str) -> Expr:
"""
return wrap_expr(self._pyexpr.str_json_path_match(json_path))

def decode(self, encoding: str, strict: bool = False) -> Expr:
def decode(self, encoding: Literal["hex", "base64"], strict: bool = False) -> Expr:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding
'hex' or 'base64'
encoding : {'hex', 'base64'}
The encoding to use.
strict
how to handle invalid inputs
- True: method will throw error if unable to decode a value
- False: unhandled values will be replaced with `None`
How to handle invalid inputs:
- ``True``: An error will be thrown if unable to decode a value.
- ``False``: Unhandled values will be replaced with `None`.
Examples
--------
Expand All @@ -6141,16 +6148,18 @@ def decode(self, encoding: str, strict: bool = False) -> Expr:
elif encoding == "base64":
return wrap_expr(self._pyexpr.str_base64_decode(strict))
else:
raise ValueError("supported encodings are 'hex' and 'base64'")
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
)

def encode(self, encoding: str) -> Expr:
def encode(self, encoding: Literal["hex", "base64"]) -> Expr:
"""
Encode a value using the provided encoding.
Parameters
----------
encoding
'hex' or 'base64'
encoding : {'hex', 'base64'}
The encoding to use.
Returns
-------
Expand Down Expand Up @@ -6179,7 +6188,9 @@ def encode(self, encoding: str) -> Expr:
elif encoding == "base64":
return wrap_expr(self._pyexpr.str_base64_encode())
else:
raise ValueError("supported encodings are 'hex' and 'base64'")
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
)

def extract(self, pattern: str, group_index: int = 1) -> Expr:
r"""
Expand Down
13 changes: 11 additions & 2 deletions py-polars/polars/internals/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ def _read_csv(
infer_schema_length: int | None = 100,
batch_size: int = 8192,
n_rows: int | None = None,
encoding: str = "utf8",
encoding: Literal["utf8", "utf8-lossy"] = "utf8",
low_memory: bool = False,
rechunk: bool = True,
skip_rows_after_header: int = 0,
Expand All @@ -580,7 +580,16 @@ def _read_csv(
sample_size: int = 1024,
eol_char: str = "\n",
) -> DF:
"""See pl.read_csv."""
"""
Read a CSV file into a DataFrame.
Use ``pl.read_csv`` to dispatch to this method.
See Also
--------
polars.io.read_csv
"""
self = cls.__new__(cls)

path: str | None
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/internals/lazy_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def scan_csv(
with_column_names: Callable[[list[str]], list[str]] | None = None,
infer_schema_length: int | None = 100,
n_rows: int | None = None,
encoding: str = "utf8",
encoding: Literal["utf8", "utf8-lossy"] = "utf8",
low_memory: bool = False,
rechunk: bool = True,
skip_rows_after_header: int = 0,
Expand All @@ -136,9 +136,14 @@ def scan_csv(
eol_char: str = "\n",
) -> LDF:
"""
Lazily read from a CSV file or multiple files via glob patterns.
Use ``pl.scan_csv`` to dispatch to this method.
See Also
--------
scan_csv
polars.io.scan_csv
"""
dtype_list: list[tuple[str, type[DataType]]] | None = None
if dtypes is not None:
Expand Down
29 changes: 18 additions & 11 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -4500,18 +4500,21 @@ def starts_with(self, sub: str) -> Series:
s = wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.starts_with(sub)).to_series()

def decode(self, encoding: str, strict: bool = False) -> Series:
def decode(
self, encoding: Literal["hex", "base64"], strict: bool = False
) -> Series:
"""
Decode a value using the provided encoding.
Parameters
----------
encoding
'hex' or 'base64'
encoding : {'hex', 'base64'}
The encoding to use.
strict
how to handle invalid inputs
- True: method will throw error if unable to decode a value
- False: unhandled values will be replaced with `None`
How to handle invalid inputs:
- ``True``: An error will be thrown if unable to decode a value.
- ``False``: Unhandled values will be replaced with `None`.
Examples
--------
Expand All @@ -4531,16 +4534,18 @@ def decode(self, encoding: str, strict: bool = False) -> Series:
elif encoding == "base64":
return wrap_s(self._s.str_base64_decode(strict))
else:
raise ValueError("supported encodings are 'hex' and 'base64'")
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
)

def encode(self, encoding: str) -> Series:
def encode(self, encoding: Literal["hex", "base64"]) -> Series:
"""
Encode a value using the provided encoding
Parameters
----------
encoding
'hex' or 'base64'
encoding : {'hex', 'base64'}
The encoding to use.
Returns
-------
Expand All @@ -4564,7 +4569,9 @@ def encode(self, encoding: str) -> Series:
elif encoding == "base64":
return wrap_s(self._s.str_base64_encode())
else:
raise ValueError("supported encodings are 'hex' and 'base64'")
raise ValueError(
f"encoding must be one of {{'hex', 'base64'}}, got {encoding}"
)

def json_path_match(self, json_path: str) -> Series:
"""
Expand Down
12 changes: 9 additions & 3 deletions py-polars/polars/io.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Functions for reading and writing data."""
from __future__ import annotations

import sys
from io import BytesIO, IOBase, StringIO
from pathlib import Path
from typing import Any, BinaryIO, Callable, Mapping, TextIO
Expand Down Expand Up @@ -31,6 +32,11 @@
except ImportError:
_WITH_CX = False

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal


def _check_arg_is_1byte(
arg_name: str, arg: str | None, can_be_empty: bool = False
Expand Down Expand Up @@ -77,7 +83,7 @@ def read_csv(
infer_schema_length: int | None = 100,
batch_size: int = 8192,
n_rows: int | None = None,
encoding: str = "utf8",
encoding: Literal["utf8", "utf8-lossy"] = "utf8",
low_memory: bool = False,
rechunk: bool = True,
use_pyarrow: bool = False,
Expand All @@ -90,7 +96,7 @@ def read_csv(
**kwargs: Any,
) -> DataFrame:
"""
Read a CSV file into a Dataframe.
Read a CSV file into a DataFrame.
Parameters
----------
Expand Down Expand Up @@ -413,7 +419,7 @@ def scan_csv(
with_column_names: Callable[[list[str]], list[str]] | None = None,
infer_schema_length: int | None = 100,
n_rows: int | None = None,
encoding: str = "utf8",
encoding: Literal["utf8", "utf8-lossy"] = "utf8",
low_memory: bool = False,
rechunk: bool = True,
skip_rows_after_header: int = 0,
Expand Down
7 changes: 4 additions & 3 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,10 @@ impl PyDataFrame {
"utf8" => CsvEncoding::Utf8,
"utf8-lossy" => CsvEncoding::LossyUtf8,
e => {
return Err(
PyPolarsErr::Other(format!("encoding not {} not implemented.", e)).into(),
)
return Err(PyValueError::new_err(format!(
"encoding must be one of {{'utf8', 'utf8-lossy'}}, got {}",
e
)))
}
};

Expand Down
7 changes: 4 additions & 3 deletions py-polars/src/lazy/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ impl PyLazyFrame {
"utf8" => CsvEncoding::Utf8,
"utf8-lossy" => CsvEncoding::LossyUtf8,
e => {
return Err(
PyPolarsErr::Other(format!("encoding not {} not implemented.", e)).into(),
)
return Err(PyValueError::new_err(format!(
"encoding must be one of {{'utf8', 'utf8-lossy'}}, got {}",
e
)))
}
};

Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1312,7 +1312,7 @@ def test_str_encode() -> None:
verify_series_and_expr_api(s, hex_encoded, "str.encode", "hex")
verify_series_and_expr_api(s, base64_encoded, "str.encode", "base64")
with pytest.raises(ValueError):
s.str.encode("utf8")
s.str.encode("utf8") # type: ignore[arg-type]


def test_str_decode() -> None:
Expand All @@ -1331,7 +1331,7 @@ def test_str_decode_exception() -> None:
with pytest.raises(Exception):
s.str.decode(encoding="base64", strict=True)
with pytest.raises(ValueError):
s.str.decode("utf8")
s.str.decode("utf8") # type: ignore[arg-type]


def test_str_replace_str_replace_all() -> None:
Expand Down

0 comments on commit 093d55c

Please sign in to comment.