Skip to content

Commit

Permalink
fix(python): fix apply function over object dtype (#5206)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Oct 14, 2022
1 parent b995b36 commit 0355b7d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 24 deletions.
59 changes: 35 additions & 24 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,23 @@ impl PySeries {

let output_type = output_type.map(|dt| dt.0);

macro_rules! dispatch_apply {
($self:expr, $method:ident, $($args:expr),*) => {
if matches!($self.dtype(), DataType::Object(_)) {
let ca = $self.0.unpack::<ObjectType<ObjectValue>>().unwrap();
ca.$method($($args),*)
} else {
apply_method_all_arrow_series2!(
$self,
$method,
$($args),*
)
}

}

}

if matches!(
self.series.dtype(),
DataType::Datetime(_, _)
Expand All @@ -808,7 +825,7 @@ impl PySeries {

let out = match output_type {
Some(DataType::Int8) => {
let ca: Int8Chunked = apply_method_all_arrow_series2!(
let ca: Int8Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -819,7 +836,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int16) => {
let ca: Int16Chunked = apply_method_all_arrow_series2!(
let ca: Int16Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -830,7 +847,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int32) => {
let ca: Int32Chunked = apply_method_all_arrow_series2!(
let ca: Int32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -841,7 +858,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int64) => {
let ca: Int64Chunked = apply_method_all_arrow_series2!(
let ca: Int64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -852,7 +869,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt8) => {
let ca: UInt8Chunked = apply_method_all_arrow_series2!(
let ca: UInt8Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -863,7 +880,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt16) => {
let ca: UInt16Chunked = apply_method_all_arrow_series2!(
let ca: UInt16Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -874,7 +891,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt32) => {
let ca: UInt32Chunked = apply_method_all_arrow_series2!(
let ca: UInt32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -885,7 +902,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt64) => {
let ca: UInt64Chunked = apply_method_all_arrow_series2!(
let ca: UInt64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -896,7 +913,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Float32) => {
let ca: Float32Chunked = apply_method_all_arrow_series2!(
let ca: Float32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -907,7 +924,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Float64) => {
let ca: Float64Chunked = apply_method_all_arrow_series2!(
let ca: Float64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -918,7 +935,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Boolean) => {
let ca: BooleanChunked = apply_method_all_arrow_series2!(
let ca: BooleanChunked = dispatch_apply!(
series,
apply_lambda_with_bool_out_type,
py,
Expand All @@ -929,7 +946,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Date) => {
let ca: Int32Chunked = apply_method_all_arrow_series2!(
let ca: Int32Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -940,7 +957,7 @@ impl PySeries {
ca.into_date().into_series()
}
Some(DataType::Datetime(tu, tz)) => {
let ca: Int64Chunked = apply_method_all_arrow_series2!(
let ca: Int64Chunked = dispatch_apply!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -951,19 +968,20 @@ impl PySeries {
ca.into_datetime(tu, tz).into_series()
}
Some(DataType::Utf8) => {
let ca: Utf8Chunked = apply_method_all_arrow_series2!(
let ca = dispatch_apply!(
series,
apply_lambda_with_utf8_out_type,
py,
lambda,
0,
None
)?;

ca.into_series()
}
#[cfg(feature = "object")]
Some(DataType::Object(_)) => {
let ca: ObjectChunked<ObjectValue> = apply_method_all_arrow_series2!(
let ca = dispatch_apply!(
series,
apply_lambda_with_object_out_type,
py,
Expand All @@ -973,16 +991,9 @@ impl PySeries {
)?;
ca.into_series()
}
None => {
return apply_method_all_arrow_series2!(
series,
apply_lambda_unknown,
py,
lambda
);
}
None => return dispatch_apply!(series, apply_lambda_unknown, py, lambda),

_ => return apply_method_all_arrow_series2!(series, apply_lambda, py, lambda),
_ => return dispatch_apply!(series, apply_lambda, py, lambda),
};

Ok(PySeries::new(out))
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,3 +239,25 @@ def test_apply_skip_nulls() -> None:

assert s.apply(lambda x: some_map[x]).to_list() == [None, "b"]
assert s.apply(lambda x: some_map[x], skip_nulls=False).to_list() == ["a", "b"]


def test_apply_object_dtypes() -> None:
assert pl.DataFrame(
{"a": pl.Series([1, 2, "a", 4, 5], dtype=pl.Object)}
).with_columns(
[
pl.col("a").apply(lambda x: x * 2, return_dtype=pl.Object),
pl.col("a")
.apply(lambda x: isinstance(x, (int, float)), return_dtype=pl.Boolean)
.alias("is_numeric1"),
pl.col("a")
.apply(lambda x: isinstance(x, (int, float)))
.alias("is_numeric_infer"),
]
).to_dict(
False
) == {
"a": [2, 4, "aa", 8, 10],
"is_numeric1": [True, True, False, True, True],
"is_numeric_infer": [True, True, False, True, True],
}

0 comments on commit 0355b7d

Please sign in to comment.