Skip to content

Commit

Permalink
Adapt to latest IR updates
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- committed Apr 29, 2024
1 parent 1c32963 commit bf2b301
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 112 deletions.
8 changes: 2 additions & 6 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ impl NodeTraverser {
self.scratch_to_list()
}

/// Get Schema as python dict<str, pl.DataType>
/// Get Schema of current node as python dict<str, pl.DataType>
fn get_schema(&self, py: Python<'_>) -> PyObject {
let lp_arena = self.lp_arena.read().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena).into_owned();
Expand Down Expand Up @@ -139,12 +139,8 @@ impl NodeTraverser {
}

fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
self.view_node(py, self.root.0)
}

fn view_node(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
let lp_arena = self.lp_arena.read().unwrap();
let lp_node = lp_arena.get(Node(node));
let lp_node = lp_arena.get(self.root);
nodes::into_py(py, lp_node)
}

Expand Down
40 changes: 32 additions & 8 deletions py-polars/src/lazyframe/visitor/expr_nodes.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use polars_core::series::IsSorted;
use polars_plan::dsl::function_expr::rolling::RollingFunction;
use polars_plan::dsl::function_expr::trigonometry::TrigonometricFunction;
use polars_plan::dsl::BooleanFunction;
use polars_plan::prelude::{
AAggExpr, AExpr, FunctionExpr, GroupbyOptions, LiteralValue, Operator, PowFunction,
WindowMapping, WindowType,
};
use polars_rs::series::IsSorted;
use polars_time::prelude::RollingGroupOptions;
use pyo3::exceptions::PyNotImplementedError;
use pyo3::prelude::*;
Expand Down Expand Up @@ -168,7 +168,8 @@ pub struct SortBy {
#[pyo3(get)]
by: Vec<usize>,
#[pyo3(get)]
descending: Vec<bool>,
/// descending, nulls_last, maintain_order
sort_options: (Vec<bool>, bool, bool),
}

#[pyclass]
Expand Down Expand Up @@ -334,6 +335,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
use LiteralValue::*;
let dtype: PyObject = Wrap(lit.get_datatype()).to_object(py);
match lit {
Float(v) => Literal {
value: v.to_object(py),
dtype,
},
Float32(v) => Literal {
value: v.to_object(py),
dtype,
Expand All @@ -342,6 +347,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
value: v.to_object(py),
dtype,
},
Int(v) => Literal {
value: v.to_object(py),
dtype,
},
Int8(v) => Literal {
value: v.to_object(py),
dtype,
Expand Down Expand Up @@ -378,6 +387,10 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
value: v.to_object(py),
dtype,
},
StrCat(v) => Literal {
value: v.to_object(py),
dtype,
},
String(v) => Literal {
value: v.to_object(py),
dtype,
Expand Down Expand Up @@ -442,11 +455,15 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
AExpr::SortBy {
expr,
by,
descending,
sort_options,
} => SortBy {
expr: expr.0,
by: by.iter().map(|n| n.0).collect(),
descending: descending.clone(),
sort_options: (
sort_options.descending.clone(),
sort_options.nulls_last,
sort_options.maintain_order,
),
}
.into_py(py),
AExpr::Agg(aggexpr) => match aggexpr {
Expand Down Expand Up @@ -606,9 +623,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
.to_object(py),
FunctionExpr::Atan2 => ("atan2",).to_object(py),
FunctionExpr::Sign => ("sign",).to_object(py),
FunctionExpr::FillNull { super_type: _ } => {
return Err(PyNotImplementedError::new_err("fill null"))
},
FunctionExpr::FillNull => return Err(PyNotImplementedError::new_err("fill null")),
FunctionExpr::RollingExpr(rolling) => match rolling {
RollingFunction::Min(_) => {
return Err(PyNotImplementedError::new_err("rolling min"))
Expand Down Expand Up @@ -679,7 +694,9 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
has_max: _,
} => return Err(PyNotImplementedError::new_err("clip")),
FunctionExpr::AsStruct => return Err(PyNotImplementedError::new_err("as struct")),
FunctionExpr::TopK(_) => return Err(PyNotImplementedError::new_err("top k")),
FunctionExpr::TopK { sort_options: _ } => {
return Err(PyNotImplementedError::new_err("top k"))
},
FunctionExpr::CumCount { reverse } => ("cumcount", reverse).to_object(py),
FunctionExpr::CumSum { reverse } => ("cumsum", reverse).to_object(py),
FunctionExpr::CumProd { reverse } => ("cumprod", reverse).to_object(py),
Expand Down Expand Up @@ -806,6 +823,13 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult<PyObject> {
FunctionExpr::Business(_) => {
return Err(PyNotImplementedError::new_err("business"))
},
FunctionExpr::TopKBy { sort_options: _ } => {
return Err(PyNotImplementedError::new_err("top_k_by"))
},
FunctionExpr::EwmMeanBy {
half_life: _,
check_sorted: _,
} => return Err(PyNotImplementedError::new_err("ewm_mean_by")),
},
options: py.None(),
}
Expand Down

0 comments on commit bf2b301

Please sign in to comment.