Skip to content

Commit

Permalink
Python: fallback to anyvalue for logical lambda application (#3152)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Apr 15, 2022
1 parent 2255fe6 commit 25dde0c
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 10 deletions.
26 changes: 25 additions & 1 deletion polars/polars-core/src/chunked_array/ops/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,33 @@ pub(crate) unsafe fn arr_to_any_value<'a>(
.collect();
AnyValue::Struct(vals)
}
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(tu, tz) => {
let arr = &*(arr as *const dyn Array as *const Int64Array);
let v = arr.value_unchecked(idx);
AnyValue::Datetime(v, *tu, tz)
}
#[cfg(feature = "dtype-date")]
DataType::Date => {
let arr = &*(arr as *const dyn Array as *const Int32Array);
let v = arr.value_unchecked(idx);
AnyValue::Date(v)
}
#[cfg(feature = "dtype-duration")]
DataType::Duration(tu) => {
let arr = &*(arr as *const dyn Array as *const Int64Array);
let v = arr.value_unchecked(idx);
AnyValue::Duration(v, *tu)
}
#[cfg(feature = "dtype-time")]
DataType::Time => {
let arr = &*(arr as *const dyn Array as *const Int64Array);
let v = arr.value_unchecked(idx);
AnyValue::Time(v)
}
#[cfg(feature = "object")]
DataType::Object(_) => panic!("should not be here"),
_ => unimplemented!(),
dt => panic!("not implemented for {:?}", dt),
}
}

Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/apply/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,19 @@ pub trait ApplyLambda<'a> {
) -> PyResult<ObjectChunked<ObjectValue>>;
}

fn call_lambda<'a, T>(py: Python, lambda: &'a PyAny, in_val: T) -> PyResult<&'a PyAny>
pub fn call_lambda<'a, T>(py: Python, lambda: &'a PyAny, in_val: T) -> PyResult<&'a PyAny>
where
T: ToPyObject,
{
let arg = PyTuple::new(py, &[in_val]);
lambda.call1(arg)
}

fn call_lambda_and_extract<'a, T, S>(py: Python, lambda: &'a PyAny, in_val: T) -> PyResult<S>
pub(crate) fn call_lambda_and_extract<'a, T, S>(
py: Python,
lambda: &'a PyAny,
in_val: T,
) -> PyResult<S>
where
T: ToPyObject,
S: FromPyObject<'a>,
Expand Down
34 changes: 28 additions & 6 deletions py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,8 @@ impl IntoPy<PyObject> for Wrap<AnyValue<'_>> {
todo!()
}
let pl = PyModule::import(py, "polars").unwrap();
let pli = pl.getattr("internals").unwrap();
let m_series = pli.getattr("series").unwrap();
let convert = m_series.getattr("_to_python_datetime").unwrap();
let utils = pl.getattr("utils").unwrap();
let convert = utils.getattr("_to_python_datetime").unwrap();
let py_datetime_dtype = pl.getattr("Datetime").unwrap();
match tu {
TimeUnit::Nanoseconds => convert
Expand All @@ -214,9 +213,8 @@ impl IntoPy<PyObject> for Wrap<AnyValue<'_>> {
}
AnyValue::Duration(v, tu) => {
let pl = PyModule::import(py, "polars").unwrap();
let pli = pl.getattr("internals").unwrap();
let m_series = pli.getattr("series").unwrap();
let convert = m_series.getattr("_to_python_datetime").unwrap();
let utils = pl.getattr("utils").unwrap();
let convert = utils.getattr("_to_python_timedelta").unwrap();
match tu {
TimeUnit::Nanoseconds => convert.call1((v, "ns")).unwrap().into_py(py),
TimeUnit::Microseconds => convert.call1((v, "us")).unwrap().into_py(py),
Expand Down Expand Up @@ -543,6 +541,30 @@ impl<'s> FromPyObject<'s> for Wrap<AnyValue<'s>> {
let py_pyseries = ob.getattr("_s").unwrap();
let series = py_pyseries.extract::<PySeries>().unwrap().series;
Ok(Wrap(AnyValue::List(series)))
} else if ob.get_type().name()?.contains("date") {
let gil = Python::acquire_gil();
let py = gil.python();
let pypolars = PyModule::import(py, "polars").unwrap().to_object(py);
let utils = pypolars.getattr(py, "utils").unwrap();
let utils = utils
.getattr(py, "_date_to_pl_date")
.unwrap()
.call1(py, (ob,))
.unwrap();
let v = utils.extract::<i32>(py).unwrap();
Ok(Wrap(AnyValue::Date(v)))
} else if ob.get_type().name()?.contains("timedelta") {
let gil = Python::acquire_gil();
let py = gil.python();
let pypolars = PyModule::import(py, "polars").unwrap().to_object(py);
let utils = pypolars.getattr(py, "utils").unwrap();
let utils = utils
.getattr(py, "_timedelta_to_pl_timedelta")
.unwrap()
.call1(py, (ob, "us"))
.unwrap();
let v = utils.extract::<i64>(py).unwrap();
Ok(Wrap(AnyValue::Duration(v, TimeUnit::Microseconds)))
} else {
Err(PyErr::from(PyPolarsErr::Other(format!(
"row type not supported {:?}",
Expand Down
21 changes: 20 additions & 1 deletion py-polars/src/series.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::apply::series::ApplyLambda;
use crate::apply::series::{call_lambda_and_extract, ApplyLambda};
use crate::arrow_interop::to_rust::array_to_rust;
use crate::dataframe::PyDataFrame;
use crate::error::PyPolarsErr;
Expand Down Expand Up @@ -835,6 +835,25 @@ impl PySeries {

let output_type = output_type.map(|dt| dt.0);

if matches!(
self.series.dtype(),
DataType::Datetime(_, _)
| DataType::Date
| DataType::Duration(_)
| DataType::Categorical(_)
| DataType::Time
) {
let mut avs = Vec::with_capacity(self.series.len());
let iter = self.series.iter().map(|av| {
let input = Wrap(av);
call_lambda_and_extract::<_, Wrap<AnyValue>>(py, lambda, input)
.unwrap()
.0
});
avs.extend(iter);
return Ok(Series::new(self.name(), &avs).into());
}

let out = match output_type {
Some(DataType::Int8) => {
let ca: Int8Chunked = apply_method_all_arrow_series!(
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import date, datetime, timedelta
from functools import reduce
from typing import List, Optional

Expand Down Expand Up @@ -143,3 +144,12 @@ def test_apply_numpy_int_out() -> None:
.apply(lambda cols: np.left_shift(cols["col1"], cols["shift"]))
.alias("result")
).frame_equal(pl.DataFrame({"result": [4, 8, 32, 64]}))


def test_datelike_identity() -> None:
for s in [
pl.Series([datetime(year=2000, month=1, day=1)]),
pl.Series([timedelta(hours=2)]),
pl.Series([date(year=2000, month=1, day=1)]),
]:
assert s.apply(lambda x: x).to_list() == s.to_list()
1 change: 1 addition & 0 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,7 @@ def test_join_dates() -> None:
)
dts = (
pl.from_pandas(date_times)
.cast(int)
.apply(lambda x: x + np.random.randint(1_000 * 60, 60_000 * 60))
.cast(pl.Datetime)
)
Expand Down

0 comments on commit 25dde0c

Please sign in to comment.