In [7]:
from flash_ansr import ExpressionSpace, get_path
from flash_ansr.expressions.utils import flatten_nested_list

In [4]:
space = ExpressionSpace.from_config(get_path('configs', 'v7.20', 'expression_space.yaml'))

In [5]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')

In [6]:
space.simplify(expression)

('+', '-', '<num>', 'x1', '+', 'x2', '<num>')

In [102]:
connection_classes = {
    'add': (set(['+', '-']), 0),
    'mult': (set(['*', '/']), 1),
}

operator_to_class = {
    '+': 'add',
    '-': 'add',
    '*': 'mult',
    '/': 'mult'
}

connectable_operators = set(['+', '-', '*', '/'])

In [120]:
def find_connected_num_paths(expression: list[str] | tuple[str, ...]):
    stack = []
    stack_annotations = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_sets = list(reversed(stack_annotations[-arity:]))

            if all(operand[0] == '<num>' for operand in operands):
                # All operands are constants. Simplify to a single constant
                _ = [stack.pop() for _ in range(arity)]
                _ = [stack_annotations.pop() for _ in range(arity)]
                stack.append(['<num>'])
                stack_annotations.append([set('<num>')]) # A node can have multiple annotations
                i -= 1
                continue

            operator_annotation_set = set()
            num_in_operands = '<num>' in [operand[0] for operand in operands]
            connection_class = operator_to_class[operator]
            if num_in_operands or any(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                # The subtree has a constant or a constant connected by a path of + or -

                if all(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # Both operands are connected by a path of + or -
                    operator_annotation_set.add(f'{connection_class}_connected')
                else:
                    operator_annotation_set.add(connection_class)

            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_set, operands_annotations_sets])
            i -= 1
            continue

        if token == '<num>':
            # If the token is a number, push it onto the stack
            stack.append([token])
            stack_annotations.append([set(['<num>'])])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([set()])
        i -= 1

    return stack, stack_annotations

In [132]:
expression = ('+', '-', '<num>', 'x1', '+', '*', '<num>', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
expression_tree, annotated_expression_tree

([['+',
   [['-', [['<num>'], ['x1']]],
    ['+', [['*', [['<num>'], ['x2']]], ['<num>']]]]]],
 [[{'add_connected'},
   [[{'add'}, [[{'<num>'}], [set()]]],
    [{'add'}, [[{'mult'}, [[{'<num>'}], [set()]]], [{'<num>'}]]]]]])

In [148]:
def cancel_nums(expression_tree: list, expression_annotations_tree: list):
    stack = expression_tree
    stack_annotations = expression_annotations_tree

    expression = []

    while len(stack) > 0:
        # Traverse in dfs
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()

        # Leaf node
        if len(subtree) == 1:
            operand = subtree[0]
            operand_annotation = subtree_annotation[0]

            if operand == '<num>':
                pruned = False
                for connection_class in connection_classes:
                    if f'{connection_class}_prune' in operand_annotation:
                        expression.append(connection_classes[connection_class][1])  # Neural element
                        pruned = True
                        break
                if pruned:
                    continue
            
            expression.append(operand)
            continue
        
        # Non-leaf node
        operator, operands = subtree
        operator_annotation_set, operands_annotations_sets = subtree_annotation

        for connection_class in connection_classes:
            if f'{connection_class}_prune' in operator_annotation_set:
                # Promote children to '_prune'
                for operand_an in operands_annotations_sets:
                    operand_an[0].add(f'{connection_class}_prune')
                continue

            if f'{connection_class}_connected' in operator_annotation_set:
                # Promote children to 'connected'
                for operand_an in operands_annotations_sets:
                    operand_an[0].add(f'{connection_class}_connected')
                
                # If both children are connected, promote the left child to '_prune'
                if all(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # operands_annotations_sets[0][0].remove(f'{connection_class}_connected')  # Not neecessary because prune is checked first
                    operands_annotations_sets[0][0].add(f'{connection_class}_prune')
                continue

        # Add the operator to the expression
        expression.append(operator)

        # Add the children to the stack
        for operand, operand_an in zip(reversed(operands), reversed(operands_annotations_sets)):
            stack.append(operand)
            stack_annotations.append(operand_an)

    return expression, stack_annotations

In [149]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)

(['+', '-', 0, 'x1', '+', 'x2', '<num>'], [])

In [150]:
expression = ('*', '/', '<num>', 'x1', '*', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)

(['*', '/', 1, 'x1', '*', 'x2', '<num>'], [])