Skip to content

Commit

Permalink
fix auto explode of exprs and wrong function namespaces
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 28, 2022
1 parent 62728c9 commit a4554f9
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 65 deletions.
26 changes: 20 additions & 6 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -790,20 +790,34 @@ impl Expr {

/// Get the index value that has the minumum value
pub fn arg_min(self) -> Self {
self.apply(
|s| Ok(Series::new(s.name(), &[s.arg_min().map(|idx| idx as u32)])),
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str: "arg_min",
};

self.function_with_options(
move |s: Series| Ok(Series::new(s.name(), &[s.arg_min().map(|idx| idx as u32)])),
GetOutput::from_type(DataType::UInt32),
options,
)
.with_fmt("arg_min")
}

/// Get the index value that has the maximum value
pub fn arg_max(self) -> Self {
self.apply(
|s| Ok(Series::new(s.name(), &[s.arg_max().map(|idx| idx as u32)])),
let options = FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str: "arg_max",
};

self.function_with_options(
move |s: Series| Ok(Series::new(s.name(), &[s.arg_max().map(|idx| idx as u32)])),
GetOutput::from_type(DataType::UInt32),
options,
)
.with_fmt("arg_max")
}

/// Get the index values that would sort this expression.
Expand Down
20 changes: 14 additions & 6 deletions polars/polars-lazy/src/dsl/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,19 @@ impl StringNameSpace {
/// * `delimiter` - A string that will act as delimiter between values.
pub fn concat(self, delimiter: &str) -> Expr {
let delimiter = delimiter.to_owned();
self.0
.apply(
move |s| Ok(s.str_concat(&delimiter).into_series()),
GetOutput::from_type(DataType::Utf8),
)
.with_fmt("str_concat")
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
Ok(s[0].str_concat(&delimiter).into_series())
}) as Arc<dyn SeriesUdf>);
Expr::Function {
input: vec![self.0],
function,
output_type: GetOutput::from_type(DataType::Utf8),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: false,
auto_explode: true,
fmt_str: "str_concat",
},
}
}
}
6 changes: 3 additions & 3 deletions polars/polars-lazy/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ pub fn argsort_by<E: AsRef<[Expr]>>(by: E, reverse: &[bool]) -> Expr {

#[cfg(feature = "concat_str")]
#[cfg_attr(docsrs, doc(cfg(feature = "concat_str")))]
/// Concat string columns in linear time
/// Horizontally concat string columns in linear time
pub fn concat_str(s: Vec<Expr>, sep: &str) -> Expr {
let sep = sep.to_string();
let function = NoEq::new(Arc::new(move |s: &mut [Series]| {
Expand All @@ -150,9 +150,9 @@ pub fn concat_str(s: Vec<Expr>, sep: &str) -> Expr {
function,
output_type: GetOutput::from_type(DataType::Utf8),
options: FunctionOptions {
collect_groups: ApplyOptions::ApplyFlat,
collect_groups: ApplyOptions::ApplyGroups,
input_wildcard_expansion: true,
auto_explode: false,
auto_explode: true,
fmt_str: "concat_by",
},
}
Expand Down
3 changes: 2 additions & 1 deletion py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,6 @@ Manipulation/ selection
Expr.clip
Expr.lower_bound
Expr.upper_bound
Expr.str_concat
Expr.reshape
Expr.to_physical
Expr.shuffle
Expand Down Expand Up @@ -282,6 +281,7 @@ The following methods are available under the `Expr.str` attribute.
ExprStringNameSpace.lengths
ExprStringNameSpace.to_uppercase
ExprStringNameSpace.to_lowercase
ExprStringNameSpace.concat
ExprStringNameSpace.strip
ExprStringNameSpace.lstrip
ExprStringNameSpace.rstrip
Expand All @@ -305,6 +305,7 @@ The following methods are available under the `expr.arr` attribute.

ExprListNameSpace.concat
ExprListNameSpace.lengths
ExprListNameSpace.lengths
ExprListNameSpace.sum
ExprListNameSpace.min
ExprListNameSpace.max
Expand Down
2 changes: 1 addition & 1 deletion py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,6 @@ Manipulation/ selection
Series.zip_with
Series.interpolate
Series.clip
Series.str_concat
Series.reshape
Series.to_dummies
Series.shuffle
Expand Down Expand Up @@ -244,6 +243,7 @@ The following methods are available under the `Series.str` attribute.

StringNameSpace.strptime
StringNameSpace.lengths
StringNameSpace.concat
StringNameSpace.contains
StringNameSpace.json_path_match
StringNameSpace.extract
Expand Down
52 changes: 26 additions & 26 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2014,32 +2014,6 @@ def upper_bound(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.upper_bound())

def str_concat(self, delimiter: str = "-") -> "Expr":
"""
Vertically concat the values in the Series to a single string value.
Returns
-------
Series of dtype Utf8
Examples
--------
>>> df = pl.DataFrame({"foo": [1, None, 2]})
>>> df = df.select(pl.col("foo").str_concat("-"))
>>> df
shape: (1, 1)
┌──────────┐
│ foo │
│ --- │
│ str │
╞══════════╡
│ 1-null-2 │
└──────────┘
"""
return wrap_expr(self._pyexpr.str_concat(delimiter))

def sin(self) -> "Expr":
"""
Compute the element-wise value for Trigonometric sine on an array
Expand Down Expand Up @@ -2531,6 +2505,32 @@ def lengths(self) -> Expr:
"""
return wrap_expr(self._pyexpr.str_lengths())

def concat(self, delimiter: str = "-") -> "Expr":
"""
Vertically concat the values in the Series to a single string value.
Returns
-------
Series of dtype Utf8
Examples
--------
>>> df = pl.DataFrame({"foo": [1, None, 2]})
>>> df = df.select(pl.col("foo").str.concat("-"))
>>> df
shape: (1, 1)
┌──────────┐
│ foo │
│ --- │
│ str │
╞══════════╡
│ 1-null-2 │
└──────────┘
"""
return wrap_expr(self._pyexpr.str_concat(delimiter))

def to_uppercase(self) -> Expr:
"""
Transform to uppercase variant.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ def _date(

def concat_str(exprs: Sequence[Union["pli.Expr", str]], sep: str = "") -> "pli.Expr":
"""
Concat Utf8 Series in linear time. Non utf8 columns are cast to utf8.
Horizontally Concat Utf8 Series in linear time. Non utf8 columns are cast to utf8.
Parameters
----------
Expand Down
35 changes: 17 additions & 18 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3233,24 +3233,6 @@ def clip(self, min_val: Union[int, float], max_val: Union[int, float]) -> "Serie
self.name
]

def str_concat(self, delimiter: str = "-") -> "Series":
"""
Vertically concat the values in the Series to a single string value.
Returns
-------
Series of dtype Utf8
Examples
--------
>>> pl.Series([1, None, 2]).str_concat("-")[0]
'1-null-2'
"""
return self.to_frame().select(pli.col(self.name).str_concat(delimiter))[
self.name
]

def reshape(self, dims: Tuple[int, ...]) -> "Series":
"""
Reshape this Series to a flat series, shape: (len,)
Expand Down Expand Up @@ -3502,6 +3484,23 @@ def lengths(self) -> Series:
"""
return wrap_s(self._s.str_lengths())

def concat(self, delimiter: str = "-") -> "Series":
"""
Vertically concat the values in the Series to a single string value.
Returns
-------
Series of dtype Utf8
Examples
--------
>>> pl.Series([1, None, 2]).str.concat("-")[0]
'1-null-2'
"""
s = wrap_s(self._s)
return s.to_frame().select(pli.col(s.name).str.concat(delimiter)).to_series()

def contains(self, pattern: str) -> Series:
"""
Check if strings in Series contain regex pattern.
Expand Down
10 changes: 8 additions & 2 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,7 @@ def test_clip() -> None:


def test_argminmax() -> None:
df = pl.DataFrame({"a": [1, 2, 3, 4, 5]})
df = pl.DataFrame({"a": [1, 2, 3, 4, 5], "b": [1, 1, 2, 2, 2]})
out = df.select(
[
pl.col("a").arg_min().alias("min"),
Expand All @@ -847,6 +847,12 @@ def test_argminmax() -> None:
assert out["max"][0] == 4
assert out["min"][0] == 0

out = df.groupby("b", maintain_order=True).agg(
[pl.col("a").arg_min().alias("min"), pl.col("a").arg_max().alias("max")]
)
assert out["max"][0] == 1
assert out["min"][0] == 0


def test_expr_bool_cmp() -> None:
# Since expressions are lazy they should not be evaluated as
Expand Down Expand Up @@ -943,7 +949,7 @@ def test_join_suffix() -> None:

def test_str_concat() -> None:
df = pl.DataFrame({"foo": [1, None, 2]})
df = df.select(pl.col("foo").str_concat("-"))
df = df.select(pl.col("foo").str.concat("-"))
assert df[0, 0] == "1-null-2"


Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def test_shrink_to_fit() -> None:

def test_str_concat() -> None:
s = pl.Series(["1", None, "2"])
result = s.str_concat()
result = s.str.concat()
expected = pl.Series(["1-null-2"])
testing.assert_series_equal(result, expected)

Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/test_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,16 @@ def test_extract_binary() -> None:
df = pl.DataFrame({"foo": ["aron", "butler", "charly", "david"]})
out = df.filter(pl.col("foo").str.extract("^(a)", 1) == "a").to_series()
assert out[0] == "aron"


def test_auto_explode() -> None:
df = pl.DataFrame(
[pl.Series("val", ["A", "B", "C", "D"]), pl.Series("id", [1, 1, 2, 2])]
)
pl.col("val").str.concat(delimiter=",")
grouped = (
df.groupby("id")
.agg(pl.col("val").str.concat(delimiter=",").alias("grouped"))
.get_column("grouped")
)
assert grouped.dtype == pl.Utf8

0 comments on commit a4554f9

Please sign in to comment.