Skip to content

Commit

Permalink
A restricted expression replacement scheme
Browse files Browse the repository at this point in the history
Rather than rewriting the entire expression, allow
the user to add some expressions to the existing
arena and wire in an indirection map for
view_expression_node.
  • Loading branch information
wence- committed Apr 23, 2024
1 parent 8fe7627 commit b766402
Showing 1 changed file with 34 additions and 20 deletions.
54 changes: 34 additions & 20 deletions py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct NodeTraverser {
expr_arena: Arc<RwLock<Arena<AExpr>>>,
scratch: Vec<Node>,
expr_scratch: Vec<ExprIR>,
expr_mapping: Option<Vec<Node>>,
}

impl NodeTraverser {
Expand Down Expand Up @@ -149,30 +150,42 @@ impl NodeTraverser {

fn view_expression(&self, py: Python<'_>, node: usize) -> PyResult<PyObject> {
let expr_arena = self.expr_arena.read().unwrap();
let expr = expr_arena.get(Node(node));
let n = match &self.expr_mapping {
Some(mapping) => *mapping.get(node).unwrap(),
None => Node(node),
};
let expr = expr_arena.get(n);
expr_nodes::into_py(py, expr)
}

fn replace_expressions(&self, expressions: Vec<(usize, PyExpr)>) -> PyResult<Self> {
let mut expr_arena = self.expr_arena.read().unwrap().to_owned();
let nexprs = expr_arena.len();
for (idx, pyexpr) in expressions.iter() {
if *idx >= nexprs {
raise_err!(
format!("Attempting to replace out of bounds index {}", *idx),
OutOfBounds
);
}
let expr = to_aexpr(pyexpr.inner.clone(), &mut expr_arena);
expr_arena.swap(Node(*idx), expr);
/// Add some expressions to the arena and return their new node ids as well
/// as the total number of nodes in the arena.
fn add_expressions(&mut self, expressions: Vec<PyExpr>) -> PyResult<(Vec<usize>, usize)> {
let mut expr_arena: std::sync::RwLockWriteGuard<'_, Arena<AExpr>> =
self.expr_arena.write().unwrap();
Ok((
expressions
.iter()
.map(|e| to_aexpr(e.inner.clone(), &mut expr_arena).0)
.collect(),
expr_arena.len(),
))
}

/// Set up a mapping of expression nodes used in `view_expression_node``.
/// With a mapping set, `view_expression_node(i)` produces the node for
/// `mapping[i]`.
fn set_expr_mapping(&mut self, mapping: Vec<usize>) -> PyResult<()> {
if mapping.len() != self.expr_arena.read().unwrap().len() {
raise_err!("Invalid mapping length", ComputeError);
}
Ok(NodeTraverser {
root: self.root,
lp_arena: self.lp_arena.clone(),
expr_arena: Arc::new(RwLock::new(expr_arena)),
scratch: vec![],
expr_scratch: vec![],
})
self.expr_mapping = Some(mapping.into_iter().map(Node).collect());
Ok(())
}

/// Unset the expression mapping (reinstates the identity map)
fn unset_expr_mapping(&mut self) {
self.expr_mapping = None;
}
}

Expand All @@ -193,6 +206,7 @@ impl PyLazyFrame {
expr_arena: Arc::new(RwLock::new(expr_arena)),
scratch: vec![],
expr_scratch: vec![],
expr_mapping: None,
})
}
}

0 comments on commit b766402

Please sign in to comment.