In [1]:
from flash_ansr.expressions import ExpressionSpace
from flash_ansr.expressions.utils import codify, num_to_constants, flatten_nested_list
from flash_ansr import get_path
import itertools
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from typing import Generator, Callable, Any
import json
import warnings
from scipy.optimize import curve_fit, OptimizeWarning
from copy import deepcopy
import time

In [2]:
MODEL = 'v7.0'

In [3]:
config = get_path('configs', MODEL, 'expression_space.yaml')

In [4]:
space = ExpressionSpace.from_config(config)

In [102]:
def expression_generator(hashes_of_size: dict[int, list[tuple[str]]], non_leaf_nodes: dict[str, int]) -> Generator[tuple[str], None, None]:
    # Append existing trees to every operator
    for new_root_operator, arity in non_leaf_nodes.items():
        # Start with the smallest arity-tuples of trees
        for child_lengths in sorted(itertools.product(list(hashes_of_size.keys()), repeat=arity), key=lambda x: sum(x)):
            # Check all possible combinations of child trees
            for child_combination in itertools.product(*[hashes_of_size[child_length] for child_length in child_lengths]):
                yield (new_root_operator,) + tuple(itertools.chain.from_iterable(child_combination))

def exist_constants_that_fit(expression: list[str], variables: list[str], X: np.ndarray, y_target: np.ndarray, debug: bool = False):
    if isinstance(expression, tuple):
        expression = list(expression)

    executable_prefix_expression = space.operators_to_realizations(expression)
    prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
    code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
    code = codify(code_string, variables + constants)
    f = space.code_to_lambda(code)

    def pred_function(X: np.ndarray, *constants: np.ndarray | None) -> float:
        if len(constants) == 0:
            y = f(*X.T)
        y =  f(*X.T, *constants)

        # If the numbers are complex, return nan
        if np.iscomplexobj(y):
            return np.full(y.shape, np.nan)
        
        return y

    p0 = np.random.normal(loc=0, scale=5, size=len(constants))

    is_valid = np.isfinite(X).all(axis=1) & np.isfinite(y_target)

    if not np.any(is_valid):
        if debug:
            print("No valid data points")
        return False

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=OptimizeWarning)
            popt, _ = curve_fit(pred_function, X[is_valid], y_target[is_valid].flatten(), p0=p0)
    except RuntimeError:
        if debug:
            print("RuntimeError")
        return False

    y = f(*X.T, *popt)
    if not isinstance(y, np.ndarray):
        y = np.full(X.shape[0], y)

    if debug:
        print(f"Constants: {popt}")
        print(f"y_target: {y_target}")
        print(f"y: {y}")
        print(f'All close: {np.allclose(y_target, y, equal_nan=True)}')

    return np.allclose(y_target, y, equal_nan=True)

def remap_expression(source_expression: list[str], dummy_variables: list[str], variable_mapping: dict | None = None):
    source_expression = deepcopy(source_expression)
    if variable_mapping is None:
        variable_mapping = {}
        for i, token in enumerate(source_expression):
            if token in dummy_variables:
                if token not in variable_mapping:
                    variable_mapping[token] = f'_{len(variable_mapping)}'
    
    for i, token in enumerate(source_expression):
        if token in dummy_variables:
            source_expression[i] = variable_mapping[token]

    return source_expression, variable_mapping

def prefix_to_tree(expression: list[str], operator_arity: dict[str, int]) -> list:
    def build_tree(index):
        if index >= len(expression):
            return None, index

        token = expression[index]

        # If token is not an operator or is an operator with arity 0
        if isinstance(token, dict) or token not in operator_arity or operator_arity[token] == 0:
            return [token], index + 1

        # If token is an operator
        operands = []
        current_index = index + 1

        # Process operands based on the operator's arity
        for _ in range(operator_arity[token]):
            if current_index >= len(expression):
                break

            subtree, current_index = build_tree(current_index)
            if subtree:
                operands.append(subtree)

        return [token, operands], current_index

    result, _ = build_tree(0)
    return result

def safe_f(f: Callable, X: np.ndarray, constants: np.ndarray | None = None):
    try:
        if constants is None:
            y = f(*X.T)
        else:
            y = f(*X.T, *constants)
        if not isinstance(y, np.ndarray) or y.shape[0] == 1:
            y = np.full(X.shape[0], y)
        return y
    except ZeroDivisionError:
        return np.full(X.shape[0], np.nan)
    
