In [1]:
%load_ext autoreload
%autoreload 2

from treekit import TreeNode as nt
import json

In [2]:

# Create a tree
tree = nt.NestedTree(id=1, value=10)
print(tree)
child2 = tree.add_child(id=2, value=20)
child3 = tree.add_child(id=3, value=30)
child4 = tree.add_child(id=4, value=40)
child4_1 = child4.add_child(id=5, value=50)
child4_2 = child4.add_child(id=6, value=60)
child4_2_1 = child4_2.add_child(id=7, value=70)
child4_2_2 = child4_2.add_child(id=8, value=80)
child4_2_3 = child4_2.add_child(id=9, value=90)
child4_2_3_1 = child4_2_3.add_child({"id":1000, "value":1000})
child4_2_3_2 = child4_2_3.add_child(whatever=1001)
child4_2_3_3 = child4_2_3.add_child(id=1002)
child3_1 = child3.add_child(id=10, value=100)
child3_2 = child3.add_child(id=11, value=110)

print(f'{tree.height()=}')
print(f'{tree.is_leaf()=}')
print(json.dumps(tree, indent=1))


{'id': 1, 'value': 10, 'children': []}
tree.height()=5
tree.is_leaf()=False
{
 "id": 1,
 "value": 10,
 "children": [
  {
   "id": 2,
   "value": 20,
   "children": []
  },
  {
   "id": 3,
   "value": 30,
   "children": [
    {
     "id": 10,
     "value": 100,
     "children": []
    },
    {
     "id": 11,
     "value": 110,
     "children": []
    }
   ]
  },
  {
   "id": 4,
   "value": 40,
   "children": [
    {
     "id": 5,
     "value": 50,
     "children": []
    },
    {
     "id": 6,
     "value": 60,
     "children": [
      {
       "id": 7,
       "value": 70,
       "children": []
      },
      {
       "id": 8,
       "value": 80,
       "children": []
      },
      {
       "id": 9,
       "value": 90,
       "children": [
        {
         "id": 1000,
         "value": 1000,
         "children": []
        },
        {
         "whatever": 1001,
         "children": []
        },
        {
         "id": 1002,
         "children": []
        }
       ]
      }
     ]

In [3]:
expr = {
    "value": "+",
    "type": "op",
    "children": [
        {
            "value": "max",
            "type": "op",
            "children": [
                {
                    "value": "+", 
                    "type": "op",
                    "children": [{"type": "var", "value": "x"},
                                 {"type": "var", "value": "x"},
                                 {"type": "var", "value": "x"}]
                },
                { "type": "const", "value": -100 }
            ]
        },
        {"type": "op",
         "value": "+",
         "children": [
             {"type": "op", "value": "*",
              "children": [{"type": "var", "value": "x"},
                           {"type": "const", "value": 1},
                           {"type": "var", "value": "y"}]},
             {"type": "const", "value": 10}, 
             {"type": "const", "value": 0},
             {"type": "var", "value": "x"},
             {"type": "var", "value": "y"},
             {"type": "var", "value": "y"}]
        },
        {
            "type": "op",
            "value": "*",
            "children": [
                {"type": "var", "value": "z"},
                {"type": "const", "value": 0},
                {"type": "const", "value": 2}
            ]
        }
    ]
}
expr_tree = nt.NestedTree(expr)
print(json.dumps(expr_tree, indent=4))

{
    "value": "+",
    "type": "op",
    "children": [
        {
            "value": "max",
            "type": "op",
            "children": [
                {
                    "value": "+",
                    "type": "op",
                    "children": [
                        {
                            "type": "var",
                            "value": "x",
                            "children": []
                        },
                        {
                            "type": "var",
                            "value": "x",
                            "children": []
                        },
                        {
                            "type": "var",
                            "value": "x",
                            "children": []
                        }
                    ]
                },
                {
                    "type": "const",
                    "value": -100,
                    "children": []
                }
        

`NestedTree` has a `visit` method, which performs a depth-first walk of the tree.

The method takes a function `fn`, which takes a `NestedTree` (recursive data
struture of `NestedTree`) and returns a value. At each node (`NestedTree` rooted
at the given node), the `visit` method calls `fn` with the node and the results
of recursively visting its children. The return value of `fn` is the return
value of the `visit` method on that node.

If we do a post-order traversal of the tree, we apply `fn` to the children of
a node before applying it to the node itself. This is useful for problems where
the children need to be processed before the node itself. For example, if we
need to compute the height of a tree, we need to compute the height of the
children before we can compute the height of the node. A more complex example
is **evaluating** an expression tree, where the value of a node is a function of
the values of its children, such as an arithmetic expression tree.

If we do a pre-order traversal of the tree, we apply `fn` to the node before
applying it to the children. This is useful, for instance, if we need to
pretty-print the tree, or in a more complex example, if we need to rewrite
the tree in a different form, for instance algebraic simplification.

We can evaluate expression trees, for instance. Here is a simple way toe
valuate the previously defined expression tree `expr_tree`. We see that it
contains certain operations, `+`, `max`, and also contains numbers and variables
`x`, `y`, and `z` that should be replaced by constants or other primitive values.
that are self-evaluating. We can define a recursive function `eval` that
evaluates the expression tree.

In [4]:
ctx = {
    "x": 0,
    "y": 1,
    "z": 2  
}

from functools import reduce
ops = {
    '+': lambda x: sum(x),
    '*': lambda x: reduce(lambda a, b: a*b, x),
    'max': lambda x: max(x)
}

def tree_eval(
        types = {
            'const': lambda node, _: node['value'],
            'var': lambda node, ctx: ctx[node['value']],
            'op': lambda node, _: ops[node['value']](c['value'] for c in node.children())
    }):
    """
    Create an evaluator for a given set of operations. The operations are a
    dictionary where the keys are the operation names and the values are
    functions that take a node and a context and return the value of the
    operation in that context.

    This `eval` function only works if post-order traversal is used to
    evaluate the expression, e.g.:
        `tree.apply(f=eval(ops), ctx=ctx, order='post')`.

    :param ops: A dictionary of operations
    """
    def _eval(node, ctx):
        return nt.NestedTree(type='const', value=types[node['type']](node, ctx))

    return _eval


In [5]:
from copy import deepcopy
post_eval = deepcopy(expr_tree).visit(fn = tree_eval(), ctx = ctx, order='post')
print(f'{post_eval=}')

post_eval={'type': 'const', 'value': 12, 'children': []}


We see that we get the expected result, $12$. Note that it is still a tree, but
it has been transformed into a single atomic value. We can evaluate it again,
and we see that it cannot be rewritten further. We call this state of affairs
a **normal form**. Essentially, we can think of the tree as a program that
computes a value, and the normal form is the result of running the program.

In [6]:
new_post_eval = deepcopy(post_eval).visit(fn = tree_eval(), ctx = ctx, order='post')
assert post_eval == new_post_eval 
print(new_post_eval['value'])

12


What happens when we change the context so that not every variable is defined?

In [7]:
open_ctx = {
    'x': 0,
    # 'y': 1,
    'z': 2
}

try:
    deepcopy(expr_tree).visit(fn = tree_eval(), ctx = open_ctx, order='post')
except KeyError as e:
    print(f'Error: {e}')

Error: 'y'


We see that we get an error. Our operations in `ops` are not defined over
variables.

We would run into a similar problem if we used `pre` (pre-order traversal)
instead of `post` (post-order traversal). In that case, we would try to evaluate the operations before we had the values of the variables, and we would get an error.
Our operations only work over numbers (type `const`)), so we need to evaluate the variables first if we want to evaluate the operations.

Preorder traversal is good for things like rewriting trees from the top down,
but your rewrite rules need to be defined in terms of unnormalized expressions.

For instance, suppose that we add a `0` to a variable `x` in the expression tree. We know that `x + 0` is the same as `x`, so we could add a rewrite rule that maps the
sub-tree '(+ x 0)' to 'x'. We could add many rewrite rules to implement, for
instance, algebraic simplification (`simplify`), or implement a compiler
(`compile`) that translates the tree into a different form that could be
evaluated by a different set of rewrite rules.

SyntaxError: expected ':' (2503090772.py, line 75)

In [None]:
changed = True
new_ctx = {
    "x": 0
}

while changed:
    print("Applying `simplify`")
    new_expr = deepcopy(expr_tree).visit(simplify, new_ctx, order='pre')
    changed = new_expr != expr_tree
    expr_tree = new_expr

print(json.dumps(expr_tree, indent=4))

Applying `simplify`
Applying `simplify`
Applying `simplify`
Applying `simplify`
{
    "type": "+",
    "children": [
        {
            "type": "max",
            "children": [
                {
                    "type": "+",
                    "children": [
                        {
                            "type": "*",
                            "children": [
                                {
                                    "type": "const",
                                    "value": 3,
                                    "children": []
                                },
                                {
                                    "type": "var",
                                    "value": "x",
                                    "children": []
                                }
                            ]
                        }
                    ]
                },
                {
                    "type": "const",
                    "value": -100

In [None]:
expr_tree.apply(eval, ctx)

{'type': 'const', 'value': 12}

In [None]:
import tree_simp as ts