Skip to content

Commit

Permalink
Add decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Aug 15, 2022
1 parent cd01961 commit 28f9575
Showing 1 changed file with 73 additions and 87 deletions.
160 changes: 73 additions & 87 deletions py-polars/polars/internals/series/string.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,43 @@
from __future__ import annotations

from typing import TYPE_CHECKING
import sys
from typing import TYPE_CHECKING, Callable

import polars.internals as pli
from polars.datatypes import Date, Datetime, Time

if TYPE_CHECKING:
from polars.internals.type_aliases import TransferEncoding
from polars.polars import PySeries

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

P = ParamSpec("P")


def expr(func: Callable[P, pli.Series]) -> Callable[P, pli.Series]:
"""Dispatch to the expression implementation."""

def wrapper(self: StringNameSpace, *args: P.args, **kwargs: P.kwargs) -> pli.Series:
s = pli.wrap_s(self._s)
f = getattr(pli.col(s.name).str, func.__name__)
return s.to_frame().select(f(*args, **kwargs)).to_series()

return wrapper # type: ignore[return-value]


class StringNameSpace:
"""Series.str namespace."""

_s: PySeries

def __init__(self, series: pli.Series):
self._s = series._s

@expr
def strptime(
self,
datatype: type[Date] | type[Datetime] | type[Time],
Expand Down Expand Up @@ -84,13 +107,9 @@ def strptime(
└────────────┘
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.strptime(datatype, fmt, strict, exact))
.to_series()
)
...

@expr
def lengths(self) -> pli.Series:
"""
Get length of the string values in the Series.
Expand All @@ -113,9 +132,9 @@ def lengths(self) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.lengths()).to_series()
...

@expr
def concat(self, delimiter: str = "-") -> pli.Series:
"""
Vertically concat the values in the Series to a single string value.
Expand All @@ -135,9 +154,9 @@ def concat(self, delimiter: str = "-") -> pli.Series:
'1-null-2'
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.concat(delimiter)).to_series()
...

@expr
def contains(self, pattern: str, literal: bool = False) -> pli.Series:
"""
Check if strings in Series contain a substring that matches a regex.
Expand Down Expand Up @@ -176,13 +195,9 @@ def contains(self, pattern: str, literal: bool = False) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.contains(pattern, literal))
.to_series()
)
...

@expr
def ends_with(self, sub: str) -> pli.Series:
"""
Check if string values end with a substring.
Expand Down Expand Up @@ -210,9 +225,9 @@ def ends_with(self, sub: str) -> pli.Series:
starts_with : Check if string values start with a substring.
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.ends_with(sub)).to_series()
...

@expr
def starts_with(self, sub: str) -> pli.Series:
"""
Check if string values start with a substring.
Expand Down Expand Up @@ -240,9 +255,9 @@ def starts_with(self, sub: str) -> pli.Series:
ends_with : Check if string values end with a substring.
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.starts_with(sub)).to_series()
...

@expr
def decode(self, encoding: TransferEncoding, strict: bool = False) -> pli.Series:
"""
Decode a value using the provided encoding.
Expand Down Expand Up @@ -270,13 +285,9 @@ def decode(self, encoding: TransferEncoding, strict: bool = False) -> pli.Series
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.decode(encoding, strict))
.to_series()
)
...

@expr
def encode(self, encoding: TransferEncoding) -> pli.Series:
"""
Encode a value using the provided encoding
Expand All @@ -303,9 +314,9 @@ def encode(self, encoding: TransferEncoding) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.encode(encoding)).to_series()
...

@expr
def json_path_match(self, json_path: str) -> pli.Series:
"""
Extract the first match of json string with provided JSONPath expression.
Expand Down Expand Up @@ -342,13 +353,9 @@ def json_path_match(self, json_path: str) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.json_path_match(json_path))
.to_series()
)
...

