# Hierarchical Reduction

In [1]:
import src.utils as utils
from enum import Enum
class PRes(str, Enum):
    success = 'SUCCESS'
    failed = 'FAILED'
    invalid = 'INVALID'
    timeout = 'TIMEOUT'

In [2]:
ecount = 0

def add_to_pq(tup, q):
    global ecount
    dtree, F_path = tup
    stree = get_child(dtree, F_path)
    n =  count_leaves(dtree)
    m =  count_leaves(stree)
    # heap smallest first
    heapq.heappush(q, (n, m, -ecount, tup))
    ecount += 1

In [3]:
def compatible_nodes(tree, grammar):
    key, children, *_ = tree
    # Here is the first choice. Do we restrict ourselves to only children of the tree
    # or do we allow all nodes in the original tree? given in all_nodes?
    lst = nt_group(tree)
    node_lst = [(i, n) for i,n in enumerate(lst[key])]

    # insert empty if the grammar allows it as the first element
    if [] in grammar[key]: node_lst.insert(0, (-1, (key, [])))
    return node_lst

In [4]:
def get_child(tree, path):
    if not path: return tree
    cur, *path = path
    return get_child(tree[1][cur], path)

In [5]:
def hierarchical_reduction(tree, grammar, predicate):
    first_tuple = (tree, [])
    p_q = []
    add_to_pq(first_tuple, p_q)

    ostr = utils.tree_to_str(tree)
    assert predicate(ostr) == PRes.success
    failed_set = {ostr: True}

    min_tree, min_tree_size = tree, count_leaves(tree)
    while p_q:
        # extract the tuple
        _n, _m, _ec, (dtree, F_path) = heapq.heappop(p_q)
        stree = get_child(dtree, F_path)
        skey, schildren = stree
        found = False
        # we now want to replace stree with alternate nodes.
        for i, node in compatible_nodes(stree, grammar):
            # replace with current (copy).
            ctree = replace_path(dtree, F_path, node)
            if ctree is None: continue # same node
            v = utils.tree_to_str(ctree)
            if v in failed_set: continue
            failed_set[v] = predicate(v) # we ignore PRes.invalid results
            if failed_set[v] == PRes.success:
                found = True
                ctree_size = count_leaves(ctree)
                if ctree_size < min_tree_size: min_tree, min_tree_size = ctree, ctree_size

                if v not in failed_set:
                    print(v)
                t = (ctree, F_path)
                assert get_child(ctree, F_path) is not None
                add_to_pq(t, p_q)

        # The CHOICE here is that we explore the children if and only if we fail
        # to find a node that can replace the current
        if found: continue
        if utils.is_token(skey): continue # do not follow children TOKEN optimization
        for i, child in enumerate(schildren):
            if not utils.is_nt(child[0]): continue
            assert get_child(tree=dtree, path=F_path + [i]) is not None
            t = (dtree, F_path + [i])
            add_to_pq(t, p_q)
    return min_tree

In [6]:
import heapq

def count_leaves(node):
    name, children, *_ = node
    if not children:
        return 1
    return sum(count_leaves(i) for i in children)

def count_nodes(node):
    name, children, *_ = node
    if not children:
        return 0
    return sum(count_nodes(i) for i in children) + 1

In [7]:
def nt_group(tree, all_nodes=None):
    if all_nodes is None: all_nodes = {}
    name, children, *_ = tree
    if not utils.is_nt(name): return
    all_nodes.setdefault(name, []).append(tree)
    for c in children:
        nt_group(c, all_nodes)
    return all_nodes

In [8]:
def replace_path(tree, path, new_node=None):
    if new_node is None: new_node = []
    if not path: return utils.deep_copy(new_node)
    cur, *path = path
    name, children, *rest = tree
    new_children = []
    for i,c in enumerate(children):
        if i == cur:
            nc = replace_path(c, path, new_node)
        else:
            nc = c
        if nc:
            new_children.append(nc)
    return (name, new_children, *rest)

In [9]:
import re
def expr_double_paren(inp):
    if re.match(r'.*[(][(].*[)][)].*', inp):
        return PRes.success
    return PRes.failed

In [21]:
if __name__ == '__main__':
    import ipynb.fs.full.x0_2_parser as parser
    expr_parser = parser.EarleyParser(parser.EXPR_GRAMMAR)
    parsed_expr = list(expr_parser.parse_on('1+((2*3/4))', parser.EXPR_START))[0]
    reduced_expr_tree = hierarchical_reduction(parsed_expr, parser.EXPR_GRAMMAR, expr_double_paren)
    print(utils.tree_to_str(reduced_expr_tree))
    utils.display_tree_console(reduced_expr_tree)

((4))
<start>
+- <expr>
    +- <term>
        +- <factor>
            +- '('
            +- <expr>
            |   +- <term>
            |       +- <factor>
            |           +- '('
            |           +- <expr>
            |           |   +- <term>
            |           |       +- <factor>
            |           |           +- <integer>
            |           |               +- <digit>
            |           |                   +- '4'
            |           +- ')'
            +- ')'


# Done

In [None]:
#%tb