Skip to content

Commit

Permalink
python: add apply on struct dtype (#3003)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 29, 2022
1 parent 1a2bf9b commit 829905b
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 17 deletions.
44 changes: 44 additions & 0 deletions polars/polars-core/src/chunked_array/iterator/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use crate::prelude::*;
#[cfg(feature = "dtype-struct")]
use crate::series::iterator::SeriesIter;
use crate::utils::CustomIterTools;
use arrow::array::*;
use std::convert::TryFrom;
Expand Down Expand Up @@ -345,6 +347,48 @@ impl<T: PolarsObject> ObjectChunked<T> {
}
}

// Make sure to call `rechunk` first!
#[cfg(feature = "dtype-struct")]
impl<'a> IntoIterator for &'a StructChunked {
type Item = &'a [AnyValue<'a>];
type IntoIter = StructIter<'a>;

fn into_iter(self) -> Self::IntoIter {
let field_iter = self.fields().iter().map(|s| s.iter()).collect();

StructIter {
field_iter,
buf: vec![],
}
}
}

#[cfg(feature = "dtype-struct")]
pub struct StructIter<'a> {
field_iter: Vec<SeriesIter<'a>>,
buf: Vec<AnyValue<'a>>,
}

#[cfg(feature = "dtype-struct")]
impl<'a> Iterator for StructIter<'a> {
type Item = &'a [AnyValue<'a>];

fn next(&mut self) -> Option<Self::Item> {
self.buf.clear();

for it in &mut self.field_iter {
self.buf.push(it.next()?);
}
// Safety:
// Lifetime is bound to struct, we just cannot set the lifetime for the iterator trait
unsafe {
Some(std::mem::transmute::<&'_ [AnyValue], &'a [AnyValue]>(
&self.buf,
))
}
}
}

/// Wrapper struct to convert an iterator of type `T` into one of type `Option<T>`. It is useful to make the
/// `IntoIterator` trait, in which every iterator shall return an `Option<T>`.
pub struct SomeIterator<I>(I)
Expand Down
3 changes: 3 additions & 0 deletions polars/polars-core/src/chunked_array/ops/chunkops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ fn slice(
break;
}
}
if new_chunks.is_empty() {
new_chunks.push(chunks[0].slice(0, 0).into());
}
new_chunks
}

Expand Down
2 changes: 1 addition & 1 deletion polars/polars-core/src/series/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ impl FromIterator<String> for Series {
}
}

