Skip to content

Commit

Permalink
python allow returning a python object in pl.col(..).map(lambda x: ..)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Dec 3, 2021
1 parent 67c90e7 commit e7ec5a1
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 26 deletions.
55 changes: 29 additions & 26 deletions py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,30 @@ use polars::prelude::*;
use pyo3::prelude::*;
use pyo3::types::PyList;

trait ToSeries {
fn to_series(&self, py: Python, py_polars_module: &PyObject, name: &str) -> Series;
}

impl ToSeries for PyObject {
fn to_series(&self, py: Python, py_polars_module: &PyObject, name: &str) -> Series {
let py_pyseries = match self.getattr(py, "_s") {
Ok(s) => s,
// the lambda did not return a series, we try to create a new python Series
_ => {
let python_s = py_polars_module
.getattr(py, "Series")
.unwrap()
.call1(py, (name, PyList::new(py, [self])))
.unwrap();
python_s.getattr(py, "_s").unwrap()
}
};
let pyseries = py_pyseries.extract::<PySeries>(py).unwrap();
// Finally get the actual Series
pyseries.series
}
}

fn get_output_type(obj: &PyAny) -> Option<DataType> {
match obj.is_none() {
true => None,
Expand Down Expand Up @@ -82,12 +106,7 @@ pub(crate) fn binary_lambda(lambda: &PyObject, a: Series, b: Series) -> Result<S
let s = out.select_at_idx(0).unwrap().clone();
PySeries::new(s)
} else {
// unpack the wrapper in a PySeries
let py_pyseries = result_series_wrapper.getattr(py, "_s").expect(
"Could net get series attribute '_s'. Make sure that you return a Series object.",
);
// Downcast to Rust
py_pyseries.extract::<PySeries>(py).unwrap()
return Ok(result_series_wrapper.to_series(py, &pypolars.into_py(py), ""));
};

// Finally get the actual Series
Expand Down Expand Up @@ -134,17 +153,9 @@ pub fn map_single(
let py = gil.python();

// this is a python Series
let out = call_lambda_with_series(py, s, &lambda, &pypolars);

// unpack the wrapper in a PySeries
let py_pyseries = out.getattr(py, "_s").expect(
"Could net get series attribute '_s'. \
Make sure that you return a Series object from a custom function.",
);
// Downcast to Rust
let pyseries = py_pyseries.extract::<PySeries>(py).unwrap();
// Finally get the actual Series
Ok(pyseries.series)
let out = call_lambda_with_series(py, s.clone(), &lambda, &pypolars);

Ok(out.to_series(py, &pypolars, s.name()))
};

let output_map = GetOutput::map_field(move |fld| match output_type {
Expand Down Expand Up @@ -209,15 +220,7 @@ pub fn map_mul(
return Err(PolarsError::NoData("".into()));
}

// unpack the wrapper in a PySeries
let py_pyseries = out.getattr(py, "_s").expect(
"Could net get series attribute '_s'. \
Make sure that you return a Series object from a custom function.",
);
// Downcast to Rust
let pyseries = py_pyseries.extract::<PySeries>(py).unwrap();
// Finally get the actual Series
Ok(pyseries.series)
Ok(out.to_series(py, &pypolars, ""))
};

let exprs = pyexpr.iter().map(|pe| pe.clone().inner).collect::<Vec<_>>();
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import reduce
from typing import List, Optional

import polars as pl
Expand Down Expand Up @@ -41,3 +42,11 @@ def func(s: List) -> Optional[int]:
)
)["multiple"]
assert out[1] is None


def test_apply_return_py_object() -> None:
df = pl.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})

out = df.select([pl.all().map(lambda s: reduce(lambda a, b: a + b, s))])

assert out.shape == (1, 2)

0 comments on commit e7ec5a1

Please sign in to comment.