In [1]:
from flash_ansr import ExpressionSpace, get_path
from flash_ansr.expressions.utils import flatten_nested_list

In [2]:
space = ExpressionSpace.from_config(get_path('configs', 'v7.20', 'expression_space.yaml'))

In [3]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')

In [4]:
space.simplify(expression)

('-', '+', 'x2', '<num>', 'x1')

In [5]:
connection_classes = {
    'add': (set(['+', '-']), "0"),
    'mult': (set(['*', '/']), "1"),
}

connection_classes_inverse = {
    'add': "neg",
    'mult': "inv",
}

operator_to_class = {
    '+': 'add',
    '-': 'add',
    '*': 'mult',
    '/': 'mult'
}

connectable_operators = set(['+', '-', '*', '/'])

# `<num>`

In [6]:
def find_connected_num_paths(expression: list[str] | tuple[str, ...]):
    stack = []
    stack_annotations = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_sets = list(reversed(stack_annotations[-arity:]))

            if all(operand[0] == '<num>' for operand in operands):
                # All operands are constants. Simplify to a single constant
                _ = [stack.pop() for _ in range(arity)]
                _ = [stack_annotations.pop() for _ in range(arity)]
                stack.append(['<num>'])
                stack_annotations.append([set('<num>')]) # A node can have multiple annotations
                i -= 1
                continue

            operator_annotation_set = set()
            num_in_operands = '<num>' in [operand[0] for operand in operands]
            connection_class = operator_to_class[operator]
            if num_in_operands:
                if any(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # Both operands are connected by a path of + or -
                    operator_annotation_set.add(f'{connection_class}_connected')
                else:
                    operator_annotation_set.add(connection_class)
            elif all(any(operand_an.startswith(connection_class) for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    operator_annotation_set.add(f'{connection_class}_connected')

            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_set, operands_annotations_sets])
            i -= 1
            continue

        if token == '<num>':
            # If the token is a number, push it onto the stack
            stack.append([token])
            stack_annotations.append([set(['<num>'])])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([set()])
        i -= 1

    return stack, stack_annotations

In [7]:
expression = ('-', '<num>', '+', '<num>', 'x1')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
expression_tree, annotated_expression_tree

([['-', [['<num>'], ['+', [['<num>'], ['x1']]]]]],
 [[{'add_connected'}, [[{'<num>'}], [{'add'}, [[{'<num>'}], [set()]]]]]])

In [8]:
expression = ('+', '-', '<num>', 'x1', '+', '*', '<num>', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
expression_tree, annotated_expression_tree

([['+',
   [['-', [['<num>'], ['x1']]],
    ['+', [['*', [['<num>'], ['x2']]], ['<num>']]]]]],
 [[{'add_connected'},
   [[{'add'}, [[{'<num>'}], [set()]]],
    [{'add'}, [[{'mult'}, [[{'<num>'}], [set()]]], [{'<num>'}]]]]]])

In [9]:
def cancel_nums(expression_tree: list, expression_annotations_tree: list):
    stack = expression_tree
    stack_annotations = expression_annotations_tree

    expression = []

    while len(stack) > 0:
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()

        # Leaf node
        if len(subtree) == 1:
            operand = subtree[0]
            operand_annotation = subtree_annotation[0]

            if operand == '<num>':
                pruned = False
                for connection_class in connection_classes:
                    if f'{connection_class}_prune' in operand_annotation:
                        expression.append(connection_classes[connection_class][1])  # Neural element
                        pruned = True
                        break
                if pruned:
                    continue
            
            expression.append(operand)
            continue
        
        # Non-leaf node
        operator, operands = subtree
        operator_annotation_set, operands_annotations_sets = subtree_annotation

        for connection_class in connection_classes:
            if f'{connection_class}_prune' in operator_annotation_set:
                # Promote children to '_prune'
                for operand_an in operands_annotations_sets:
                    operand_an[0].add(f'{connection_class}_prune')
                continue

            if f'{connection_class}_connected' in operator_annotation_set:
                # Promote children to 'connected'
                for operand_an in operands_annotations_sets:
                    if connection_class in operand_an[0] or '<num>' in operand_an[0]:
                        operand_an[0].add(f'{connection_class}_connected')
                
                # If both children are connected, promote the left child to '_prune'
                if all(any(operand_an.startswith(f'{connection_class}_connected') for operand_an in operand_ans[0]) for operand_ans in operands_annotations_sets):
                    # operands_annotations_sets[0][0].remove(f'{connection_class}_connected')  # Not neecessary because prune is checked first
                    operands_annotations_sets[0][0].add(f'{connection_class}_prune')
                continue

        # Add the operator to the expression
        expression.append(operator)

        # Add the children to the stack
        for operand, operand_an in zip(reversed(operands), reversed(operands_annotations_sets)):
            stack.append(operand)
            stack_annotations.append(operand_an)

    return expression

In [10]:
expression = ('-', '<num>', '+', '<num>', 'x1')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
pruned_expression = cancel_nums(expression_tree, annotated_expression_tree)
pruned_expression

['-', '0', '+', '<num>', 'x1']

In [11]:
space.simplify(expression)

('-', '<num>', 'x1')

In [12]:
space.simplify(pruned_expression)

['-', '<num>', 'x1']

In [13]:
expression = ('+', '-', '<num>', 'x1', '+', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)
# annotated_expression_tree

['+', '-', '0', 'x1', '+', 'x2', '<num>']

In [14]:
expression = ('*', '/', '<num>', 'x1', '*', 'x2', '<num>')
expression_tree, annotated_expression_tree = find_connected_num_paths(expression)
cancel_nums(expression_tree, annotated_expression_tree)

['*', '/', '1', 'x1', '*', 'x2', '<num>']

# All terms

In [55]:
def collect_multiplicities(expression: list[str] | tuple[str, ...]):
    stack = []
    stack_annotations: list[dict[str, dict[tuple[str, ...], list[int, int]]]] = []
    stack_labels = []

    i = len(expression) - 1

    # Traverse the expression from right to left
    while i >= 0:
        token = expression[i]

        if token in connectable_operators:
            operator = token
            arity = 2
            operands = list(reversed(stack[-arity:]))
            operands_annotations_dicts = list(reversed(stack_annotations[-arity:]))
            operands_labels = list(reversed(stack_labels[-arity:]))

            operator_annotation_dict: dict[str, dict[tuple[str, ...], list[int, int]]] = {cc: {} for cc in connection_classes}
            # print()
            for operand_annotations_dict in operands_annotations_dicts:
                # print(operand_annotations_dict)
                cc = operator_to_class[operator]
                for subtree_hash in operand_annotations_dict[0][cc]:
                    if subtree_hash not in operator_annotation_dict[cc]:
                        operator_annotation_dict[cc][subtree_hash] = [0, 0]

                    for p in range(2):
                        # print(f'Adding {operand_annotations_dict[0][cc][subtree_hash][p]} to {operator_annotation_dict[cc][subtree_hash][p]}')
                        operator_annotation_dict[cc][subtree_hash][p] += operand_annotations_dict[0][cc][subtree_hash][p]

                    if operator in {'-', '/'}:
                        operator_annotation_dict[cc][subtree_hash][0], operator_annotation_dict[cc][subtree_hash][1] = operator_annotation_dict[cc][subtree_hash][1], operator_annotation_dict[cc][subtree_hash][0]

            # Add subtree hashes for both operand subtrees
            operand_tuple_0 = tuple(flatten_nested_list(operands[0])[::-1])
            if operand_tuple_0 not in operator_annotation_dict[cc]:
                operator_annotation_dict[cc][operand_tuple_0] = [1, 0]
            
            operand_tuple_1 = tuple(flatten_nested_list(operands[1])[::-1])
            if operand_tuple_1 not in operator_annotation_dict[cc]:
                index = int(operator in {'+', '*'})
                operator_annotation_dict[cc][operand_tuple_1] = [index, 1 - index]

            # Label each subtree with its own hash to know which to prune later
            _ = [stack.pop() for _ in range(arity)]
            _ = [stack_annotations.pop() for _ in range(arity)]
            _ = [stack_labels.pop() for _ in range(arity)]
            stack.append([operator, operands])
            stack_annotations.append([operator_annotation_dict, operands_annotations_dicts])
            new_label = tuple(flatten_nested_list([operator, operands])[::-1])
            stack_labels.append([new_label, operands_labels])
            i -= 1
            continue

        stack.append([token])
        stack_annotations.append([{cc: {tuple([token]): [1, 0]} for cc in connection_classes}])
        stack_labels.append([tuple([token])])
        i -= 1

    return stack, stack_annotations, stack_labels

In [60]:
expression = ('*', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
annotated_expression_tree

[[{'add': {},
   'mult': {('x1',): [2, 0], ('x2',): [1, 0], ('*', 'x1', 'x2'): [1, 0]}},
  [[{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}],
   [{'add': {}, 'mult': {('x1',): [1, 0], ('x2',): [1, 0]}},
    [[{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}],
     [{'add': {('x2',): [1, 0]}, 'mult': {('x2',): [1, 0]}}]]]]]]

In [17]:
flatten_nested_list(stack_labels)[::-1]

[('*', 'x1', '*', 'x1', 'x2'), ('x1',), ('*', 'x1', 'x2'), ('x1',), ('x2',)]

In [61]:
expression = ('*', '/', '<num>', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
annotated_expression_tree

[[{'add': {},
   'mult': {('<num>',): [0, 1],
    ('x1',): [1, 1],
    ('x2',): [1, 0],
    ('/', '<num>', 'x1'): [1, 0],
    ('*', 'x1', 'x2'): [1, 0]}},
  [[{'add': {}, 'mult': {('<num>',): [0, 1], ('x1',): [0, 1]}},
    [[{'add': {('<num>',): [1, 0]}, 'mult': {('<num>',): [1, 0]}}],
     [{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}]]],
   [{'add': {}, 'mult': {('x1',): [1, 0], ('x2',): [1, 0]}},
    [[{'add': {('x1',): [1, 0]}, 'mult': {('x1',): [1, 0]}}],
     [{'add': {('x2',): [1, 0]}, 'mult': {('x2',): [1, 0]}}]]]]]]

In [19]:
flatten_nested_list(stack_labels)[::-1]

[('*', '/', '<num>', 'x1', '*', 'x1', 'x2'),
 ('/', '<num>', 'x1'),
 ('<num>',),
 ('x1',),
 ('*', 'x1', 'x2'),
 ('x1',),
 ('x2',)]

In [20]:
# TODO:
# - Traverse dfs
# - take the longest nontrivial tuple (where both positions are not 0, i.e. the largest term thet has been added and removed somewhere in the tree)
# - depending on multiplicity, figure out what the replacement will be, pow, neg, inv, multiplication with a constant
# - find indices of the tuple in the flattened list
# - replace all instances with the neutral element but (the last/first) one

In [21]:
import math
def is_prime(n: int) -> bool:
    '''
    Check if a number is prime.

    Parameters
    ----------
    n : int
        The number to check.

    Returns
    -------
    bool
        True if the number is prime, False otherwise.
    '''
    if n % 2 == 0 and n > 2:
        return False
    return all(n % i for i in range(3, int(math.sqrt(n)) + 1, 2))

In [165]:
def cancel_terms(expression_tree: list, expression_annotations_tree: list, stack_labels: list):
    stack = expression_tree
    stack_annotations = expression_annotations_tree
    stack_parity = [{cc: 1 for cc in connection_classes} for _ in range(len(stack_labels))]

    expression = []

    argmax_class = None
    argmax_subtree = None
    argmax_multiplicity = None
    argmax_multiplicity_sum = None
    max_subtree_length = 0
    n_replaced = 0

    while len(stack) > 0:
        subtree = stack.pop()
        subtree_annotation = stack_annotations.pop()
        subtree_labels = stack_labels.pop()
        subtree_parities = stack_parity.pop()

        if argmax_subtree == subtree_labels[0]:
            current_parity = subtree_parities[argmax_class]
            inverse_operator = connection_classes_inverse[argmax_class]
            print(f'Current parity: {current_parity}')

            if current_parity * argmax_multiplicity_sum < 0:
                inverse_operator_prefix = (inverse_operator,)
            else:
                inverse_operator_prefix = ()

            neutral_element = connection_classes[argmax_class][1]
            if argmax_multiplicity_sum == 0:
                # Term is cancelled entirely. Replace all occurences with the neutral element
                print('Full cancellation: ', argmax_subtree)
                first_replacement = (neutral_element,)
                other_replacements = neutral_element
         
            if argmax_multiplicity_sum == 1:
                # Term occurs once. Replace every occurence after the first one with the neutral element
                print('One remaining: ', argmax_subtree)
                first_replacement = inverse_operator_prefix + argmax_subtree
                other_replacements = (neutral_element,)

            if argmax_multiplicity_sum == -1:
                # Term occurs once but inverted. Replace the first occurence with the inverse of the term. Replace every occurence after the first one with the neutral element
                print('One inverted: ', argmax_subtree)
                first_replacement = inverse_operator_prefix + argmax_subtree
                other_replacements = (neutral_element,)

            # if multiplicity_sum > 1:
            #     # Term occurs multiple times. Replace the first occurence with a multiplication or power of the term. Replace every occurence after the first one with the neutral element
            #     print(f'{multiplicity_sum} remaining')
            #     replacement = connection_classes[argmax_class][1]
            #     if argmax_class == 'mult':
            #         if multiplicity_sum > 2 and is_prime(multiplicity_sum):
            #             powers = space.factorize_to_at_most(multiplicity_sum - 1, space.max_power)
            #             summarized_term = ('*',) + tuple(f'pow_{p}' for p in powers) + (expression[indices[0]],) + (expression[indices[0]],)
            #         else:
            #             powers = space.factorize_to_at_most(multiplicity_sum, space.max_power)
            #             summarized_term = tuple(f'pow_{p}' for p in powers) + (expression[indices[0]],)

            #     expression[indices[0]] = summarized_term

            #     for index in indices[1:]:
            #         expression[index] = (replacement,)
            #     break

            # Leaf node
            if len(subtree) == 1:
                if n_replaced == 0:
                    expression.extend(first_replacement)
                else:
                    expression.extend(other_replacements)
                n_replaced += 1
                continue


        # Leaf node
        if len(subtree) == 1:
            operand = subtree[0]
            operand_annotation = subtree_annotation[0]
            operand_label = subtree_labels[0]
            operand_parity = subtree_parities
            
            expression.append(operand)
            continue

        # Non-leaf node
        operator, operands = subtree
        operator_annotation_set, operands_annotations_sets = subtree_annotation
        operator_label, operands_labels = subtree_labels
        operator_parity = subtree_parities  # No operand parity information yet

        # TODO: Switch polarity for inverse operators

        if len(operands) == 2:
            propagated_operand_parities = [{}, {}]
            for cc, (operator_set, _) in connection_classes.items():
                if operator in operator_set:
                    propagated_operand_parities[0][cc] = operator_parity[cc]
                    propagated_operand_parities[1][cc] = operator_parity[cc] * (-1 if operator in {'-', '/'} else 1)
                else:
                    propagated_operand_parities[0][cc] = operator_parity[cc]
                    propagated_operand_parities[1][cc] = operator_parity[cc]

            # If no cancellation candidate has been identified yet, try to find one in the current subtree
            if argmax_class is None:
                for cc in connection_classes:
                    for subtree_hash, multiplicity in subtree_annotation[0][cc].items():
                        if len(subtree_hash) > max_subtree_length and sum(abs(m) for m in multiplicity) > 1:
                            argmax_class = cc
                            argmax_subtree = subtree_hash
                            argmax_multiplicity = multiplicity
                            argmax_multiplicity_sum = multiplicity[0] - multiplicity[1]
                        
        # Add the operator to the expression
        expression.append(operator)

        # Add the children to the stack
        for operand, operand_an, operand_label, propagated_operand_parity in zip(reversed(operands), reversed(operands_annotations_sets), reversed(operands_labels), reversed(propagated_operand_parities)):
            stack.append(operand)
            stack_annotations.append(operand_an)
            stack_labels.append(operand_label)
            stack_parity.append(propagated_operand_parity)

    return expression

In [166]:
expression = ('*', '/', '<num>', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
cancel_terms(expression_tree, annotated_expression_tree, stack_labels)

Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': -1}
Current parity: -1
Full cancellation:  ('x1',)
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Current parity: 1
Full cancellation:  ('x1',)
Parities: {'add': 1, 'mult': 1}


['*', '/', '<num>', '1', '*', '1', 'x2']

In [167]:
# TODO: Need to really traverse the tree to keep track of the current parity

In [168]:
expression = ('*', '/', '<num>', 'x1', '*', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
cancel_terms(expression_tree, annotated_expression_tree, stack_labels)

Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': -1}
Current parity: -1
One remaining:  ('x1',)
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Current parity: 1
One remaining:  ('x1',)
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Current parity: 1
One remaining:  ('x1',)
Parities: {'add': 1, 'mult': 1}


['*', '/', '<num>', 'inv', 'x1', '*', '1', '*', '1', 'x2']

In [169]:
expression = ('*', '/', '<num>', 'x1', '/', 'x1', '*', 'x1', 'x2')
expression_tree, annotated_expression_tree, stack_labels = collect_multiplicities(expression)
cancel_terms(expression_tree, annotated_expression_tree, stack_labels)

Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': -1}
Current parity: -1
One inverted:  ('x1',)
Parities: {'add': 1, 'mult': 1}
Parities: {'add': 1, 'mult': 1}
Current parity: 1
One inverted:  ('x1',)
Parities: {'add': 1, 'mult': -1}
Parities: {'add': 1, 'mult': -1}
Current parity: -1
One inverted:  ('x1',)
Parities: {'add': 1, 'mult': -1}


['*', '/', '<num>', 'x1', '/', '1', '*', '1', 'x2']