Skip to content

Commit

Permalink
feat: Add .arr.to_list expression (#12136)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Nov 1, 2023
1 parent b7bc205 commit 6f76b70
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 33 deletions.
6 changes: 6 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Expand Up @@ -37,4 +37,10 @@ impl ArrayNameSpace {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::Unique(true)))
}

/// Cast the Array column to List column with the same inner data type.
pub fn to_list(self) -> Expr {
self.0
.map_private(FunctionExpr::ArrayExpr(ArrayFunction::ToList))
}
}
42 changes: 41 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/array.rs
@@ -1,30 +1,65 @@
use polars_ops::chunked_array::array::*;

use super::*;
use crate::map;

#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum ArrayFunction {
Min,
Max,
Sum,
ToList,
Unique(bool),
}

impl ArrayFunction {
pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
use ArrayFunction::*;
match self {
Min | Max => mapper.map_to_list_and_array_inner_dtype(),
Sum => mapper.nested_sum_type(),
ToList => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
Unique(_) => mapper.try_map_dtype(map_array_dtype_to_list_dtype),
}
}
}

fn map_array_dtype_to_list_dtype(datatype: &DataType) -> PolarsResult<DataType> {
if let DataType::Array(inner, _) = datatype {
Ok(DataType::List(inner.clone()))
} else {
polars_bail!(ComputeError: "expected array dtype")
}
}

impl Display for ArrayFunction {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use ArrayFunction::*;
let name = match self {
Min => "min",
Max => "max",
Sum => "sum",
ToList => "to_list",
Unique(_) => "unique",
};

write!(f, "arr.{name}")
}
}

impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
fn from(func: ArrayFunction) -> Self {
use ArrayFunction::*;
match func {
Min => map!(min),
Max => map!(max),
Sum => map!(sum),
ToList => map!(to_list),
Unique(stable) => map!(unique, stable),
}
}
}

pub(super) fn max(s: &Series) -> PolarsResult<Series> {
Ok(s.array()?.array_max())
}
Expand All @@ -46,3 +81,8 @@ pub(super) fn unique(s: &Series, stable: bool) -> PolarsResult<Series> {
};
out.map(|ca| ca.into_series())
}

