Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Expose plan and expression nodes through NodeTraverser to Python #15776

Merged
merged 12 commits into from
Apr 30, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Copyright (c) 2020 Ritchie Vink
Some portions Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ mod bounds;
#[cfg(feature = "business")]
mod business;
#[cfg(feature = "dtype-categorical")]
mod cat;
pub mod cat;
#[cfg(feature = "round_series")]
mod clip;
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -38,13 +38,13 @@ mod nan;
mod peaks;
#[cfg(feature = "ffi_plugin")]
mod plugin;
mod pow;
pub mod pow;
#[cfg(feature = "random")]
mod random;
#[cfg(feature = "range")]
mod range;
#[cfg(feature = "rolling_window")]
mod rolling;
pub mod rolling;
#[cfg(feature = "round_series")]
mod round;
#[cfg(feature = "row_hash")]
Expand All @@ -63,7 +63,7 @@ mod struct_;
#[cfg(any(feature = "temporal", feature = "date_offset"))]
mod temporal;
#[cfg(feature = "trigonometry")]
mod trigonometry;
pub mod trigonometry;
mod unique;

use std::fmt::{Display, Formatter};
Expand All @@ -88,10 +88,10 @@ pub use self::boolean::BooleanFunction;
#[cfg(feature = "business")]
pub(super) use self::business::BusinessFunction;
#[cfg(feature = "dtype-categorical")]
pub(crate) use self::cat::CategoricalFunction;
pub use self::cat::CategoricalFunction;
#[cfg(feature = "temporal")]
pub(super) use self::datetime::TemporalFunction;
pub(super) use self::pow::PowFunction;
pub use self::pow::PowFunction;
#[cfg(feature = "range")]
pub(super) use self::range::RangeFunction;
#[cfg(feature = "rolling_window")]
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pub mod dt;
mod expr;
mod expr_dyn_fn;
mod from;
pub(crate) mod function_expr;
pub mod function_expr;
pub mod functions;
mod list;
#[cfg(feature = "meta")]
Expand Down
1 change: 1 addition & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ polars-lazy = { workspace = true, features = ["python"] }
polars-ops = { workspace = true }
polars-parquet = { workspace = true, optional = true }
polars-plan = { workspace = true }
polars-time = { workspace = true }
polars-utils = { workspace = true }

ahash = { workspace = true }
Expand Down
10 changes: 10 additions & 0 deletions py-polars/src/conversion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,16 @@ impl FromPyObject<'_> for Wrap<Schema> {
}
}

impl IntoPy<PyObject> for Wrap<&Schema> {
fn into_py(self, py: Python<'_>) -> PyObject {
let dict = PyDict::new(py);
for (k, v) in self.0.iter() {
dict.set_item(k.as_str(), Wrap(v.clone())).unwrap();
}
dict.into_py(py)
}
}

#[derive(Clone, Debug)]
#[repr(transparent)]
pub struct ObjectValue {
Expand Down
4 changes: 3 additions & 1 deletion py-polars/src/lazyframe/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod exitable;

mod visit;
pub(crate) mod visitor;
use std::collections::HashMap;
use std::io::BufWriter;
use std::num::NonZeroUsize;
Expand All @@ -13,6 +14,7 @@ use polars_core::prelude::*;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyDict, PyList};
pub(crate) use visit::PyExprIR;

use crate::arrow_interop::to_rust::pyarrow_schema_to_rust;
use crate::error::PyPolarsErr;
Expand Down
207 changes: 207 additions & 0 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
use std::sync::Mutex;

use polars_plan::logical_plan::{to_aexpr, Context, IR};
use polars_plan::prelude::expr_ir::ExprIR;
use polars_plan::prelude::{AExpr, PythonOptions};
use polars_utils::arena::{Arena, Node};
use pyo3::prelude::*;
use visitor::{expr_nodes, nodes};

use super::*;
use crate::raise_err;

#[derive(Clone)]
#[pyclass]
pub(crate) struct PyExprIR {
#[pyo3(get)]
node: usize,
#[pyo3(get)]
output_name: String,
}

impl From<ExprIR> for PyExprIR {
fn from(value: ExprIR) -> Self {
Self {
node: value.node().0,
output_name: value.output_name().into(),
}
}
}

impl From<&ExprIR> for PyExprIR {
fn from(value: &ExprIR) -> Self {
Self {
node: value.node().0,
output_name: value.output_name().into(),
}
}
}

#[pyclass]
struct NodeTraverser {
root: Node,
lp_arena: Arc<Mutex<Arena<IR>>>,
expr_arena: Arc<Mutex<Arena<AExpr>>>,
scratch: Vec<Node>,
expr_scratch: Vec<ExprIR>,
expr_mapping: Option<Vec<Node>>,
}

