Skip to content

Commit

Permalink
fix(python): explicit output type in apply (#5328)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 25, 2022
1 parent 6d2ae61 commit de9bffa
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 21 deletions.
2 changes: 1 addition & 1 deletion py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3019,7 +3019,7 @@ def map(
def apply(
self,
f: Callable[[pli.Series], pli.Series] | Callable[[Any], Any],
return_dtype: type[DataType] | None = None,
return_dtype: PolarsDataType | None = None,
) -> Expr:
"""
Apply a custom/user-defined function (UDF) in a GroupBy or Projection context.
Expand Down
2 changes: 1 addition & 1 deletion py-polars/polars/internals/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3264,7 +3264,7 @@ def tanh(self) -> Series:
def apply(
self,
func: Callable[[Any], Any],
return_dtype: type[DataType] | None = None,
return_dtype: PolarsDataType | None = None,
skip_nulls: bool = True,
) -> Series:
"""
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use pyo3::{PyAny, PyResult};
use crate::dataframe::PyDataFrame;
use crate::error::PyPolarsErr;
use crate::lazy::dataframe::PyLazyFrame;
use crate::object::OBJECT_NAME;
use crate::prelude::*;
use crate::py_modules::POLARS;
use crate::series::PySeries;
Expand Down Expand Up @@ -343,7 +344,7 @@ impl FromPyObject<'_> for Wrap<DataType> {
"Float32" => DataType::Float32,
"Float64" => DataType::Float64,
#[cfg(feature = "object")]
"Object" => DataType::Object("Object"),
"Object" => DataType::Object(OBJECT_NAME),
"List" => DataType::List(Box::new(DataType::Boolean)),
"Null" => DataType::Null,
"Unknown" => DataType::Unknown,
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use polars::prelude::*;
use pyo3::{FromPyObject, PyAny, PyResult};

use crate::object::OBJECT_NAME;
use crate::Wrap;

// Don't change the order of these!
Expand Down Expand Up @@ -92,7 +93,7 @@ impl From<PyDataType> for DataType {
PyDataType::Duration(tu) => Duration(tu),
PyDataType::Time => Time,
#[cfg(feature = "object")]
PyDataType::Object => Object("object"),
PyDataType::Object => Object(OBJECT_NAME),
PyDataType::Categorical => Categorical(None),
PyDataType::Struct => Struct(vec![]),
}
Expand Down
19 changes: 5 additions & 14 deletions py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use pyo3::prelude::*;
use pyo3::types::PyList;

use crate::lazy::dsl::PyExpr;
use crate::prelude::PyDataType;
use crate::py_modules::POLARS;
use crate::series::PySeries;
use crate::Wrap;

trait ToSeries {
fn to_series(&self, py: Python, py_polars_module: &PyObject, name: &str) -> Series;
Expand Down Expand Up @@ -39,13 +39,6 @@ impl ToSeries for PyObject {
}
}

fn get_output_type(obj: &PyAny) -> Option<DataType> {
match obj.is_none() {
true => None,
false => Some(obj.extract::<PyDataType>().unwrap().into()),
}
}

pub(crate) fn call_lambda_with_series(
py: Python,
s: Series,
Expand Down Expand Up @@ -122,10 +115,10 @@ pub(crate) fn binary_lambda(lambda: &PyObject, a: Series, b: Series) -> PolarsRe
pub fn map_single(
pyexpr: &PyExpr,
lambda: PyObject,
output_type: &PyAny,
output_type: Option<Wrap<DataType>>,
agg_list: bool,
) -> PyExpr {
let output_type = get_output_type(output_type);
let output_type = output_type.map(|wrap| wrap.0);

let output_type2 = output_type.clone();
let function = move |s: Series| {
Expand Down Expand Up @@ -191,11 +184,9 @@ pub fn map_mul(
pyexpr: &[PyExpr],
py: Python,
lambda: PyObject,
output_type: &PyAny,
output_type: Option<Wrap<DataType>>,
apply_groups: bool,
) -> PyExpr {
let output_type = get_output_type(output_type);

// get the pypolars module
// do the import outside of the function to prevent import side effects in a hot loop.
let pypolars = PyModule::import(py, "polars").unwrap().to_object(py);
Expand All @@ -217,7 +208,7 @@ pub fn map_mul(
let exprs = pyexpr.iter().map(|pe| pe.clone().inner).collect::<Vec<_>>();

let output_map = GetOutput::map_field(move |fld| match output_type {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
Some(ref dt) => Field::new(fld.name(), dt.0.clone()),
None => fld.clone(),
});
if apply_groups {
Expand Down
7 changes: 6 additions & 1 deletion py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,12 @@ impl PyExpr {
.into()
}

pub fn map(&self, lambda: PyObject, output_type: &PyAny, agg_list: bool) -> PyExpr {
pub fn map(
&self,
lambda: PyObject,
output_type: Option<Wrap<DataType>>,
agg_list: bool,
) -> PyExpr {
map_single(self, lambda, output_type, agg_list)
}

Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ pub fn map_mul(
py: Python,
pyexpr: Vec<PyExpr>,
lambda: PyObject,
output_type: &PyAny,
output_type: Option<Wrap<DataType>>,
apply_groups: bool,
) -> PyExpr {
lazy::map_mul(&pyexpr, py, lambda, output_type, apply_groups)
Expand Down
2 changes: 2 additions & 0 deletions py-polars/src/object.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ use pyo3::prelude::*;
use crate::prelude::ObjectValue;
use crate::Wrap;

pub(crate) const OBJECT_NAME: &str = "object";

pub(crate) fn register_object_builder() {
if !registry::is_object_builder_registered() {
let object_builder = Box::new(|name: &str, capacity: usize| {
Expand Down
2 changes: 1 addition & 1 deletion py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,7 @@ impl PySeries {
}
None => return dispatch_apply!(series, apply_lambda_unknown, py, lambda),

_ => return dispatch_apply!(series, apply_lambda, py, lambda),
_ => return dispatch_apply!(series, apply_lambda_unknown, py, lambda),
};

Ok(PySeries::new(out))
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,16 @@ def test_apply_object_dtypes() -> None:
"is_numeric1": [True, True, False, True, True],
"is_numeric_infer": [True, True, False, True, True],
}


def test_apply_explicit_list_output_type() -> None:
out = pl.DataFrame({"str": ["a", "b"]}).with_columns(
[
pl.col("str").apply(
lambda _: pl.Series([1, 2, 3]), return_dtype=pl.List(pl.Int64)
)
]
)

assert out.dtypes == [pl.List(pl.Int64)]
assert out.to_dict(False) == {"str": [[1, 2, 3], [1, 2, 3]]}

0 comments on commit de9bffa

Please sign in to comment.