In [1]:
from flash_ansr import ExpressionSpace, get_path
from flash_ansr.expressions.utils import flatten_nested_list

In [2]:
space = ExpressionSpace.from_config(get_path('configs', 'v7.20', 'expression_space.yaml'))

In [3]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')

In [4]:
space.simplify(expression)

('-', '+', 'x2', '<num>', 'x1')

In [5]:
connection_classes = {
    'add': (set(['+', '-']), "0"),
    'mult': (set(['*', '/']), "1"),
}

operator_to_class = {
    '+': 'add',
    '-': 'add',
    '*': 'mult',
    '/': 'mult'
}

connectable_operators = set(['+', '-', '*', '/'])

# `<num>`

In [6]:
def find_connected_num_paths(expression: list[str] | tuple[str, ...]):
    stack = []
    stack_annotations = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_sets = list(reversed(stack_annotations[-arity:]))

            if all(operand[0] == '<num>' for operand in operands):
                # All operands are constants. Simplify to a single constant
                _ = [stack.pop() for _ in range(arity)]
                _ = [stack_annotations.pop() for _ in range(arity)]
                stack.append(['<num>'])
                stack_annotations.append([set('<num>')]) # A node can have multiple annotations
                i -= 1
                continue

            operator_annotation_set = set()
            num_in_operands = '<num>' in [operand[0] for operand in operands]
            connection_class = operator_to_class[operator]
            if num_in_operands:
                if any(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # Both operands are connected by a path of + or -
                    operator_annotation_set.add(f'{connection_class}_connected')
                else:
                    operator_annotation_set.add(connection_class)
            elif all(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    operator_annotation_set.add(f'{connection_class}_connected')

            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_set, operands_annotations_sets])
            i -= 1
            continue

        if token == '<num>':
            # If the token is a number, push it onto the stack
            stack.append([token])
            stack_annotations.append([set(['<num>'])])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([set()])
        i -= 1

    return stack, stack_annotations

In [7]:
expression = ('-', '<num>', '+', '<num>', 'x1')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
expression_tree, annotated_expression_tree

([['-', [['<num>'], ['+', [['<num>'], ['x1']]]]]],
 [[{'add_connected'}, [[{'<num>'}], [{'add'}, [[{'<num>'}], [set()]]]]]])