#[cfg(feature = "rows")]
#[cfg(any(feature = "rows", feature = "dtype-struct"))]
impl Series {
pub fn iter(&self) -> SeriesIter<'_> {
let dtype = self.dtype();
Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ macro_rules! apply_method_all_arrow_series {
DataType::Date => $self.date().unwrap().$method($($args),*),
DataType::Datetime(_, _) => $self.datetime().unwrap().$method($($args),*),
DataType::List(_) => $self.list().unwrap().$method($($args),*),
DataType::Struct(_) => $self.struct_().unwrap().$method($($args),*),
dt => panic!("dtype {:?} not supported", dt)
}
}
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/lazy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,7 +1350,7 @@ def select(
return pli.DataFrame([]).select(exprs)


def struct(exprs: Union[Sequence["pli.Expr"], "pli.Expr"]) -> "pli.Expr":
def struct(exprs: Union[Sequence[Union["pli.Expr", str]], "pli.Expr"]) -> "pli.Expr":
"""
Collect several columns into a Series of dtype Struct
Expand Down
225 changes: 214 additions & 11 deletions py-polars/src/apply/series.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use crate::conversion::to_wrapped;
use crate::series::PySeries;
use crate::Wrap;
use polars::chunked_array::builder::get_list_builder;
Expand All @@ -14,15 +15,10 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
out: &'a PyAny,
null_count: usize,
) -> PyResult<PySeries> {
if out.is_instance::<PyInt>().unwrap() {
let first_value = out.extract::<i64>().unwrap();
if out.is_instance::<PyBool>().unwrap() {
let first_value = out.extract::<bool>().unwrap();
applyer
.apply_lambda_with_primitive_out_type::<Int64Type>(
py,
lambda,
null_count,
Some(first_value),
)
.apply_lambda_with_bool_out_type(py, lambda, null_count, Some(first_value))
.map(|ca| ca.into_series().into())
} else if out.is_instance::<PyFloat>().unwrap() {
let first_value = out.extract::<f64>().unwrap();
Expand All @@ -34,10 +30,15 @@ fn infer_and_finish<'a, A: ApplyLambda<'a>>(
Some(first_value),
)
.map(|ca| ca.into_series().into())
} else if out.is_instance::<PyBool>().unwrap() {
let first_value = out.extract::<bool>().unwrap();
} else if out.is_instance::<PyInt>().unwrap() {
let first_value = out.extract::<i64>().unwrap();
applyer
.apply_lambda_with_bool_out_type(py, lambda, null_count, Some(first_value))
.apply_lambda_with_primitive_out_type::<Int64Type>(
py,
lambda,
null_count,
Some(first_value),
)
.map(|ca| ca.into_series().into())
} else if out.is_instance::<PyString>().unwrap() {
let first_value = out.extract::<&str>().unwrap();
Expand Down Expand Up @@ -1683,3 +1684,205 @@ impl<'a> ApplyLambda<'a> for ObjectChunked<ObjectValue> {
}
}
}

impl<'a> ApplyLambda<'a> for StructChunked {
fn apply_lambda_unknown(&'a self, py: Python, lambda: &'a PyAny) -> PyResult<PySeries> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let mut null_count = 0;
for val in self.into_iter() {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let out = lambda.call1((arg,))?;
if out.is_none() {
null_count += 1;
continue;
}
return infer_and_finish(self, py, lambda, out, null_count);
}

// todo! full null
Ok(self.clone().into_series().into())
}

fn apply_lambda(&'a self, py: Python, lambda: &'a PyAny) -> PyResult<PySeries> {
self.apply_lambda_unknown(py, lambda)
}

fn apply_to_struct(
&'a self,
py: Python,
lambda: &'a PyAny,
init_null_count: usize,
first_value: AnyValue<'a>,
) -> PyResult<PySeries> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = 1;
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
let out = lambda.call1((arg,)).unwrap();
Some(out)
});
iterator_to_struct(it, init_null_count, first_value, self.name(), self.len())
}

fn apply_lambda_with_primitive_out_type<D>(
&'a self,
py: Python,
lambda: &'a PyAny,
init_null_count: usize,
first_value: Option<D::Native>,
) -> PyResult<ChunkedArray<D>>
where
D: PyArrowPrimitiveType,
D::Native: ToPyObject + FromPyObject<'a>,
{
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
call_lambda_and_extract(py, lambda, arg).ok()
});

Ok(iterator_to_primitive(
it,
init_null_count,
first_value,
self.name(),
self.len(),
))
}

fn apply_lambda_with_bool_out_type(
&'a self,
py: Python,
lambda: &'a PyAny,
init_null_count: usize,
first_value: Option<bool>,
) -> PyResult<BooleanChunked> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
call_lambda_and_extract(py, lambda, arg).ok()
});

Ok(iterator_to_bool(
it,
init_null_count,
first_value,
self.name(),
self.len(),
))
}

fn apply_lambda_with_utf8_out_type(
&'a self,
py: Python,
lambda: &'a PyAny,
init_null_count: usize,
first_value: Option<&str>,
) -> PyResult<Utf8Chunked> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
call_lambda_and_extract(py, lambda, arg).ok()
});

Ok(iterator_to_utf8(
it,
init_null_count,
first_value,
self.name(),
self.len(),
))
}
fn apply_lambda_with_list_out_type(
&'a self,
py: Python,
lambda: PyObject,
init_null_count: usize,
first_value: &Series,
dt: &DataType,
) -> PyResult<ListChunked> {
let skip = 1;

let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let lambda = lambda.as_ref(py);
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
call_lambda_series_out(py, lambda, arg).ok()
});
Ok(iterator_to_list(
dt,
it,
init_null_count,
Some(first_value),
self.name(),
self.len(),
))
}
fn apply_lambda_with_object_out_type(
&'a self,
py: Python,
lambda: &'a PyAny,
init_null_count: usize,
first_value: Option<ObjectValue>,
) -> PyResult<ObjectChunked<ObjectValue>> {
let collections_module = PyModule::import(py, "collections").unwrap();
let namedtuple = collections_module.getattr("namedtuple").unwrap();
let names = self.fields().iter().map(|s| s.name()).collect::<Vec<_>>();
let names = names.join(" ");
let namedtuple_constructor = namedtuple.call1(("struct", names)).unwrap();

let skip = if first_value.is_some() { 1 } else { 0 };
let it = self.into_iter().skip(init_null_count + skip).map(|val| {
let val = to_wrapped(val);
let arg = namedtuple_constructor.call1(PyTuple::new(py, val)).unwrap();
call_lambda_and_extract(py, lambda, arg).ok()
});

Ok(iterator_to_object(
it,
init_null_count,
first_value,
self.name(),
self.len(),
))
}
}
6 changes: 6 additions & 0 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ use pyo3::{PyAny, PyResult};
use std::fmt::{Display, Formatter};
use std::hash::{Hash, Hasher};

pub(crate) fn to_wrapped<T>(slice: &[T]) -> &[Wrap<T>] {
// Safety:
// Wrap is transparent.
unsafe { std::mem::transmute(slice) }
}

#[repr(transparent)]
pub struct Wrap<T>(pub T);

Expand Down
16 changes: 12 additions & 4 deletions py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,20 @@ impl ToSeries for PyObject {
Ok(s) => s,
// the lambda did not return a series, we try to create a new python Series
_ => {
let python_s = py_polars_module
let res = py_polars_module
.getattr(py, "Series")
.unwrap()
.call1(py, (name, PyList::new(py, [self])))
.unwrap();
python_s.getattr(py, "_s").unwrap()
.call1(py, (name, PyList::new(py, [self])));

match res {
Ok(python_s) => python_s.getattr(py, "_s").unwrap(),
Err(_) => {
panic!(
"expected a something that could convert to a `Series` but got: {}",
self.as_ref(py).get_type()
)
}
}
}
};
let pyseries = py_pyseries.extract::<PySeries>(py).unwrap();
Expand Down

0 comments on commit 829905b

Please sign in to comment.