Skip to content

Commit

Permalink
add reference visitor APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed May 11, 2024
1 parent 2a15614 commit e844799
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 8 deletions.
71 changes: 67 additions & 4 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ pub trait TreeNode: Sized {
.visit_parent(|| visitor.f_up(self))
}

fn visit_ref<'n, V: TreeNodeRefVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
.f_down(self)?
.visit_children(|| self.apply_children_ref(|c| c.visit_ref(visitor)))?
.visit_parent(|| visitor.f_up(self))
}

/// Rewrite the tree node with a [`TreeNodeRewriter`], performing a
/// depth-first walk of the node and its children.
///
Expand Down Expand Up @@ -204,6 +214,24 @@ pub trait TreeNode: Sized {
apply_impl(self, &mut f)
}

fn apply_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_ref_impl<
'n,
N: TreeNode,
F: FnMut(&'n N) -> Result<TreeNodeRecursion>,
>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children_ref(|c| apply_ref_impl(c, f)))
}

apply_ref_impl(self, &mut f)
}

/// Recursively rewrite the node's children and then the node using `f`
/// (a bottom-up post-order traversal).
///
Expand Down Expand Up @@ -430,6 +458,13 @@ pub trait TreeNode: Sized {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion> {
self.apply_children_ref(f)
}

fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion>;

/// Low-level API used to implement other APIs.
Expand Down Expand Up @@ -483,6 +518,23 @@ pub trait TreeNodeVisitor: Sized {
}
}

pub trait TreeNodeRefVisitor<'n>: Sized {
/// The node type which is visitable.
type Node: TreeNode;

/// Invoked while traversing down the tree, before any children are visited.
/// Default implementation continues the recursion.
fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}

/// Invoked while traversing up the tree after children are visited. Default
/// implementation continues the recursion.
fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
/// rewriting [`TreeNode`]s via [`TreeNode::rewrite`].
///
Expand Down Expand Up @@ -857,6 +909,10 @@ pub trait DynTreeNode {
/// Returns all children of the specified `TreeNode`.
fn arc_children(&self) -> Vec<Arc<Self>>;

fn children(&self) -> Vec<&Arc<Self>> {
panic!("DynTreeNode::children is not implemented yet")
}

/// Constructs a new node with the specified children.
fn with_new_arc_children(
&self,
Expand All @@ -875,6 +931,13 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
self.arc_children().iter().apply_until_stop(f)
}

fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
Expand Down Expand Up @@ -913,8 +976,8 @@ pub trait ConcreteTreeNode: Sized {
}

impl<T: ConcreteTreeNode> TreeNode for T {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
Expand Down Expand Up @@ -959,8 +1022,8 @@ mod tests {
}

impl<T> TreeNode for TestTreeNode<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children.iter().apply_until_stop(f)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ use datafusion_common::{
};

impl TreeNode for LogicalPlan {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.inputs().into_iter().apply_until_stop(f)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_common::tree_node::{
use datafusion_common::{internal_err, map_until_stop_and_collect, Result};

impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
let children = match self {
Expand Down

0 comments on commit e844799

Please sign in to comment.