Skip to content

Commit

Permalink
fix[python]: enforce strict rules in expr.apply (#4645)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 31, 2022
1 parent 8ea0b8e commit 803dd5b
Show file tree
Hide file tree
Showing 13 changed files with 83 additions and 23 deletions.
7 changes: 3 additions & 4 deletions polars/polars-core/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,11 +670,10 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
#[cfg(all(feature = "dtype-date", feature = "dtype-time"))]
(Date, Time) => Some(Int64),

// everything can be cast to a string
(_, Utf8) => Some(Utf8),
// every known type can be casted to a string
(dt, Utf8) if dt != &DataType::Unknown => Some(Utf8),

(dt, Null) => Some(dt.clone()),
(Null, dt) => Some(dt.clone()),

#[cfg(all(feature = "dtype-duration", feature = "dtype-datetime"))]
(Duration(lu), Datetime(ru, Some(tz))) | (Datetime(lu, Some(tz)), Duration(ru)) => {
Expand Down Expand Up @@ -719,7 +718,7 @@ fn _get_supertype(l: &DataType, r: &DataType) -> Option<DataType> {
let st = _get_supertype(inner, other)?;
Some(DataType::List(Box::new(st)))
}
(dt, Unknown) => Some(dt.clone()),
(_, Unknown) => Some(Unknown),
_ => None,
}
}
Expand Down
17 changes: 17 additions & 0 deletions polars/polars-lazy/src/dsl/dt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,23 @@ impl DateLikeNameSpace {
)
}

/// Change the underlying [`TimeZone`] of the [`Series`]. This does not modify the data.
pub fn with_time_zone(self, tz: Option<TimeZone>) -> Expr {
self.0.map(
move |s| match s.dtype() {
DataType::Datetime(_, _) => {
let mut ca = s.datetime().unwrap().clone();
ca.set_time_zone(tz.clone());
Ok(ca.into_series())
}
dt => Err(PolarsError::ComputeError(
format!("Series of dtype {:?} has got no time zone", dt).into(),
)),
},
GetOutput::same_type(),
)
}

