Skip to content

Commit

Permalink
add arr.concat singleton lists (#2279)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jan 6, 2022
1 parent ce34ba1 commit 81129eb
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 62 deletions.
179 changes: 129 additions & 50 deletions polars/polars-core/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,52 @@ use polars_arrow::kernels::list::sublist_get;
use polars_arrow::prelude::ValueSize;
use std::convert::TryFrom;

fn cast_rhs(
other: &mut [Series],
inner_type: &DataType,
dtype: &DataType,
length: usize,
allow_broadcast: bool,
) -> Result<()> {
for s in other.iter_mut() {
// make sure that inner types match before we coerce into list
if !matches!(s.dtype(), DataType::List(_)) {
*s = s.cast(inner_type)?
}
if !matches!(s.dtype(), DataType::List(_)) && s.dtype() == inner_type {
// coerce to list JIT
*s = s.reshape(&[-1, 1]).unwrap();
}
if s.dtype() != dtype {
match s.cast(dtype) {
Ok(out) => {
*s = out;
}
Err(_) => {
return Err(PolarsError::SchemaMisMatch(
format!("cannot concat {:?} into a list of {:?}", s.dtype(), dtype).into(),
));
}
}
}

if s.len() != length {
if s.len() == 1 {
if allow_broadcast {
// broadcast JIT
*s = s.expand_at_index(0, length)
}
// else do nothing
} else {
return Err(PolarsError::ShapeMisMatch(
format!("length {} does not match {}", s.len(), length).into(),
));
}
}
}
Ok(())
}

impl ListChunked {
pub fn lst_max(&self) -> Series {
self.apply_amortized(|s| s.as_ref().max_as_series())
Expand Down Expand Up @@ -72,68 +118,101 @@ impl ListChunked {
}

pub fn lst_concat(&self, other: &[Series]) -> Result<ListChunked> {
let mut other = other.to_vec();
let other_len = other.len();
let mut iters = Vec::with_capacity(other.len() + 1);
let length = self.len();
let mut other = other.to_vec();
let dtype = self.dtype();
let inner_type = self.inner_dtype();
let length = self.len();

for s in other.iter_mut() {
if !matches!(s.dtype(), DataType::List(_)) && s.dtype() == &inner_type {
// coerce to list JIT
*s = s.reshape(&[-1, 1]).unwrap();
// broadcasting path in case all unit length
// this path will not expand the series, so saves memory
if other.iter().all(|s| s.len() == 1) && self.len() != 1 {
cast_rhs(&mut other, &inner_type, dtype, length, false)?;
let to_append = other
.iter()
.flat_map(|s| {
let lst = s.list().unwrap();
lst.get(0)
})
.collect::<Vec<_>>();
// there was a None, so all values will be None
if to_append.len() != other_len {
return Ok(Self::full_null_with_dtype(self.name(), length, &inner_type));
}
if s.dtype() != dtype {
return Err(PolarsError::SchemaMisMatch(
format!("cannot concat {:?} into a list of {:?}", s.dtype(), dtype).into(),
));
}
if s.len() != length {
return Err(PolarsError::ShapeMisMatch(
format!("length {} does not match {}", s.len(), length).into(),
));
}
iters.push(s.list()?.amortized_iter())
}
let mut first_iter = self.into_iter();
let mut builder = get_list_builder(
&self.inner_dtype(),
self.get_values_size() * (other_len + 1),
self.len(),
self.name(),
);

for _ in 0..self.len() {
let mut acc = match first_iter.next().unwrap() {
Some(s) => s,
None => {
builder.append_null();
// make sure that the iterators advance before we continue
for it in &mut iters {
it.next().unwrap();
}
continue;
}
};
let mut already_null = false;
for it in &mut iters {
match it.next().unwrap() {
Some(s) => {
acc.append(s.as_ref())?;

let vals_size_other = other
.iter()
.map(|s| s.list().unwrap().get_values_size())
.sum::<usize>();

let mut builder = get_list_builder(
&inner_type,
self.get_values_size() + vals_size_other + 1,
length,
self.name(),
);
self.into_iter().for_each(|opt_s| {
let opt_s = opt_s.map(|mut s| {
for append in &to_append {
s.append(append).unwrap();
}
s
});
builder.append_opt_series(opt_s.as_ref())
});
Ok(builder.finish())
} else {
// normal path which may contain same length list or unit length lists
cast_rhs(&mut other, &inner_type, dtype, length, true)?;

let vals_size_other = other
.iter()
.map(|s| s.list().unwrap().get_values_size())
.sum::<usize>();
let mut iters = Vec::with_capacity(other_len + 1);

for s in other.iter_mut() {
iters.push(s.list()?.amortized_iter())
}
let mut first_iter = self.into_iter();
let mut builder = get_list_builder(
&inner_type,
self.get_values_size() + vals_size_other + 1,
length,
self.name(),
);

for _ in 0..self.len() {
let mut acc = match first_iter.next().unwrap() {
Some(s) => s,
None => {
if !already_null {
builder.append_null();
already_null = true;
builder.append_null();
// make sure that the iterators advance before we continue
for it in &mut iters {
it.next().unwrap();
}

continue;
}
};
let mut already_null = false;
for it in &mut iters {
match it.next().unwrap() {
Some(s) => {
acc.append(s.as_ref())?;
}
None => {
if !already_null {
builder.append_null();
already_null = true;
}

continue;
}
}
}
builder.append_series(&acc);
}
builder.append_series(&acc);
Ok(builder.finish())
}
Ok(builder.finish())
}
}
14 changes: 11 additions & 3 deletions py-polars/polars/internals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,7 +2331,9 @@ def unique(self) -> "Expr":
"""
return wrap_expr(self._pyexpr.lst_unique())

def concat(self, other: Union[List[Union[Expr, str]], Expr, str]) -> "Expr":
def concat(
self, other: Union[List[Union[Expr, str]], Expr, str, "pli.Series", List[Any]]
) -> "Expr":
"""
Concat the arrays in a Series dtype List in linear time.
Expand All @@ -2340,11 +2342,17 @@ def concat(self, other: Union[List[Union[Expr, str]], Expr, str]) -> "Expr":
other
Columns to concat into a List Series
"""
other_list: List[Union[Expr, str]]
if isinstance(other, list) and (
not isinstance(other[0], (Expr, str, pli.Series))
):
return self.concat(pli.Series([other]))

other_list: List[Union[Expr, str, "pli.Series"]]
if not isinstance(other, list):
other_list = [other]
else:
other_list = copy.copy(other)
other_list = copy.copy(other) # type: ignore

other_list.insert(0, wrap_expr(self._pyexpr))
return pli.concat_list(other_list)

Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,7 +1091,7 @@ def format(fstring: str, *args: Union["pli.Expr", str]) -> "pli.Expr":
return concat_str(exprs, sep="")


def concat_list(exprs: Sequence[Union[str, "pli.Expr"]]) -> "pli.Expr":
def concat_list(exprs: Sequence[Union[str, "pli.Expr", "pli.Series"]]) -> "pli.Expr":
"""
Concat the arrays in a Series dtype List in linear time.
Expand Down
10 changes: 2 additions & 8 deletions py-polars/polars/internals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3672,7 +3672,7 @@ def unique(self) -> Series:
"""
return pli.select(pli.lit(wrap_s(self._s)).arr.unique()).to_series()

def concat(self, other: Union[List[Series], Series]) -> "Series":
def concat(self, other: Union[List[Series], Series, List[Any]]) -> "Series":
"""
Concat the arrays in a Series dtype List in linear time.
Expand All @@ -3681,14 +3681,8 @@ def concat(self, other: Union[List[Series], Series]) -> "Series":
other
Columns to concat into a List Series
"""
if not isinstance(other, list):
other = [other]
s = wrap_s(self._s)
names = [s.name for s in other]
names.insert(0, s.name)
df = pli.DataFrame(other)
df.insert_at_idx(0, s)
return df.select(pli.concat_list(names))[s.name]
return s.to_frame().select(pli.col(s.name).arr.concat(other)).to_series()

def get(self, index: int) -> "Series":
"""
Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1762,3 +1762,9 @@ def test_to_dict(as_series: bool, inner_dtype: Any) -> None:
for v in s.values():
assert isinstance(v, inner_dtype)
assert len(v) == len(df)


def test_df_broadcast() -> None:
df = pl.DataFrame({"a": [1, 2, 3]})
out = df.with_column(pl.Series([[1, 2]]))
assert out.shape == (3, 2)
13 changes: 13 additions & 0 deletions py-polars/tests/test_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,16 @@ def test_list_concat_rolling_window() -> None:
)
assert out.shape == (5, 5)
assert out["A_rolling"].dtype == pl.List


def test_list_append() -> None:
df = pl.DataFrame({"a": [[1, 2], [1], [1, 2, 3]]})

out = df.select([pl.col("a").arr.concat(pl.Series([[1, 2]]))])
assert out["a"][0].to_list() == [1, 2, 1, 2]

out = df.select([pl.col("a").arr.concat([1, 4])])
assert out["a"][0].to_list() == [1, 2, 1, 4]

out_s = df["a"].arr.concat(([4, 1]))
assert out_s[0].to_list() == [1, 2, 4, 1]

0 comments on commit 81129eb

Please sign in to comment.