@expr
def extract(self, pattern: str, group_index: int = 1) -> pli.Series:
r"""
Extract the target capture group from provided patterns.
Expand Down Expand Up @@ -392,13 +399,9 @@ def extract(self, pattern: str, group_index: int = 1) -> pli.Series:
└─────────┘
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.extract(pattern, group_index))
.to_series()
)
...

@expr
def extract_all(self, pattern: str) -> pli.Series:
r"""
Extract each successive non-overlapping regex match in an individual string as
Expand Down Expand Up @@ -426,9 +429,9 @@ def extract_all(self, pattern: str) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.extract_all(pattern)).to_series()
...

@expr
def count_match(self, pattern: str) -> pli.Series:
r"""
Count all successive non-overlapping regex matches.
Expand All @@ -455,9 +458,9 @@ def count_match(self, pattern: str) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.count_match(pattern)).to_series()
...

@expr
def split(self, by: str, inclusive: bool = False) -> pli.Series:
"""
Split the string by a substring.
Expand All @@ -474,9 +477,9 @@ def split(self, by: str, inclusive: bool = False) -> pli.Series:
List of Utf8 type
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.split(by, inclusive)).to_series()
...

@expr
def split_exact(self, by: str, n: int, inclusive: bool = False) -> pli.Series:
"""
Split the string by a substring into a struct of ``n`` fields.
Expand Down Expand Up @@ -546,13 +549,9 @@ def split_exact(self, by: str, n: int, inclusive: bool = False) -> pli.Series:
Struct of Utf8 type
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.split_exact(by, n, inclusive))
.to_series()
)
...

@expr
def replace(self, pattern: str, value: str, literal: bool = False) -> pli.Series:
r"""
Replace first matching regex/literal substring with a new string value.
Expand Down Expand Up @@ -582,13 +581,9 @@ def replace(self, pattern: str, value: str, literal: bool = False) -> pli.Series
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.replace(pattern, value, literal))
.to_series()
)
...

@expr
def replace_all(
self, pattern: str, value: str, literal: bool = False
) -> pli.Series:
Expand Down Expand Up @@ -620,28 +615,24 @@ def replace_all(
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame()
.select(pli.col(s.name).str.replace_all(pattern, value, literal))
.to_series()
)
...

@expr
def strip(self) -> pli.Series:
"""Remove leading and trailing whitespace."""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.strip()).to_series()
...

@expr
def lstrip(self) -> pli.Series:
"""Remove leading whitespace."""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.lstrip()).to_series()
...

@expr
def rstrip(self) -> pli.Series:
"""Remove trailing whitespace."""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.rstrip()).to_series()
...

@expr
def zfill(self, alignment: int) -> pli.Series:
"""
Return a copy of the string left filled with ASCII '0' digits to make a string
Expand All @@ -655,9 +646,9 @@ def zfill(self, alignment: int) -> pli.Series:
Fill the value up to this length.
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.zfill(alignment)).to_series()
...

@expr
def ljust(self, width: int, fillchar: str = " ") -> pli.Series:
"""
Return the string left justified in a string of length ``width``.
Expand Down Expand Up @@ -686,11 +677,9 @@ def ljust(self, width: int, fillchar: str = " ") -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame().select(pli.col(s.name).str.ljust(width, fillchar)).to_series()
)
...

@expr
def rjust(self, width: int, fillchar: str = " ") -> pli.Series:
"""
Return the string right justified in a string of length ``width``.
Expand Down Expand Up @@ -719,21 +708,19 @@ def rjust(self, width: int, fillchar: str = " ") -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return (
s.to_frame().select(pli.col(s.name).str.rjust(width, fillchar)).to_series()
)
...

@expr
def to_lowercase(self) -> pli.Series:
"""Modify the strings to their lowercase equivalent."""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.to_lowercase()).to_series()
...

@expr
def to_uppercase(self) -> pli.Series:
"""Modify the strings to their uppercase equivalent."""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.to_uppercase()).to_series()
...

@expr
def slice(self, start: int, length: int | None = None) -> pli.Series:
"""
Create subslices of the string values of a Utf8 Series.
Expand Down Expand Up @@ -777,5 +764,4 @@ def slice(self, start: int, length: int | None = None) -> pli.Series:
]
"""
s = pli.wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.slice(start, length)).to_series()
...

0 comments on commit 28f9575

Please sign in to comment.