def pattern_match(tree: list, pattern: list, mapping: dict[str, Any] | None = None) -> tuple[bool, dict[str, Any]]:
    if mapping is None:
        mapping = {}

    if len(tree) == 1 and isinstance(tree[0], str) and len(pattern) != 1:
        return False, mapping

    if len(pattern) == 1 and isinstance(pattern[0], str):
        if pattern[0].startswith('_'):
            if pattern[0] not in mapping:
                mapping[pattern[0]] = tree
            elif mapping[pattern[0]] != tree:
                return False, mapping
            return True, mapping
        
        if tree != pattern:
            return False, mapping
        return True, mapping

    tree_operator, tree_operands = tree
    pattern_operator, pattern_operands = pattern

    if tree_operator != pattern_operator:
        return False, mapping
    
    for tree_operand, pattern_operand in zip(tree_operands, pattern_operands):
        if isinstance(pattern_operand, str):
            if pattern_operand not in mapping:
                mapping[pattern_operand] = tree_operand
            elif mapping[pattern_operand] != tree_operand:
                return False, mapping
        else:
            does_match, mapping = pattern_match(tree_operand, pattern_operand, mapping)
            if not does_match:
                return False, mapping

    return True, mapping

def apply_mapping(tree: list, mapping: dict[str, Any]) -> list:
    if len(tree) == 1 and isinstance(tree[0], str):
        if tree[0].startswith('_'):
            return mapping[tree[0]]
        return tree

    operator, operands = tree
    return [operator] + [apply_mapping(operand, mapping) for operand in operands]

def _subtree_simplify(expression: list[str], rules_trees: dict[int, set[tuple[list[str], list[str]]]]) -> list[str]:
    stack: list = []
    i = len(expression) - 1

    while i >= 0:
        token = expression[i]
        applied_rule = False

        if token in space.operator_arity_compat or token in space.operator_aliases:
            operator = space.operator_aliases.get(token, token)
            arity = space.operator_arity_compat[operator]
            operands = list(reversed(stack[-arity:]))

            # Check if a pattern matches the current subtree
            for rule in rules_trees.get(arity, []):
                subtree = [operator, operands]
                does_match, mapping = pattern_match(subtree, rule[0], mapping=None)
                if does_match:

                    # Replace the placeholders (keys of the mapping) with the actual subtrees (values of the mapping) in the entire subtree at any depth
                    _ = [stack.pop() for _ in range(arity)]
                    stack.append(apply_mapping(deepcopy(rule[1]), mapping))
                    i -= 1
                    applied_rule = True
                    break

            if not applied_rule:
                # print('No rule applied')
                _ = [stack.pop() for _ in range(arity)]
                stack.append([operator, operands])
                i -= 1
                continue

        if not applied_rule:
            stack.append([token])
            i -= 1

    return flatten_nested_list(stack)[::-1]

def subtree_simplify(expression: list[str], rules_trees: dict[int, set[tuple[list[str], list[str]]]], max_iter: int = 1) -> list[str]:
    new_expression = expression
    for _ in range(max_iter):
        new_expression = _subtree_simplify(new_expression, rules_trees)
        if new_expression == expression:
            break
        expression = new_expression
    return new_expression

In [38]:
prefix_to_tree(['x1'], space.operator_arity_compat)

['x1']

In [39]:
prefix_to_tree(['+', 'x1', 'x2'], space.operator_arity_compat)

['+', [['x1'], ['x2']]]

In [40]:
# %%timeit
prefix_to_tree(['*', 'x1', 'cos', 'x2'], space.operator_arity_compat)

['*', [['x1'], ['cos', [['x2']]]]]

In [41]:
pattern_match(prefix_to_tree(['+', 'cos', 'x3', 'x2'], space.operator_arity), prefix_to_tree(['+', 'cos', '_1', '_2'], space.operator_arity))

(True, {'_1': ['x3'], '_2': ['x2']})

In [42]:
prefix_to_tree(['+', 'cos', 'x1', 'x2'], space.operator_arity)

['+', [['cos', [['x1']]], ['x2']]]