/// Get the year of a Date/Datetime
pub fn year(self) -> Expr {
let function = move |s: Series| s.year().map(|ca| ca.into_series());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,10 @@ impl OptimizationRule for TypeCoercionRule {
} => {
let input_schema = get_schema(lp_arena, lp_node);
let (left, type_left) = get_aexpr_and_type(expr_arena, node_left, &input_schema)?;

let (right, type_right) =
get_aexpr_and_type(expr_arena, node_right, &input_schema)?;
early_escape(&type_left, &type_right)?;

// don't coerce string with number comparisons. They must error
match (&type_left, &type_right, op) {
Expand Down Expand Up @@ -449,9 +451,6 @@ impl OptimizationRule for TypeCoercionRule {
get_aexpr_and_type(expr_arena, *other, &input_schema)?;

// early return until Unknown is set
if let DataType::Unknown = &type_other {
return None;
}
early_escape(&super_type, &type_other)?;
let new_st = get_supertype(&super_type, &type_other).ok()?;
super_type = modify_supertype(new_st, self_ae, other, &type_self, &type_other)
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/datatypes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Other
Null
Object
Utf8
Unknown

Functions
~~~~~~~~~
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def version() -> str:
UInt16,
UInt32,
UInt64,
Unknown,
Utf8,
get_idx_type,
)
Expand Down Expand Up @@ -169,6 +170,7 @@ def version() -> str:
"Field",
"Struct",
"Null",
"Unknown",
"PolarsDataType",
"get_idx_type",
# polars.io
Expand Down
4 changes: 4 additions & 0 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ class Null(DataType):
"""Type representing Null / None values."""


class Unknown(DataType):
"""Type representing Datatype values that could not be determined statically."""


class List(DataType):
def __init__(self, inner: type[DataType]):
"""
Expand Down
6 changes: 2 additions & 4 deletions py-polars/polars/internals/expr/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING

import polars.internals as pli
from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Datetime, Int32
from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Int32
from polars.utils import _timedelta_to_pl_duration

if TYPE_CHECKING:
Expand Down Expand Up @@ -879,9 +879,7 @@ def with_time_zone(self, tz: str | None) -> pli.Expr:
└─────────────────────┴─────────────────────────────┘
"""
return pli.wrap_expr(self._pyexpr).map(
lambda s: s.dt.with_time_zone(tz), return_dtype=Datetime
)
return pli.wrap_expr(self._pyexpr.dt_with_time_zone(tz))

def days(self) -> pli.Expr:
"""
Expand Down
2 changes: 2 additions & 0 deletions py-polars/polars/internals/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3025,6 +3025,8 @@ def apply(
Lambda/ function to apply.
return_dtype
Dtype of the output Series.
If not set, polars will assume that
the dtype remains unchanged.
Examples
--------
Expand Down
3 changes: 2 additions & 1 deletion py-polars/src/conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ impl ToPyObject for Wrap<DataType> {
struct_class.call1((fields,)).unwrap().into()
}
DataType::Null => pl.getattr("Null").unwrap().into(),
dt => panic!("{} not supported", dt),
DataType::Unknown => pl.getattr("Unknown").unwrap().into(),
}
}
}
Expand Down Expand Up @@ -321,6 +321,7 @@ impl FromPyObject<'_> for Wrap<DataType> {
"Object" => DataType::Object("unknown"),
"List" => DataType::List(Box::new(DataType::Boolean)),
"Null" => DataType::Null,
"Unknown" => DataType::Unknown,
dt => panic!("{} not expected as Python type for dtype conversion", dt),
}
}
Expand Down
24 changes: 17 additions & 7 deletions py-polars/src/lazy/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use pyo3::types::PyList;

use crate::lazy::dsl::PyExpr;
use crate::prelude::PyDataType;
use crate::py_modules::POLARS;
use crate::series::PySeries;

trait ToSeries {
Expand Down Expand Up @@ -117,29 +118,38 @@ pub(crate) fn binary_lambda(lambda: &PyObject, a: Series, b: Series) -> Result<S

pub fn map_single(
pyexpr: &PyExpr,
py: Python,
lambda: PyObject,
output_type: &PyAny,
agg_list: bool,
) -> PyExpr {
let output_type = get_output_type(output_type);
// get the pypolars module
// do the import outside of the function to prevent import side effects in a hot loop.
let pypolars = PyModule::import(py, "polars").unwrap().to_object(py);

let output_type2 = output_type.clone();
let function = move |s: Series| {
Python::with_gil(|py| {
let output_type = output_type2.clone().unwrap_or_else(|| DataType::Unknown);

// this is a python Series
let out = call_lambda_with_series(py, s.clone(), &lambda, &pypolars)
let out = call_lambda_with_series(py, s.clone(), &lambda, &POLARS)
.map_err(|e| PolarsError::ComputeError(format!("{e}").into()))?;
let s = out.to_series(py, &POLARS, s.name());

Ok(out.to_series(py, &pypolars, s.name()))
if !matches!(output_type, DataType::Unknown) && s.dtype() != &output_type {
Err(PolarsError::SchemaMisMatch(
format!("Expected output type: '{:?}', but got '{:?}'. Set 'return_dtype' to the proper datatype.", output_type, s.dtype()).into()))
} else {
Ok(s)
}
})
};

let output_map = GetOutput::map_field(move |fld| match output_type {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
None => fld.clone(),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown);
fld
}
});
if agg_list {
pyexpr.clone().inner.map_list(function, output_map).into()
Expand Down
8 changes: 6 additions & 2 deletions py-polars/src/lazy/dsl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,10 @@ impl PyExpr {
self.inner.clone().dt().with_time_unit(tu.0).into()
}

pub fn dt_with_time_zone(&self, tz: Option<TimeZone>) -> PyExpr {
self.inner.clone().dt().with_time_zone(tz).into()
}

pub fn dt_cast_time_unit(&self, tu: Wrap<TimeUnit>) -> PyExpr {
self.inner.clone().dt().cast_time_unit(tu.0).into()
}
Expand Down Expand Up @@ -1010,8 +1014,8 @@ impl PyExpr {
.into()
}

pub fn map(&self, py: Python, lambda: PyObject, output_type: &PyAny, agg_list: bool) -> PyExpr {
map_single(self, py, lambda, output_type, agg_list)
pub fn map(&self, lambda: PyObject, output_type: &PyAny, agg_list: bool) -> PyExpr {
map_single(self, lambda, output_type, agg_list)
}

pub fn dot(&self, other: PyExpr) -> PyExpr {
Expand Down
8 changes: 6 additions & 2 deletions py-polars/tests/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,12 @@ def test_apply_custom_function() -> None:
.groupby("fruits")
.agg(
[
pl.col("cars").apply(lambda groups: groups.len()).alias("custom_1"),
pl.col("cars").apply(lambda groups: groups.len()).alias("custom_2"),
pl.col("cars")
.apply(lambda groups: groups.len(), return_dtype=pl.Int64)
.alias("custom_1"),
pl.col("cars")
.apply(lambda groups: groups.len(), return_dtype=pl.Int64)
.alias("custom_2"),
pl.count("cars").alias("cars_count"),
]
)
Expand Down
19 changes: 19 additions & 0 deletions py-polars/tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,22 @@ def test_join_as_of_by_schema() -> None:
b = pl.DataFrame({"a": [1], "b": [2], "d": [4]}).lazy()
q = a.join_asof(b, on="a", by="b")
assert q.collect().columns == q.columns


def test_unknown_apply() -> None:
df = pl.DataFrame(
{"Amount": [10, 1, 1, 5], "Flour": ["1000g", "100g", "50g", "75g"]}
)

q = df.lazy().select(
[
pl.col("Amount"),
pl.col("Flour").apply(lambda x: 100.0) / pl.col("Amount"),
]
)

assert q.collect().to_dict(False) == {
"Amount": [10, 1, 1, 5],
"Flour": [10.0, 100.0, 100.0, 20.0],
}
assert q.dtypes == [pl.Int64, pl.Unknown]

0 comments on commit 803dd5b

Please sign in to comment.