Skip to content

Commit

Permalink
add arr.join(separator: str) (#2550)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Feb 5, 2022
1 parent d070bd4 commit 36a4336
Show file tree
Hide file tree
Showing 11 changed files with 256 additions and 85 deletions.
45 changes: 45 additions & 0 deletions polars/polars-core/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::prelude::*;
use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use std::convert::TryFrom;
use std::fmt::Write;

fn cast_rhs(
other: &mut [Series],
Expand Down Expand Up @@ -51,6 +52,50 @@ fn cast_rhs(
}

impl ListChunked {
/// In case the inner dtype [`DataType::Utf8`], the individual items will be joined into a
/// single string separated by `separator`.
pub fn lst_join(&self, separator: &str) -> Result<Utf8Chunked> {
match self.inner_dtype() {
DataType::Utf8 => {
// used to amortize heap allocs
let mut buf = String::with_capacity(128);

let mut builder = Utf8ChunkedBuilder::new(
self.name(),
self.len(),
self.get_values_size() + separator.len() * self.len(),
);

self.amortized_iter().for_each(|opt_s| {
let opt_val = opt_s.map(|s| {
// make sure that we don't write values of previous iteration
buf.clear();
let ca = s.as_ref().utf8().unwrap();
let iter = ca.into_iter().map(|opt_v| opt_v.unwrap_or("null"));

for val in iter {
buf.write_str(val).unwrap();
buf.write_str(separator).unwrap();
}
// last value should not have a separator, so slice that off
// saturating sub because there might have been nothing written.
&buf[..buf.len().saturating_sub(separator.len())]
});
builder.append_option(opt_val)
});
Ok(builder.finish())
}
dt => Err(PolarsError::SchemaMisMatch(
format!(
"cannot call lst.join on Series with dtype {:?}.\
Inner type must be Utf8",
dt
)
.into(),
)),
}
}

pub fn lst_max(&self) -> Series {
self.apply_amortized(|s| s.as_ref().max_as_series())
.explode()
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/dsl/dt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::*;

/// Specialized expressions for [`Series`] with dates/datetimes.
pub struct DateLikeNameSpace(pub(crate) Expr);

impl DateLikeNameSpace {
Expand Down
143 changes: 143 additions & 0 deletions polars/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
use crate::prelude::*;
use polars_core::prelude::*;

/// Specialized expressions for [`Series`] of [`DataType::List`].
pub struct ListNameSpace(pub(crate) Expr);

impl ListNameSpace {
/// Get lengths of the arrays in the List type.
pub fn lengths(self) -> Expr {
let function = |s: Series| {
let ca = s.list()?;
Ok(ca.lst_lengths().into_series())
};
self.0
.map(function, GetOutput::from_type(DataType::UInt32))
.with_fmt("arr.len")
}

/// Compute the maximum of the items in every sublist.
pub fn max(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_max()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
}
}),
)
.with_fmt("arr.max")
}

/// Compute the minimum of the items in every sublist.
pub fn min(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_min()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
}
}),
)
.with_fmt("arr.min")
}

/// Compute the sum the items in every sublist.
pub fn sum(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_sum()),
GetOutput::map_field(|f| {
if let DataType::List(adt) = f.data_type() {
Field::new(f.name(), *adt.clone())
} else {
// inner type
f.clone()
}
}),
)
.with_fmt("arr.sum")
}

/// Compute the mean of every sublist and return a `Series` of dtype `Float64`
pub fn mean(self) -> Expr {
self.0
.map(
|s| Ok(s.list()?.lst_mean().into_series()),
GetOutput::from_type(DataType::Float64),
)
.with_fmt("arr.mean")
}

/// Sort every sublist.
pub fn sort(self, reverse: bool) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_sort(reverse).into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.sort")
}

/// Reverse every sublist
pub fn reverse(self) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_reverse().into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.reverse")
}

