Skip to content

Commit

Permalink
fix apply with list output type
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 20, 2021
1 parent 942caf2 commit fec76cf
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 15 deletions.
25 changes: 21 additions & 4 deletions py-polars/src/apply/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::series::PySeries;
use polars::prelude::*;
use pyo3::conversion::FromPyObject;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyFloat, PyInt, PyString, PyTuple};
use pyo3::types::{PyBool, PyFloat, PyInt, PyString, PyTuple, PyList};

pub fn apply_lambda_unknown<'a>(
df: &'a DataFrame,
Expand Down Expand Up @@ -69,7 +69,11 @@ pub fn apply_lambda_unknown<'a>(
dt,
)
.into_series());
} else {
} else if out.is_instance::<PyList>().unwrap() {
return Err(PyPolarsEr::Other("A list output type is invalid. Do you mean to create polars List Series?\
Then return a Series object.".into()).into());
}
else {
return Err(PyPolarsEr::Other("Could not determine output type".into()).into());
}
}
Expand Down Expand Up @@ -175,9 +179,22 @@ pub fn apply_lambda_with_list_out_type<'a>(
} else {
let iter = ((init_null_count + skip)..df.height()).map(|idx| {
let iter = columns.iter().map(|s: &Series| Wrap(s.get(idx)));
let tpl = PyTuple::new(py, iter);
let tpl = (PyTuple::new(py, iter),);
match lambda.call1(tpl) {
Ok(val) => val.extract::<PySeries>().ok().map(|ps| ps.series),
Ok(val) => {
match val.getattr("_s") {
Ok(val) => {
val.extract::<PySeries>().ok().map(|ps| ps.series)
}
Err(_) => {
if val.is_none() {
None
} else {
panic!("should return a Series, got a {:?}", val)
}
}
}
},
Err(e) => panic!("python function failed {}", e),
}
});
Expand Down
6 changes: 1 addition & 5 deletions py-polars/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ use polars::prelude::*;
#[cfg(feature = "downsample")]
use polars_core::frame::groupby::resample::SampleRule;

use crate::apply::dataframe::{
apply_lambda_unknown, apply_lambda_with_bool_out_type, apply_lambda_with_primitive_out_type,
apply_lambda_with_utf8_out_type,
};
use crate::apply::dataframe::{apply_lambda_unknown, apply_lambda_with_bool_out_type, apply_lambda_with_primitive_out_type, apply_lambda_with_utf8_out_type};
use crate::conversion::{ObjectValue, Wrap};
use crate::datatypes::PyDataType;
use crate::file::get_mmap_bytes_reader;
Expand Down Expand Up @@ -891,7 +888,6 @@ impl PyDataFrame {
Some(DataType::Utf8) => {
apply_lambda_with_utf8_out_type(df, py, lambda, 0, None).into_series()
}

_ => apply_lambda_unknown(df, py, lambda)?,
};

Expand Down
18 changes: 12 additions & 6 deletions py-polars/tests/test_df.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,12 +1061,12 @@ def test_filter_date():
{"date": ["2020-01-02", "2020-01-03", "2020-01-04"], "index": [1, 2, 3]}
)
df = dataset.with_column(pl.col("date").str.strptime(pl.Date32, "%Y-%m-%d"))
assert df.filter(col("date") <= pl.lit_date(datetime(2019, 1, 3))).is_empty()
assert df.filter(col("date") < pl.lit_date(datetime(2020, 1, 4))).shape[0] == 2
assert df.filter(col("date") < pl.lit_date(datetime(2020, 1, 5))).shape[0] == 3
assert df.filter(col("date") <= pl.lit(datetime(2019, 1, 3))).is_empty()
assert df.filter(col("date") < pl.lit(datetime(2020, 1, 4))).shape[0] == 2
assert df.filter(col("date") < pl.lit(datetime(2020, 1, 5))).shape[0] == 3
assert df.filter(pl.col("date") <= pl.lit_date(datetime(2019, 1, 3))).is_empty()
assert df.filter(pl.col("date") < pl.lit_date(datetime(2020, 1, 4))).shape[0] == 2
assert df.filter(pl.col("date") < pl.lit_date(datetime(2020, 1, 5))).shape[0] == 3
assert df.filter(pl.col("date") <= pl.lit(datetime(2019, 1, 3))).is_empty()
assert df.filter(pl.col("date") < pl.lit(datetime(2020, 1, 4))).shape[0] == 2
assert df.filter(pl.col("date") < pl.lit(datetime(2020, 1, 5))).shape[0] == 3


def test_slicing():
Expand All @@ -1084,3 +1084,9 @@ def test_slicing():
2,
1,
)


def test_apply_list_return():
df = pl.DataFrame({"start": [1, 2], "end": [3, 5]})
out = df.apply(lambda r: pl.Series(range(r[0], r[1] + 1)))
assert out.to_list() == [[1, 2, 3], [2, 3, 4, 5]]

0 comments on commit fec76cf

Please sign in to comment.