impl NodeTraverser {
fn fill_inputs(&mut self) {
let lp_arena = self.lp_arena.lock().unwrap();
let this_node = lp_arena.get(self.root);
self.scratch.clear();
this_node.copy_inputs(&mut self.scratch);
}

fn fill_expressions(&mut self) {
let lp_arena = self.lp_arena.lock().unwrap();
let this_node = lp_arena.get(self.root);
self.expr_scratch.clear();
this_node.copy_exprs(&mut self.expr_scratch);
}

fn scratch_to_list(&mut self) -> PyObject {
Python::with_gil(|py| {
PyList::new(py, self.scratch.drain(..).map(|node| node.0)).to_object(py)
})
}

fn expr_to_list(&mut self) -> PyObject {
Python::with_gil(|py| {
PyList::new(
py,
self.expr_scratch
.drain(..)
.map(|e| PyExprIR::from(e).into_py(py)),
)
.to_object(py)
})
}
}

#[pymethods]
impl NodeTraverser {
/// Get expression nodes
fn get_exprs(&mut self) -> PyObject {
self.fill_expressions();
self.expr_to_list()
}

/// Get input nodes
fn get_inputs(&mut self) -> PyObject {
self.fill_inputs();
self.scratch_to_list()
}

/// Get Schema of current node as python dict<str, pl.DataType>
fn get_schema(&self, py: Python<'_>) -> PyObject {
let lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena);
Wrap(&**schema).into_py(py)
}

/// Get expression dtype.
fn get_dtype(&self, expr_node: usize, py: Python<'_>) -> PyResult<PyObject> {
let expr_node = Node(expr_node);
let lp_arena = self.lp_arena.lock().unwrap();
let schema = lp_arena.get(self.root).schema(&lp_arena);
let expr_arena = self.expr_arena.lock().unwrap();
let field = expr_arena
.get(expr_node)
.to_field(&schema, Context::Default, &expr_arena)
.map_err(PyPolarsErr::from)?;
Ok(Wrap(field.dtype).to_object(py))
}

/// Set the current node in the plan.
fn set_node(&mut self, node: usize) {
self.root = Node(node);
}

/// Set a python UDF that will replace the subtree location with this function src.
fn set_udf(&mut self, function: PyObject, schema: Wrap<Schema>) {
let ir = IR::PythonScan {
options: PythonOptions {
scan_fn: Some(function.into()),
schema: Arc::new(schema.0),
output_schema: None,
with_columns: None,
pyarrow: false,
predicate: None,
n_rows: None,
},
predicate: None,
};
let mut lp_arena = self.lp_arena.lock().unwrap();
lp_arena.replace(self.root, ir);
}

fn view_current_node(&self, py: Python<'_>) -> PyResult<PyObject> {
let lp_arena = self.lp_arena.lock().unwrap();
let lp_node = lp_arena.get(self.root);
nodes::into_py(py, lp_node)
}

fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
let expr_arena = self.expr_arena.lock().unwrap();
let n = match &self.expr_mapping {
Some(mapping) => *mapping.get(node).unwrap(),
None => Node(node),
};
let expr = expr_arena.get(n);
expr_nodes::into_py(py, expr)
}

/// Add some expressions to the arena and return their new node ids as well
/// as the total number of nodes in the arena.
fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
let mut expr_arena = self.expr_arena.lock().unwrap();
Ok((
expressions
.into_iter()
.map(|e| to_aexpr(e.inner, &mut expr_arena).0)
.collect(),
expr_arena.len(),
))
}

/// Set up a mapping of expression nodes used in `view_expression_node``.
/// With a mapping set, `view_expression_node(i)` produces the node for
/// `mapping[i]`.
fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
if mapping.len() != self.expr_arena.lock().unwrap().len() {
raise_err!("Invalid mapping length", ComputeError);
}
self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
Ok(())
}

/// Unset the expression mapping (reinstates the identity map)
fn unset_expr_mapping(&mut self) {
self.expr_mapping = None;
}
}

#[pymethods]
#[allow(clippy::should_implement_trait)]
impl PyLazyFrame {
fn visit(&self) -> PyResult<NodeTraverser> {
let mut lp_arena = Arena::with_capacity(16);
let mut expr_arena = Arena::with_capacity(16);
let root = self
.ldf
.clone()
.optimize(&mut lp_arena, &mut expr_arena)
.map_err(PyPolarsErr::from)?;
Ok(NodeTraverser {
root,
lp_arena: Arc::new(Mutex::new(lp_arena)),
expr_arena: Arc::new(Mutex::new(expr_arena)),
scratch: vec![],
expr_scratch: vec![],
expr_mapping: None,
})
}
}
Loading
Loading