/// Keep only the unique values in every sublist.
pub fn unique(self) -> Expr {
self.0
.map(
move |s| Ok(s.list()?.lst_unique()?.into_series()),
GetOutput::same_type(),
)
.with_fmt("arr.unique")
}

/// Get items in every sublist by index.
pub fn get(self, index: i64) -> Expr {
self.0.map(
move |s| s.list()?.lst_get(index),
GetOutput::map_field(|field| match field.data_type() {
DataType::List(inner) => Field::new(field.name(), *inner.clone()),
_ => panic!("should be a list type"),
}),
)
}

/// Get first item of every sublist.
pub fn first(self) -> Expr {
self.get(0)
}

/// Get last item of every sublist.
pub fn last(self) -> Expr {
self.get(-1)
}

/// Join all string items in a sublist and place a separator between them.
/// # Error
/// This errors if inner type of list `!= DataType::Utf8`.
pub fn join(self, separator: &str) -> Expr {
let separator = separator.to_string();
self.0
.map(
move |s| s.list()?.lst_join(&separator).map(|ca| ca.into_series()),
GetOutput::from_type(DataType::Utf8),
)
.with_fmt("arr.join")
}
}
6 changes: 6 additions & 0 deletions polars/polars-lazy/src/dsl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Domain specific language for the Lazy api.
#[cfg(feature = "temporal")]
mod dt;
#[cfg(feature = "list")]
mod list;
mod options;
#[cfg(feature = "strings")]
pub mod string;
Expand Down Expand Up @@ -2056,6 +2058,10 @@ impl Expr {
pub fn dt(self) -> dt::DateLikeNameSpace {
dt::DateLikeNameSpace(self)
}
#[cfg(feature = "list")]
pub fn arr(self) -> list::ListNameSpace {
list::ListNameSpace(self)
}
}

// Arithmetic ops
Expand Down
1 change: 1 addition & 0 deletions polars/polars-lazy/src/dsl/string.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::*;
use polars_arrow::array::ValueSize;

/// Specialized expressions for [`Series`] of [`DataType::Utf8`].
pub struct StringNameSpace(pub(crate) Expr);

impl StringNameSpace {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/docs/source/reference/expression.rst
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ The following methods are available under the `expr.arr` attribute.

ExprListNameSpace.concat
ExprListNameSpace.lengths
ExprListNameSpace.lengths
ExprListNameSpace.sum
ExprListNameSpace.min
ExprListNameSpace.max
Expand All @@ -319,3 +318,4 @@ The following methods are available under the `expr.arr` attribute.
ExprListNameSpace.first
ExprListNameSpace.last
ExprListNameSpace.contains
ExprListNameSpace.join
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series.rst
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,4 @@ The following methods are available under the `Series.arr` attribute.
ListNameSpace.first
ListNameSpace.last
ListNameSpace.contains
ListNameSpace.join
17 changes: 17 additions & 0 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2470,6 +2470,23 @@ def contains(self, item: Union[float, str, bool, int, date, datetime]) -> "Expr"
"""
return wrap_expr(self._pyexpr).map(lambda s: s.arr.contains(item))

def join(self, separator: str) -> "Expr":
"""
Join all string items in a sublist and place a separator between them.
This errors if inner type of list `!= Utf8`.
Parameters
----------
separator
string to separate the items with
Returns
-------
Series of dtype Utf8
"""

return wrap_expr(self._pyexpr.lst_join(separator))


class ExprStringNameSpace:
"""
Expand Down
17 changes: 17 additions & 0 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3876,6 +3876,23 @@ def get(self, index: int) -> "Series":
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.get(index)).to_series()

def join(self, separator: str) -> "Series":
"""
Join all string items in a sublist and place a separator between them.
This errors if inner type of list `!= Utf8`.
Parameters
----------
separator
string to separate the items with
Returns
-------
Series of dtype Utf8
"""

return pli.select(pli.lit(wrap_s(self._s)).arr.join(separator)).to_series()

def first(self) -> "Series":
"""
Get the first value of the sublists.
Expand Down

0 comments on commit 36a4336

Please sign in to comment.