Skip to content

Commit

Permalink
feat(rust, python): to_struct add upper_bound (#5714)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 4, 2022
1 parent a0a62ce commit c41eac3
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 7 deletions.
39 changes: 38 additions & 1 deletion polars/polars-lazy/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
#[cfg(feature = "list_to_struct")]
use std::sync::RwLock;

use polars_core::prelude::*;
#[cfg(feature = "diff")]
use polars_core::series::ops::NullBehavior;
Expand Down Expand Up @@ -208,11 +211,22 @@ impl ListNameSpace {
#[allow(clippy::wrong_self_convention)]
/// Convert this `List` to a `Series` of type `Struct`. The width will be determined according to
/// `ListToStructWidthStrategy` and the names of the fields determined by the given `name_generator`.
///
/// # Schema
///
/// A polars [`LazyFrame`] needs to know the schema at all time. The caller therefore must provide
/// an `upper_bound` of struct fields that will be set.
/// If this is incorrectly downstream operation may fail. For instance an `all().sum()` expression
/// will look in the current schema to determine which columns to select.
pub fn to_struct(
self,
n_fields: ListToStructWidthStrategy,
name_generator: Option<NameGenerator>,
upper_bound: usize,
) -> Expr {
// heap allocate the output type and fill it later
let out_dtype = Arc::new(RwLock::new(None::<DataType>));

self.0
.map(
move |s| {
Expand All @@ -221,7 +235,30 @@ impl ListNameSpace {
.map(|s| s.into_series())
},
// we don't yet know the fields
GetOutput::from_type(DataType::Struct(vec![])),
GetOutput::map_dtype(move |dt: &DataType| {
let out = out_dtype.read().unwrap();
match out.as_ref() {
// dtype already set
Some(dt) => dt.clone(),
// dtype still unknown, set it
None => {
drop(out);
let mut lock = out_dtype.write().unwrap();

let inner = dt.inner_dtype().unwrap();
let fields = (0..upper_bound)
.map(|i| {
let name = _default_struct_name_gen(i);
Field::from_owned(name, inner.clone())
})
.collect();
let dt = DataType::Struct(fields);

*lock = Some(dt.clone());
dt
}
}
}),
)
.with_fmt("arr.to_struct")
}
Expand Down
10 changes: 7 additions & 3 deletions polars/polars-ops/src/chunked_array/list/to_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ fn det_n_fields(ca: &ListChunked, n_fields: ListToStructWidthStrategy) -> usize

pub type NameGenerator = Arc<dyn Fn(usize) -> String + Send + Sync>;

pub fn _default_struct_name_gen(idx: usize) -> String {
format!("field_{idx}")
}

pub trait ToStruct: AsList {
fn to_struct(
&self,
Expand All @@ -56,9 +60,9 @@ pub trait ToStruct: AsList {
let ca = self.as_list();
let n_fields = det_n_fields(ca, n_fields);

let default_name_gen = |idx| format!("field_{idx}");

let name_generator = name_generator.as_deref().unwrap_or(&default_name_gen);
let name_generator = name_generator
.as_deref()
.unwrap_or(&_default_struct_name_gen);

if n_fields == 0 {
Err(PolarsError::ComputeError(
Expand Down
2 changes: 1 addition & 1 deletion py-polars/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion py-polars/polars/internals/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,7 @@ def to_struct(
self,
n_field_strategy: ToStructStrategy = "first_non_null",
name_generator: Callable[[int], str] | None = None,
upper_bound: int = 0,
) -> pli.Expr:
"""
Convert the series of type ``List`` to a series of type ``Struct``.
Expand All @@ -605,6 +606,14 @@ def to_struct(
name_generator
A custom function that can be used to generate the field names.
Default field names are `field_0, field_1 .. field_n`
upper_bound
A polars `LazyFrame` needs to know the schema at all time.
The caller therefore must provide an `upper_bound` of
struct fields that will be set.
If this is incorrectly downstream operation may fail.
For instance an `all().sum()` expression will look in
the current schema to determine which columns to select.
It is adviced to set this value in a lazy query.
Examples
--------
Expand Down Expand Up @@ -632,7 +641,7 @@ def to_struct(
"""
return pli.wrap_expr(
self._pyexpr.lst_to_struct(n_field_strategy, name_generator)
self._pyexpr.lst_to_struct(n_field_strategy, name_generator, upper_bound)
)

def eval(self, expr: pli.Expr, parallel: bool = False) -> pli.Expr:
Expand Down
12 changes: 12 additions & 0 deletions py-polars/polars/internals/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,18 @@ def to_struct(
{'col_name_0': 1, 'col_name_1': 2, 'col_name_2': None}]
"""
# We set the upper bound to 0.
# No need to create the proper schema in eager mode.
s = pli.wrap_s(self)
return (
s.to_frame()
.select(
pli.col(s.name).arr.to_struct(
n_field_strategy, name_generator, upper_bound=0
)
)
.to_series()
)

def eval(self, expr: pli.Expr, parallel: bool = False) -> pli.Series:
"""
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1415,6 +1415,7 @@ impl PyExpr {
&self,
width_strat: Wrap<ListToStructWidthStrategy>,
name_gen: Option<PyObject>,
upper_bound: usize,
) -> PyResult<Self> {
let name_gen = name_gen.map(|lambda| {
Arc::new(move |idx: usize| {
Expand All @@ -1429,7 +1430,7 @@ impl PyExpr {
.inner
.clone()
.arr()
.to_struct(width_strat.0, name_gen)
.to_struct(width_strat.0, name_gen, upper_bound)
.into())
}

Expand Down
6 changes: 6 additions & 0 deletions py-polars/tests/unit/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ def test_list_to_struct() -> None:
{"field_0": 1, "field_1": 2, "field_2": 3},
]

# set upper bound
df = pl.DataFrame({"lists": [[1, 1, 1], [0, 1, 0], [1, 0, 0]]})
assert df.lazy().select(pl.col("lists").arr.to_struct(upper_bound=3)).unnest(
"lists"
).sum().collect().columns == ["field_0", "field_1", "field_2"]


def test_sort_df_with_list_struct() -> None:
assert pl.DataFrame([{"a": 1, "b": [{"c": 1}]}]).sort("a").to_dict(False) == {
Expand Down

0 comments on commit c41eac3

Please sign in to comment.