In [43]:
pattern_match(prefix_to_tree(['+', 'cos', 'x1', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

(True, {'_1': ['cos', [['x1']]], '_2': ['x2']})

In [44]:
pattern_match(prefix_to_tree(['+', '+', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

(True, {'_1': ['+', [['cos', [['x1']]], ['x2']]], '_2': ['x2']})

In [45]:
pattern_match(prefix_to_tree(['-', '+', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

(False, {})

In [46]:
pattern_match(prefix_to_tree(['+', '-', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

(True, {'_1': ['-', [['cos', [['x1']]], ['x2']]], '_2': ['x2']})

In [47]:
example_rules = [
    [['-', 'x1', 'x1'], ['0']],
    [['*', 'x1', 'x1'], ['pow2', 'x1']],
]

In [48]:
# Deduplicate the rules
deduplicated_rules = set()
for rule in example_rules:
    # Rename variables in the source expression
    remapped_source, variable_mapping = remap_expression(list(rule[0]), dummy_variables=['x1'])
    remapped_target, _ = remap_expression(list(rule[1]), variable_mapping)
    deduplicated_rules.add((tuple(remapped_source), tuple(remapped_target)))

print(f'Number of rules: {len(example_rules)}')
print(f'Number of unique rules: {len(deduplicated_rules)} ({100 * len(deduplicated_rules) / len(example_rules):.2f}%)')

Number of rules: 2
Number of unique rules: 2 (100.00%)


In [49]:
deduplicated_rules_of_arity = defaultdict(list)
for rule in deduplicated_rules:
    arity = space.operator_arity[rule[0][0]]
    deduplicated_rules_of_arity[arity].append(rule)

deduplicated_rules_of_arity = dict(deduplicated_rules_of_arity)
deduplicated_rules_of_arity

{2: [(('-', '_0', '_0'), ('0',)), (('*', '_0', '_0'), ('pow2', '_0'))]}

In [50]:
rules_trees = {a: [
    (
        prefix_to_tree(list(rule[0]), space.operator_arity_compat),
        prefix_to_tree(list(rule[1]), space.operator_arity_compat)
    )
        for rule in deduplicated_rules_of_arity_a] for a, deduplicated_rules_of_arity_a in deduplicated_rules_of_arity.items()}

In [69]:
def deduplicate_rules(rules_list: list[tuple[list[str], list[str]]], dummy_variables: list[str]) -> list[tuple[list[str], list[str]]]:
    deduplicated_rules = set()
    for rule in rules_list:
        # Rename variables in the source expression
        remapped_source, variable_mapping = remap_expression(list(rule[0]), dummy_variables=dummy_variables)
        remapped_target, _ = remap_expression(list(rule[1]), variable_mapping)
        deduplicated_rules.add((tuple(remapped_source), tuple(remapped_target)))

    return list(deduplicated_rules)

In [70]:
def rules_trees_from_rules_list(rules_list: list[tuple[list[str], list[str]]], dummy_variables: list[str]) -> dict[int, set[tuple[list, list]]]:
    deduplicated_rules = deduplicate_rules(rules_list, dummy_variables)

    deduplicated_rules_of_arity = defaultdict(list)
    for rule in deduplicated_rules:
        arity = space.operator_arity[rule[0][0]]
        deduplicated_rules_of_arity[arity].append(rule)

    deduplicated_rules_of_arity = dict(deduplicated_rules_of_arity)

    rules_trees = {a: [
        (
            prefix_to_tree(list(rule[0]), space.operator_arity_compat),
            prefix_to_tree(list(rule[1]), space.operator_arity_compat)
        )
            for rule in deduplicated_rules_of_arity_a] for a, deduplicated_rules_of_arity_a in deduplicated_rules_of_arity.items()}
    
    return rules_trees

In [71]:
subtree_simplify(['*', '-', 'x2', 'x2', 'x3'], rules_trees)

['*', '-', 'x2', 'x2', 'x3']

In [72]:
subtree_simplify(['-', '*', 'x2', 'x2', '*', 'x2', 'x2'], rules_trees)

['-', '*', 'x2', 'x2', '*', 'x2', 'x2']

In [73]:
subtree_simplify(['x1'], rules_trees, max_iter=5)

['x1']

In [103]:
size = 200_000
constants_retries = 5
max_simplify_steps = 5
# timeout_seconds = 60 * 60 * 10  # 10 hours

rules = []
rules_trees = {}

hashes_of_size = defaultdict(list)

dummy_variables = [f"x{i}" for i in range(10)]

X = np.random.normal(loc=0, scale=5, size=(1024, len(dummy_variables)))
C = np.random.normal(loc=0, scale=5, size=128)

leaf_nodes = dummy_variables + ["<num>"] + ['0', '1', '2', '(-1)', '(-2)', 'float("inf")', 'float("-inf")', 'float("nan")']
non_leaf_nodes = dict(sorted(space.operator_arity.items(), key=lambda x: x[1]))

pbar = tqdm(total=size)

# start_time = time.time()

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        # Create all leaf nodes
        for leaf in leaf_nodes[:size]:
            simplified_skeleton = subtree_simplify([leaf], rules_trees, max_iter=max_simplify_steps)
            
            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, dummy_variables + constants)

            hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

        n_scanned = 0

        while n_scanned < size:
            simplified_hashes_of_size = defaultdict(set)
            for l, hashes_list in hashes_of_size.items():
                for h in hashes_list:
                    simplified_skeleton = subtree_simplify(h, rules_trees, max_iter=max_simplify_steps)
                    simplified_hashes_of_size[len(simplified_skeleton)].add(tuple(simplified_skeleton))
            hashes_of_size = {l: list(h) for l, h in simplified_hashes_of_size.items()}

            new_hashes_of_size = defaultdict(list)
            for combination in expression_generator(hashes_of_size, non_leaf_nodes):
                # TODO: Think about when to simplify the rules
                for i, rule in enumerate(rules):
                    rules[i] = (rule[0], subtree_simplify(rule[1], rules_trees, max_iter=max_simplify_steps))

                rules = deduplicate_rules(rules, dummy_variables)
                rules_trees = rules_trees_from_rules_list(rules, dummy_variables)

                simplified_skeleton = subtree_simplify(list(combination), rules_trees, max_iter=max_simplify_steps)
                h = tuple(simplified_skeleton)

                pbar.set_postfix_str(f"Rules found: {len(rules):,}, Current Expression: {combination} -> {simplified_skeleton} -> ...")

                executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
                prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
                code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
                code = codify(code_string, dummy_variables + constants)
                
                f = space.code_to_lambda(code)

                # Record the image
                if len(constants) == 0:
                    y = safe_f(f, X)
                    if not isinstance(y, np.ndarray):
                        y = np.full(X.shape[0], y)

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, dummy_variables + constants_candidate_hash)

                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    f_candidate = space.code_to_lambda(code_candidate_hash)
                                    y_candidate = safe_f(f_candidate, X)
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, dummy_variables, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                            
                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])

                else:
                    # Create an image from X and randomly sampled constants
                    y = safe_f(f, X, C[:len(constants)])

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, dummy_variables + constants_candidate_hash)

                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                
                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    y_candidate = safe_f(f_candidate, X)                                
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, dummy_variables, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))

                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])
                    

                new_hashes_of_size[len(h)].append(h)

                n_scanned += 1
                pbar.update(1)

                if n_scanned >= size:
                    break

            hashes_of_size.update(new_hashes_of_size)

        pbar.close()
except:
    pbar.close()
    raise

100%|██████████| 10000/10000 [35:04<00:00,  4.75it/s, Rules found: 406, Current Expression: ('pow1_3', '-', 'x5', '(-1)') -> ['pow1_3', '-', 'x5', '(-1)'] -> ...]                 


In [105]:
# Simplify the rules one last time
for i, rule in enumerate(rules):
    rules[i] = (rule[0], subtree_simplify(rule[1], rules_trees, max_iter=max_simplify_steps))

rules_trees = rules_trees_from_rules_list(rules, dummy_variables)

In [106]:
subtree_simplify(['-', 'x1', 'x1'], rules_trees)

['0']

In [107]:
# Write the rules to a file
with open(f'./rules_constants_{size}.txt', 'w') as f:
    for rule in rules:
        f.write(f"{rule[0]} -> {rule[1]}\n")

with open(f'./rules_constants_{size}.json', 'w') as f:
    json.dump(rules, f, indent=4)