Skip to content

Commit

Permalink
python: apply all numeric types
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 14, 2022
1 parent d945663 commit df60596
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 20 deletions.
3 changes: 1 addition & 2 deletions py-polars/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
#![feature(vec_into_raw_parts)]
#![allow(clippy::nonstandard_macro_braces)] // needed because clippy does not understand proc macro of pyo3
#![allow(clippy::transmute_undefined_repr)]
#[macro_use]
extern crate polars;
extern crate core;
extern crate polars;

use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
Expand Down
36 changes: 18 additions & 18 deletions py-polars/src/series.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use crate::error::PyPolarsErr;
use crate::list_construction::py_seq_to_list;
use crate::utils::reinterpret;
use crate::{
arrow_interop,
apply_method_all_arrow_series2, arrow_interop,
npy::{aligned_array, get_refcnt},
prelude::*,
};
Expand Down Expand Up @@ -921,7 +921,7 @@ impl PySeries {

let out = match output_type {
Some(DataType::Int8) => {
let ca: Int8Chunked = apply_method_all_arrow_series!(
let ca: Int8Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -932,7 +932,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int16) => {
let ca: Int16Chunked = apply_method_all_arrow_series!(
let ca: Int16Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -943,7 +943,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int32) => {
let ca: Int32Chunked = apply_method_all_arrow_series!(
let ca: Int32Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -954,7 +954,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Int64) => {
let ca: Int64Chunked = apply_method_all_arrow_series!(
let ca: Int64Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -965,7 +965,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt8) => {
let ca: UInt8Chunked = apply_method_all_arrow_series!(
let ca: UInt8Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -976,7 +976,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt16) => {
let ca: UInt16Chunked = apply_method_all_arrow_series!(
let ca: UInt16Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -987,7 +987,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt32) => {
let ca: UInt32Chunked = apply_method_all_arrow_series!(
let ca: UInt32Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -998,7 +998,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::UInt64) => {
let ca: UInt64Chunked = apply_method_all_arrow_series!(
let ca: UInt64Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -1009,7 +1009,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Float32) => {
let ca: Float32Chunked = apply_method_all_arrow_series!(
let ca: Float32Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -1020,7 +1020,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Float64) => {
let ca: Float64Chunked = apply_method_all_arrow_series!(
let ca: Float64Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -1031,7 +1031,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Boolean) => {
let ca: BooleanChunked = apply_method_all_arrow_series!(
let ca: BooleanChunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_bool_out_type,
py,
Expand All @@ -1042,7 +1042,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Date) => {
let ca: Int32Chunked = apply_method_all_arrow_series!(
let ca: Int32Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -1053,7 +1053,7 @@ impl PySeries {
ca.into_date().into_series()
}
Some(DataType::Datetime(tu, tz)) => {
let ca: Int64Chunked = apply_method_all_arrow_series!(
let ca: Int64Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_primitive_out_type,
py,
Expand All @@ -1064,7 +1064,7 @@ impl PySeries {
ca.into_datetime(tu, tz).into_series()
}
Some(DataType::Utf8) => {
let ca: Utf8Chunked = apply_method_all_arrow_series!(
let ca: Utf8Chunked = apply_method_all_arrow_series2!(
series,
apply_lambda_with_utf8_out_type,
py,
Expand All @@ -1075,7 +1075,7 @@ impl PySeries {
ca.into_series()
}
Some(DataType::Object(_)) => {
let ca: ObjectChunked<ObjectValue> = apply_method_all_arrow_series!(
let ca: ObjectChunked<ObjectValue> = apply_method_all_arrow_series2!(
series,
apply_lambda_with_object_out_type,
py,
Expand All @@ -1086,10 +1086,10 @@ impl PySeries {
ca.into_series()
}
None => {
return apply_method_all_arrow_series!(series, apply_lambda_unknown, py, lambda);
return apply_method_all_arrow_series2!(series, apply_lambda_unknown, py, lambda);
}

_ => return apply_method_all_arrow_series!(series, apply_lambda, py, lambda),
_ => return apply_method_all_arrow_series2!(series, apply_lambda, py, lambda),
};

Ok(PySeries::new(out))
Expand Down
26 changes: 26 additions & 0 deletions py-polars/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,29 @@ pub fn reinterpret(s: &Series, signed: bool) -> polars::prelude::Result<Series>
)),
}
}

// was redefined because I could not get feature flags activated?
#[macro_export]
macro_rules! apply_method_all_arrow_series2 {
($self:expr, $method:ident, $($args:expr),*) => {
match $self.dtype() {
DataType::Boolean => $self.bool().unwrap().$method($($args),*),
DataType::Utf8 => $self.utf8().unwrap().$method($($args),*),
DataType::UInt8 => $self.u8().unwrap().$method($($args),*),
DataType::UInt16 => $self.u16().unwrap().$method($($args),*),
DataType::UInt32 => $self.u32().unwrap().$method($($args),*),
DataType::UInt64 => $self.u64().unwrap().$method($($args),*),
DataType::Int8 => $self.i8().unwrap().$method($($args),*),
DataType::Int16 => $self.i16().unwrap().$method($($args),*),
DataType::Int32 => $self.i32().unwrap().$method($($args),*),
DataType::Int64 => $self.i64().unwrap().$method($($args),*),
DataType::Float32 => $self.f32().unwrap().$method($($args),*),
DataType::Float64 => $self.f64().unwrap().$method($($args),*),
DataType::Date => $self.date().unwrap().$method($($args),*),
DataType::Datetime(_, _) => $self.datetime().unwrap().$method($($args),*),
DataType::List(_) => $self.list().unwrap().$method($($args),*),
DataType::Struct(_) => $self.struct_().unwrap().$method($($args),*),
dt => panic!("dtype {:?} not supported", dt)
}
}
}
16 changes: 16 additions & 0 deletions py-polars/tests/test_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,3 +181,19 @@ def test_apply_list_anyvalue_fallback() -> None:
assert df.select(pl.col("text").apply(json.loads)).to_dict(False) == {
"text": [[], [{"x": 1, "y": 2}, {"x": 3, "y": 4}], [{"x": 1, "y": 2}]]
}


def test_apply_all_types() -> None:
dtypes = [
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
]
# test we don't panic
for dtype in dtypes:
pl.Series([1, 2, 3, 4, 5], dtype=dtype).apply(lambda x: x)

0 comments on commit df60596

Please sign in to comment.