In [8]:
expression = ('+', '-', '<num>', 'x1', '+', '*', '<num>', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
expression_tree, annotated_expression_tree

([['+',
   [['-', [['<num>'], ['x1']]],
    ['+', [['*', [['<num>'], ['x2']]], ['<num>']]]]]],
 [[{'add_connected'},
   [[{'add'}, [[{'<num>'}], [set()]]],
    [{'add'}, [[{'mult'}, [[{'<num>'}], [set()]]], [{'<num>'}]]]]]])

In [9]:
def cancel_nums(expression_tree: list, expression_annotations_tree: list):
    stack = expression_tree
    stack_annotations = expression_annotations_tree

    expression = []

    while len(stack) > 0:
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()

        # Leaf node
        if len(subtree) == 1:
            operand = subtree[0]
            operand_annotation = subtree_annotation[0]

            if operand == '<num>':
                pruned = False
                for connection_class in connection_classes:
                    if f'{connection_class}_prune' in operand_annotation:
                        expression.append(connection_classes[connection_class][1])  # Neural element
                        pruned = True
                        break
                if pruned:
                    continue
            
            expression.append(operand)
            continue
        
        # Non-leaf node
        operator, operands = subtree
        operator_annotation_set, operands_annotations_sets = subtree_annotation

        for connection_class in connection_classes:
            if f'{connection_class}_prune' in operator_annotation_set:
                # Promote children to '_prune'
                for operand_an in operands_annotations_sets:
                    operand_an[0].add(f'{connection_class}_prune')
                continue

            if f'{connection_class}_connected' in operator_annotation_set:
                # Promote children to 'connected'
                for operand_an in operands_annotations_sets:
                    if connection_class in operand_an[0] or '<num>' in operand_an[0]:
                        operand_an[0].add(f'{connection_class}_connected')
                
                # If both children are connected, promote the left child to '_prune'
                if all(any(operand_an.startswith(f'{connection_class}_connected') for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # operands_annotations_sets[0][0].remove(f'{connection_class}_connected')  # Not neecessary because prune is checked first
                    operands_annotations_sets[0][0].add(f'{connection_class}_prune')
                continue

        # Add the operator to the expression
        expression.append(operator)

        # Add the children to the stack
        for operand, operand_an in zip(reversed(operands), reversed(operands_annotations_sets)):
            stack.append(operand)
            stack_annotations.append(operand_an)

    return expression

In [10]:
expression = ('-', '<num>', '+', '<num>', 'x1')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
pruned_expression = cancel_nums(expression_tree, annotated_expression_tree)
pruned_expression

['-', '0', '+', '<num>', 'x1']

In [11]:
space.simplify(expression)

('-', '<num>', 'x1')

In [12]:
space.simplify(pruned_expression)

['-', '<num>', 'x1']

In [13]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)
# annotated_expression_tree

['+', '-', '0', 'x1', '+', 'x2', '<num>']

In [14]:
expression = ('*', '/', '<num>', 'x1', '*', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)

['*', '/', '1', 'x1', '*', 'x2', '<num>']

# All terms

In [53]:
def collect_multiplicities(expression: list[str] | tuple[str, ...]):
    stack = []
    stack_annotations: list[dict[str, dict[tuple[str, ...], list[int, int]]]] = []
    stack_labels = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_dicts = list(reversed(stack_annotations[-arity:]))
            operands_labels = list(reversed(stack_labels[-arity:]))

            operator_annotation_dict: dict[str, dict[tuple[str, ...], list[int, int]]] = {cc: {} for cc in connection_classes}
            for operands_annotations_dict in operands_annotations_dicts:
                # print(operands_annotations_dict)
                cc = operator_to_class[operator]
                for subtree_hash in operands_annotations_dict[0][cc]:
                    if subtree_hash not in operator_annotation_dict[cc]:
                        operator_annotation_dict[cc][subtree_hash] = [0, 0]
                    for p in range(2):
                        operator_annotation_dict[cc][subtree_hash][p] += operands_annotations_dict[0][cc][subtree_hash][p]

            # Add subtree hashes for both operand subtrees
            operand_tuple_0 = tuple(flatten_nested_list(operands[0])[::-1])
            operator_annotation_dict[cc][operand_tuple_0] = [1, 0]
            
            operand_tuple_1 = tuple(flatten_nested_list(operands[1])[::-1])
            index = int(operator in {'+', '*'})
            operator_annotation_dict[cc][operand_tuple_1] = [index, 1 - index]

            # Label each subtree with its own hash to know which to prune later
            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            _ = [stack_labels.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_dict, operands_annotations_dicts])
            new_label = tuple(flatten_nested_list([operator, operands])[::-1])
            stack_labels.append([new_label, operands_labels])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([{cc: {tuple([token]): [1, 0]} for cc in connection_classes}])
        stack_labels.append([tuple([token])])
        i -= 1

    return stack, stack_annotations, stack_labels

In [54]:
expression = ('*', '/', '<num>', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
annotated_expression_tree

[[{'add': {},
   'mult': {('<num>',): [1, 0],
    ('x1',): [1, 1],
    ('x2',): [1, 0],
    ('/', '<num>', 'x1'): [1, 0],
    ('*', 'x1', 'x2'): [1, 0]}},
  [[{'add': {}, 'mult': {('<num>',): [1, 0], ('x1',): [0, 1]}},
    [[{'add': {('<num>',): [1, 0]}, 'mult': {('<num>',): [1, 0]}}],
     [{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}]]],
   [{'add': {}, 'mult': {('x1',): [1, 0], ('x2',): [1, 0]}},
    [[{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}],
     [{'add': {('x2',): [1, 0]}, 'mult': {('x2',): [1, 0]}}]]]]]]

In [56]:
flatten_nested_list(stack_labels)[::-1]

[('*', '/', '<num>', 'x1', '*', 'x1', 'x2'),
 ('/', '<num>', 'x1'),
 ('<num>',),
 ('x1',),
 ('*', 'x1', 'x2'),
 ('x1',),
 ('x2',)]

In [None]:
# TODO:
# - Traverse dfs
# - take the longest nontrivial tuple (where both positions are not 0, i.e. the largest term thet has been added and removed somewhere in the tree)
# - depending on multiplicity, figure out what the replacement will be, pow, neg, inv, multiplication with a constant
# - find indices of the tuple in the flattened list
# - replace all instances with the neutral element but (the last/first) one

In [None]:
def cancel_terms(expression_tree: list, expression_annotations_tree: list):
    stack = expression_tree
    stack_annotations = expression_annotations_tree

    expression = []

    while len(stack) > 0:
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()

        ...

    return expression