pub(super) fn to_list(s: &Series) -> PolarsResult<Series> {
let list_dtype = map_array_dtype_to_list_dtype(s.dtype())?;
s.cast(&list_dtype)
}
10 changes: 1 addition & 9 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Expand Up @@ -863,15 +863,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
}
},
#[cfg(feature = "dtype-array")]
ArrayExpr(lf) => {
use ArrayFunction::*;
match lf {
Min => map!(array::min),
Max => map!(array::max),
Sum => map!(array::sum),
Unique(stable) => map!(array::unique, stable),
}
},
ArrayExpr(func) => func.into(),
#[cfg(feature = "dtype-struct")]
StructExpr(sf) => {
use StructFunction::*;
Expand Down
15 changes: 1 addition & 14 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Expand Up @@ -96,20 +96,7 @@ impl FunctionExpr {
}
},
#[cfg(feature = "dtype-array")]
ArrayExpr(af) => {
use ArrayFunction::*;
match af {
Min | Max => mapper.map_to_list_and_array_inner_dtype(),
Sum => mapper.nested_sum_type(),
Unique(_) => mapper.try_map_dtype(|dt| {
if let DataType::Array(inner, _) = dt {
Ok(DataType::List(inner.clone()))
} else {
polars_bail!(ComputeError: "expected array dtype")
}
}),
}
},
ArrayExpr(func) => func.get_field(mapper),
#[cfg(feature = "dtype-struct")]
AsStruct => Ok(Field::new(
fields[0].name(),
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Expand Up @@ -12,4 +12,5 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.max
Expr.arr.min
Expr.arr.sum
Expr.arr.to_list
Expr.arr.unique
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/array.rst
Expand Up @@ -12,4 +12,5 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.max
Series.arr.min
Series.arr.sum
Series.arr.to_list
Series.arr.unique
39 changes: 34 additions & 5 deletions py-polars/polars/expr/array.py
Expand Up @@ -38,7 +38,7 @@ def min(self) -> Expr:
└─────┘
"""
return wrap_expr(self._pyexpr.array_min())
return wrap_expr(self._pyexpr.arr_min())

def max(self) -> Expr:
"""
Expand All @@ -62,7 +62,7 @@ def max(self) -> Expr:
└─────┘
"""
return wrap_expr(self._pyexpr.array_max())
return wrap_expr(self._pyexpr.arr_max())

def sum(self) -> Expr:
"""
Expand All @@ -86,7 +86,7 @@ def sum(self) -> Expr:
└─────┘
"""
return wrap_expr(self._pyexpr.array_sum())
return wrap_expr(self._pyexpr.arr_sum())

def unique(self, *, maintain_order: bool = False) -> Expr:
"""
Expand All @@ -103,7 +103,7 @@ def unique(self, *, maintain_order: bool = False) -> Expr:
... {
... "a": [[1, 1, 2]],
... },
... schema_overrides={"a": pl.Array(inner=pl.Int64, width=3)},
... schema={"a": pl.Array(inner=pl.Int64, width=3)},
... )
>>> df.select(pl.col("a").arr.unique())
shape: (1, 1)
Expand All @@ -116,4 +116,33 @@ def unique(self, *, maintain_order: bool = False) -> Expr:
└───────────┘
"""
return wrap_expr(self._pyexpr.array_unique(maintain_order))
return wrap_expr(self._pyexpr.arr_unique(maintain_order))

def to_list(self) -> Expr:
"""
Convert an Array column into a List column with the same inner data type.
Returns
-------
Expr
Expression of data type :class:`List`.
Examples
--------
>>> df = pl.DataFrame(
... data={"a": [[1, 2], [3, 4]]},
... schema={"a": pl.Array(inner=pl.Int8, width=2)},
... )
>>> df.select(pl.col("a").arr.to_list())
shape: (2, 1)
┌──────────┐
│ a │
│ --- │
│ list[i8] │
╞══════════╡
│ [1, 2] │
│ [3, 4] │
└──────────┘
"""
return wrap_expr(self._pyexpr.arr_to_list())
22 changes: 22 additions & 0 deletions py-polars/polars/series/array.py
Expand Up @@ -107,3 +107,25 @@ def unique(self, *, maintain_order: bool = False) -> Series:
└───────────┘
"""

def to_list(self) -> Series:
"""
Convert an Array column into a List column with the same inner data type.
Returns
-------
Expr
Series of data type :class:`List`.
Examples
--------
>>> s = pl.Series([[1, 2], [3, 4]], dtype=pl.Array(inner=pl.Int8, width=2))
>>> s.arr.to_list()
shape: (2,)
Series: '' [list[i8]]
[
[1, 2]
[3, 4]
]
"""
13 changes: 9 additions & 4 deletions py-polars/src/expr/array.rs
Expand Up @@ -4,22 +4,27 @@ use crate::expr::PyExpr;

#[pymethods]
impl PyExpr {
fn array_max(&self) -> Self {
fn arr_max(&self) -> Self {
self.inner.clone().arr().max().into()
}

fn array_min(&self) -> Self {
fn arr_min(&self) -> Self {
self.inner.clone().arr().min().into()
}

fn array_sum(&self) -> Self {
fn arr_sum(&self) -> Self {
self.inner.clone().arr().sum().into()
}
fn array_unique(&self, maintain_order: bool) -> Self {

fn arr_unique(&self, maintain_order: bool) -> Self {
if maintain_order {
self.inner.clone().arr().unique_stable().into()
} else {
self.inner.clone().arr().unique().into()
}
}

fn arr_to_list(&self) -> Self {
self.inner.clone().arr().to_list().into()
}
}
Empty file.
42 changes: 42 additions & 0 deletions py-polars/tests/unit/namespaces/array/test_to_list.py
@@ -0,0 +1,42 @@
from __future__ import annotations

import polars as pl
from polars.testing import assert_frame_equal, assert_series_equal


def test_arr_to_list() -> None:
s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int8, width=2))

result = s.arr.to_list()

expected = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.List(pl.Int8))
assert_series_equal(result, expected)


def test_arr_to_list_lazy() -> None:
s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.Array(inner=pl.Int8, width=2))
lf = s.to_frame().lazy()

result = lf.select(pl.col("a").arr.to_list())

s = pl.Series("a", [[1, 2], [4, 3]], dtype=pl.List(pl.Int8))
expected = s.to_frame().lazy()
assert_frame_equal(result, expected)


def test_arr_to_list_nested_array_preserved() -> None:
s = pl.Series(
"a",
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
dtype=pl.Array(inner=pl.Array(inner=pl.Int8, width=2), width=2),
)
lf = s.to_frame().lazy()

result = lf.select(pl.col("a").arr.to_list())

s = pl.Series(
"a",
[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
).cast(pl.List(pl.Array(inner=pl.Int8, width=2)))
expected = s.to_frame().lazy()
assert_frame_equal(result, expected)

0 comments on commit 6f76b70

Please sign in to comment.