Skip to content

Commit

Permalink
use dict instead of namedtuple in apply struct (#3007)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 30, 2022
1 parent 96c920c commit 82c71db
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 48 deletions.
60 changes: 16 additions & 44 deletions py-polars/src/apply/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::Wrap;
use polars::chunked_array::builder::get_list_builder;
use polars::prelude::*;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyCFunction, PyFloat, PyInt, PyList, PyString, PyTuple};
use pyo3::types::{PyBool, PyCFunction, PyDict, PyFloat, PyInt, PyList, PyString, PyTuple};

/// Find the output type and dispatch to that implementation.
fn infer_and_finish<'a, A: ApplyLambda<'a>>(
Expand Down Expand Up @@ -1685,18 +1685,20 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
}
}

fn make_dict_arg(py: Python, names: &[&str], vals: &[AnyValue]) -> Py<PyDict> {
let dict = PyDict::new(py);
for (name, val) in names.iter().zip(to_wrapped(vals)) {
dict.set_item(name, val).unwrap()
}
dict.into_py(py)
}

impl<'a> ApplyLambda<'a> for StructChunked {
fn apply_lambda_unknown(&'a self, py: Python, lambda: &'a PyAny) -> PyResult<PySeries> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let mut null_count = 0;
for val in self.into_iter() {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
let out = lambda.call1((arg,))?;
if out.is_none() {
null_count += 1;
Expand All @@ -1720,16 +1722,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
init_null_count: usize,
first_value: AnyValue<'a>,
) -> PyResult<PySeries> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = 1;
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
let out = lambda.call1((arg,)).unwrap();
Some(out)
});
Expand All @@ -1747,16 +1744,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
D: PyArrowPrimitiveType,
D::Native: ToPyObject + FromPyObject<'a>,
{
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
call_lambda_and_extract(py, lambda, arg).ok()
});

Expand All @@ -1776,16 +1768,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
init_null_count: usize,
first_value: Option<bool>,
) -> PyResult<BooleanChunked> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
call_lambda_and_extract(py, lambda, arg).ok()
});

Expand All @@ -1805,16 +1792,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
init_null_count: usize,
first_value: Option<&str>,
) -> PyResult<Utf8Chunked> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
call_lambda_and_extract(py, lambda, arg).ok()
});

Expand All @@ -1836,16 +1818,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
) -> PyResult<ListChunked> {
let skip = 1;

let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let lambda = lambda.as_ref(py);
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
call_lambda_series_out(py, lambda, arg).ok()
});
Ok(iterator_to_list(
Expand All @@ -1864,16 +1841,11 @@ impl<'a> ApplyLambda<'a> for StructChunked {
init_null_count: usize,
first_value: Option<ObjectValue>,
) -> PyResult<ObjectChunked<ObjectValue>> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let arg = make_dict_arg(py, &names, val);
call_lambda_and_extract(py, lambda, arg).ok()
});

Expand Down
8 changes: 4 additions & 4 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ def test_apply_struct() -> None:
)
out = df.with_column(pl.struct(df.columns).alias("struct")).select(
[
pl.col("struct").apply(lambda x: x.A).alias("A_field"),
pl.col("struct").apply(lambda x: x.B).alias("B_field"),
pl.col("struct").apply(lambda x: x.C).alias("C_field"),
pl.col("struct").apply(lambda x: x.D).alias("D_field"),
pl.col("struct").apply(lambda x: x["A"]).alias("A_field"),
pl.col("struct").apply(lambda x: x["B"]).alias("B_field"),
pl.col("struct").apply(lambda x: x["C"]).alias("C_field"),
pl.col("struct").apply(lambda x: x["D"]).alias("D_field"),
]
)
expected = pl.DataFrame(
Expand Down

0 comments on commit 82c71db

Please sign in to comment.