In [1]:
from flash_ansr.expressions import ExpressionSpace
from flash_ansr.expressions.utils import codify, num_to_constants
from flash_ansr import get_path
import itertools
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from typing import Generator, Callable
import warnings

In [2]:
MODEL = 'v7.0'

In [3]:
config = get_path('configs', MODEL, 'expression_space.yaml')

In [4]:
space = ExpressionSpace.from_config(config)

In [5]:
leaf_nodes = space.variables + ["<num>"]
non_leaf_nodes = space.operator_arity
non_leaf_nodes = dict(sorted(non_leaf_nodes.items(), key=lambda x: x[1]))

print(leaf_nodes)
print(non_leaf_nodes)

['x1', 'x2', 'x3', '<num>']
{'neg': 1, 'abs': 1, 'inv': 1, 'pow2': 1, 'pow3': 1, 'pow4': 1, 'pow5': 1, 'pow1_2': 1, 'pow1_3': 1, 'pow1_4': 1, 'pow1_5': 1, 'sin': 1, 'cos': 1, 'tan': 1, 'asin': 1, 'acos': 1, 'atan': 1, 'exp': 1, 'log': 1, '+': 2, '-': 2, '*': 2, '/': 2}


In [6]:
def apply_rule(X, A, B):
    result = []
    i = 0
    while i < len(X):
        # Check if sublist A is found at current position
        if i <= len(X) - len(A) and X[i:i+len(A)] == A:
            # Add replacement sublist B
            result.extend(B)
            # Skip past the matched sublist A
            i += len(A)
        else:
            # Add current element and move to next
            result.append(X[i])
            i += 1
    return result

# Example usage
X = ['+', 'cos', 'x', 'sin', '*', '*', 'x', 'x', '*', 'x', 'x']
A = ['*', 'x', 'x']
B = ['pow2', 'x']
print(apply_rule(X, A, B))


['+', 'cos', 'x', 'sin', '*', 'pow2', 'x', 'pow2', 'x']


In [7]:
def simplify(expression: list[str], rules: set[tuple[list[str], list[str]]], max_iter: int = 1) -> list[str]:
    if isinstance(expression, tuple):
        expression = list(expression)

    previous_expression = None

    for _ in range(max_iter):
        for pattern, replacement in rules:
            expression = apply_rule(expression, pattern, replacement)
        if previous_expression == expression:
            break
        previous_expression = expression

    return expression

In [8]:
simplify(['+', 'cos', 'x', 'sin', '*', '*', 'x', 'x', '*', 'x', 'x'], [(['*', 'x', 'x'], ['pow2', 'x'])])

['+', 'cos', 'x', 'sin', '*', 'pow2', 'x', 'pow2', 'x']

In [9]:
simplify(['neg', '-', 'x1', 'x1'], [(['-', 'x1', 'x1'], ['0'])])

['neg', '0']

In [10]:
rules = []

In [11]:
simplify(['<num>'], rules)

['<num>']

In [12]:
def expression_generator(hashes_of_size: dict[int, list[tuple[str]]], non_leaf_nodes: dict[str, int]) -> Generator[tuple[str], None, None]:
    # Append existing trees to every operator
    for new_root_operator, arity in non_leaf_nodes.items():
        # Start with the smallest arity-tuples of trees
        for child_lengths in sorted(itertools.product(list(hashes_of_size.keys()), repeat=arity), key=lambda x: sum(x)):
            # Check all possible combinations of child trees
            for child_combination in itertools.product(*[hashes_of_size[child_length] for child_length in child_lengths]):
                yield (new_root_operator,) + tuple(itertools.chain.from_iterable(child_combination))

In [13]:
X = np.random.normal(loc=0, scale=5, size=(1024, space.n_variables))

In [14]:
size = 0
rules = []

hashes_of_size = defaultdict(list)

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)

    # Create all leaf nodes
    for leaf in leaf_nodes[:size]:
        simplified_skeleton = simplify([leaf], rules)
        
        executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
        code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, space.variables + constants)

        hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

    pbar = tqdm(total=size)
    n_scanned = 0

    while n_scanned < size:
        simplified_hashes_of_size = defaultdict(list)
        for l, hashes_list in hashes_of_size.items():
            for h in hashes_list:
                simplified_skeleton = simplify(h, rules)
                simplified_hashes_of_size[len(simplified_skeleton)].append(simplified_skeleton)
        hashes_of_size = simplified_hashes_of_size

        new_hashes_of_size = defaultdict(list)
        for combination in expression_generator(hashes_of_size, non_leaf_nodes):
            for i, rule in enumerate(rules):
                rules[i] = (rule[0], simplify(rule[1], rules))

            simplified_skeleton = simplify(list(combination), rules)
            h = tuple(simplified_skeleton)

            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, space.variables + constants)

            # Record the image
            if len(constants) == 0:
                f = space.code_to_lambda(code)
                y = f(*X.T)

                for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                    for l, candidate_hashes_list in candidate_hashes_of_size.items():
                        # Ignore simplification candidates that do not shorten the expression
                        if l >= len(h):
                            continue

                        for candidate_hash in candidate_hashes_list:
                            if candidate_hash == h:
                                continue
                            executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                            prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash)
                            code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                            code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                            # Record the image
                            if len(constants_candidate_hash) == 0:
                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                y_candidate = f_candidate(*X.T)

                                if np.allclose(y, y_candidate, equal_nan=True):
                                    rules.append((simplified_skeleton, list(candidate_hash)))

            new_hashes_of_size[len(h)].append(h)

            n_scanned += 1
            pbar.update(1)
            pbar.set_postfix_str(f"Rules found: {len(rules):,}")

            if n_scanned >= size:
                break

        hashes_of_size.update(new_hashes_of_size)

    pbar.close()

# Write the rules to a file
with open('./rules.txt', 'w') as f:
    for rule in rules:
        f.write(f"{rule[0]} -> {rule[1]}\n")

0it [00:00, ?it/s]


In [15]:
from scipy.optimize import curve_fit, OptimizeWarning

In [16]:
def exist_constants_that_fit(expression: list[str], variables: list[str], X: np.ndarray, y_target: np.ndarray, debug: bool = False):
    if isinstance(expression, tuple):
        expression = list(expression)

    executable_prefix_expression = space.operators_to_realizations(expression)
    prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
    code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
    code = codify(code_string, variables + constants)
    f = space.code_to_lambda(code)

    def pred_function(X: np.ndarray, *constants: np.ndarray | None) -> float:
        if len(constants) == 0:
            y = f(*X.T)
        y =  f(*X.T, *constants)

        # If the numbers are complex, return nan
        if np.iscomplexobj(y):
            return np.full(y.shape, np.nan)
        
        return y

    p0 = np.random.normal(loc=0, scale=5, size=len(constants))

    is_valid = np.isfinite(X).all(axis=1) & np.isfinite(y_target)

    if not np.any(is_valid):
        if debug:
            print("No valid data points")
        return False

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=OptimizeWarning)
            popt, _ = curve_fit(pred_function, X[is_valid], y_target[is_valid].flatten(), p0=p0)
    except RuntimeError:
        if debug:
            print("RuntimeError")
        return False

    y = f(*X.T, *popt)
    if not isinstance(y, np.ndarray):
        y = np.full(X.shape[0], y)

    if debug:
        print(f"Constants: {popt}")
        print(f"y_target: {y_target}")
        print(f"y: {y}")
        print(f'All close: {np.allclose(y_target, y, equal_nan=True)}')

    return np.allclose(y_target, y, equal_nan=True)

In [17]:
C = np.random.normal(loc=0, scale=5, size=128)
print(C.shape)

(128,)


In [18]:
import string

In [19]:
def safe_f(f, X, constants = None):
    try:
        if constants is None:
            y = f(*X.T)
        else:
            y = f(*X.T, *constants)
        if not isinstance(y, np.ndarray) or y.shape[0] == 1:
            y = np.full(X.shape[0], y)
        return y
    except ZeroDivisionError:
        return np.full(X.shape[0], np.nan)

In [20]:
size = 10_000
constants_retries = 5
max_simplify_steps = 5
rules = []

hashes_of_size = defaultdict(list)

leaf_nodes = space.variables + ["<num>"] + ['0', '1', '2', '(-1)', '(-2)', 'float("inf")', 'float("-inf")', 'float("nan")']

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)

        # Create all leaf nodes
        for leaf in leaf_nodes[:size]:
            simplified_skeleton = simplify([leaf], rules, max_iter=max_simplify_steps)
            
            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, space.variables + constants)

            hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

        pbar = tqdm(total=size)
        n_scanned = 0

        while n_scanned < size:
            simplified_hashes_of_size = defaultdict(set)
            for l, hashes_list in hashes_of_size.items():
                for h in hashes_list:
                    simplified_skeleton = simplify(h, rules, max_iter=max_simplify_steps)
                    simplified_hashes_of_size[len(simplified_skeleton)].add(tuple(simplified_skeleton))
            hashes_of_size = {l: list(h) for l, h in simplified_hashes_of_size.items()}

            new_hashes_of_size = defaultdict(list)
            for combination in expression_generator(hashes_of_size, non_leaf_nodes):
                for i, rule in enumerate(rules):
                    rules[i] = (rule[0], simplify(rule[1], rules, max_iter=max_simplify_steps))

                simplified_skeleton = simplify(list(combination), rules, max_iter=max_simplify_steps)
                h = tuple(simplified_skeleton)

                pbar.set_postfix_str(f"Rules found: {len(rules):,}, Current Expression: {combination} -> {simplified_skeleton} -> ...")

                executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
                prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
                code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
                code = codify(code_string, space.variables + constants)
                
                f = space.code_to_lambda(code)

                # Record the image
                if len(constants) == 0:
                    y = safe_f(f, X)
                    if not isinstance(y, np.ndarray):
                        y = np.full(X.shape[0], y)

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    f_candidate = space.code_to_lambda(code_candidate_hash)
                                    y_candidate = safe_f(f_candidate, X)
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                            
                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])

                else:
                    # Create an image from X and randomly sampled constants
                    y = safe_f(f, X, C[:len(constants)])

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                
                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    y_candidate = safe_f(f_candidate, X)
                                    if combination == ('*', '<num>', '<num>') and candidate_hash == ('<num>',):
                                        print(y_candidate)
                                
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))

                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])

                new_hashes_of_size[len(h)].append(h)

                n_scanned += 1
                pbar.update(1)

                if n_scanned >= size:
                    break

            hashes_of_size.update(new_hashes_of_size)

        pbar.close()
except:
    pbar.close()
    raise

  0%|          | 0/10000 [00:00<?, ?it/s, Rules found: 0, Current Expression: ('neg', '<num>') -> ['neg', '<num>'] -> ...]


TypeError: exist_constants_that_fit() missing 1 required positional argument: 'y_target'

In [21]:
size = 1_000
constants_retries = 5
max_simplify_steps = 5
rules = []

hashes_of_size = defaultdict(list)

dummy_variables = [f"x{i}" for i in range(10)]

X = np.random.normal(loc=0, scale=5, size=(1024, len(dummy_variables)))
C = np.random.normal(loc=0, scale=5, size=128)

leaf_nodes = dummy_variables + ["<num>"] + ['0', '1', '2', '(-1)', '(-2)', 'float("inf")', 'float("-inf")', 'float("nan")']

try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=RuntimeWarning)

        # Create all leaf nodes
        for leaf in leaf_nodes[:size]:
            simplified_skeleton = simplify([leaf], rules, max_iter=max_simplify_steps)
            
            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, dummy_variables + constants)

            hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

        pbar = tqdm(total=size)
        n_scanned = 0

        while n_scanned < size:
            simplified_hashes_of_size = defaultdict(set)
            for l, hashes_list in hashes_of_size.items():
                for h in hashes_list:
                    simplified_skeleton = simplify(h, rules, max_iter=max_simplify_steps)
                    simplified_hashes_of_size[len(simplified_skeleton)].add(tuple(simplified_skeleton))
            hashes_of_size = {l: list(h) for l, h in simplified_hashes_of_size.items()}

            new_hashes_of_size = defaultdict(list)
            for combination in expression_generator(hashes_of_size, non_leaf_nodes):
                for i, rule in enumerate(rules):
                    rules[i] = (rule[0], simplify(rule[1], rules, max_iter=max_simplify_steps))

                simplified_skeleton = simplify(list(combination), rules, max_iter=max_simplify_steps)
                h = tuple(simplified_skeleton)

                pbar.set_postfix_str(f"Rules found: {len(rules):,}, Current Expression: {combination} -> {simplified_skeleton} -> ...")

                executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
                prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, convert_numbers_to_constant=False)
                code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
                code = codify(code_string, dummy_variables + constants)
                
                f = space.code_to_lambda(code)

                # Record the image
                if len(constants) == 0:
                    y = safe_f(f, X)
                    if not isinstance(y, np.ndarray):
                        y = np.full(X.shape[0], y)

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, dummy_variables + constants_candidate_hash)

                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    f_candidate = space.code_to_lambda(code_candidate_hash)
                                    y_candidate = safe_f(f_candidate, X)
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, dummy_variables, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                            
                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])

                else:
                    # Create an image from X and randomly sampled constants
                    y = safe_f(f, X, C[:len(constants)])

                    new_rule_candidates = []
                    for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                        for l, candidate_hashes_list in candidate_hashes_of_size.items():
                            # Ignore simplification candidates that do not shorten the expression
                            if l >= len(h):
                                continue

                            for candidate_hash in candidate_hashes_list:
                                if candidate_hash == h:
                                    continue
                                executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                                prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash, convert_numbers_to_constant=False)
                                code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                                code_candidate_hash = codify(code_string_candidate_hash, dummy_variables + constants_candidate_hash)

                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                
                                # Record the image
                                if len(constants_candidate_hash) == 0:
                                    y_candidate = safe_f(f_candidate, X)                                
                                    if not isinstance(y_candidate, np.ndarray):
                                        y_candidate = np.full(X.shape[0], y_candidate)

                                    if np.allclose(y, y_candidate, equal_nan=True):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))
                                else:
                                    if any([exist_constants_that_fit(candidate_hash, dummy_variables, X, y) for _ in range(constants_retries)]):
                                        new_rule_candidates.append((simplified_skeleton, list(candidate_hash)))

                    # Find the shortest rule
                    if len(new_rule_candidates) > 0:
                        new_rule_candidates = sorted(new_rule_candidates, key=lambda x: len(x[1]))
                        new_rule_candidates_of_minimum_length = [c for c in new_rule_candidates if len(c[1]) == len(new_rule_candidates[0][1])]
                        # If there are rules with and without <num>, prefer the ones without
                        new_rule_candidates_of_minimum_length_without_num = [c for c in new_rule_candidates_of_minimum_length if '<num>' not in c[1]]
                        if len(new_rule_candidates_of_minimum_length_without_num) > 0:
                            new_rule_candidates_of_minimum_length = new_rule_candidates_of_minimum_length_without_num
                        rules.append(new_rule_candidates_of_minimum_length[0])

                new_hashes_of_size[len(h)].append(h)

                n_scanned += 1
                pbar.update(1)

                if n_scanned >= size:
                    break

            hashes_of_size.update(new_hashes_of_size)

        pbar.close()
except:
    pbar.close()
    raise

100%|██████████| 1000/1000 [00:40<00:00, 24.97it/s, Rules found: 468, Current Expression: ('-', '1', '2') -> ['-', '1', '2'] -> ...]                                               


In [61]:
def remap_expression(source_expression, variable_mapping: dict | None = None):
    if variable_mapping is None:
        variable_mapping = {}
        for i, token in enumerate(source_expression):
            if token in dummy_variables:
                if token not in variable_mapping:
                    variable_mapping[token] = f'_{len(variable_mapping)}'
    
    for i, token in enumerate(source_expression):
        if token in dummy_variables:
            source_expression[i] = variable_mapping[token]

    return source_expression, variable_mapping

In [62]:
# Deduplicate the rules
deduplicated_rules = set()
for rule in rules:
    # Rename variables in the source expression
    remapped_source, variable_mapping = remap_expression(list(rule[0]))
    remapped_target, _ = remap_expression(list(rule[1]), variable_mapping)
    deduplicated_rules.add((tuple(remapped_source), tuple(remapped_target)))

print(f'Number of rules: {len(rules)}')
print(f'Number of unique rules: {len(deduplicated_rules)} ({100 * len(deduplicated_rules) / len(rules):.2f}%)')

Number of rules: 469
Number of unique rules: 331 (70.58%)


In [215]:
def prefix_to_tree(expr, operator_arity):
    """
    Convert a prefix notation expression to a nested list representing the expression tree.

    Args:
        expr: List representing the expression in prefix notation
        operator_arity: Dictionary mapping operators to their arities

    Returns:
        Nested list representing the expression tree
    """
    def build_tree(index):
        if index >= len(expr):
            return None, index

        token = expr[index]

        # If token is not an operator or is an operator with arity 0
        if isinstance(token, dict) or token not in operator_arity or operator_arity[token] == 0:
            return [token], index + 1

        # If token is an operator
        operands = []
        current_index = index + 1

        # Process operands based on the operator's arity
        for _ in range(operator_arity[token]):
            if current_index >= len(expr):
                break

            subtree, current_index = build_tree(current_index)
            if subtree:
                operands.append(subtree)

        return [token, operands], current_index

    result, _ = build_tree(0)
    return result

In [216]:
prefix_to_tree(['x1'], space.operator_arity_compat)

['x1']

In [217]:
prefix_to_tree(['+', 'x1', 'x2'], space.operator_arity_compat)

['+', [['x1'], ['x2']]]

In [218]:
# %%timeit
prefix_to_tree(['*', 'x1', 'cos', 'x2'], space.operator_arity_compat)

['*', [['x1'], ['cos', [['x2']]]]]

In [219]:
deduplicated_rules_of_arity = defaultdict(list)
for rule in deduplicated_rules:
    arity = space.operator_arity[rule[0][0]]
    deduplicated_rules_of_arity[arity].append(rule)

deduplicated_rules_of_arity = dict(deduplicated_rules_of_arity)

In [220]:
rules_trees = {a: [
    (
        prefix_to_tree(list(rule[0]), space.operator_arity_compat),
        prefix_to_tree(list(rule[1]), space.operator_arity_compat)
    )
        for rule in deduplicated_rules_of_arity_a] for a, deduplicated_rules_of_arity_a in deduplicated_rules_of_arity.items()}

In [221]:
from typing import Any

In [222]:
def pattern_match(tree: list, pattern: list, mapping: dict[str, Any] | None = None) -> tuple[bool, dict[str, Any]]:
    print(f'tree: {tree}')
    print(f'pattern: {pattern}')
    print()

    if mapping is None:
        mapping = {}

    if len(pattern) == 1 and isinstance(pattern[0], str):
        if pattern[0].startswith('_'):
            if pattern[0] not in mapping:
                mapping[pattern[0]] = tree
            elif mapping[pattern[0]] != tree:
                return False, mapping
            return True, mapping
        
        if tree != pattern:
            return False, mapping
        return True, mapping

    print(f'tree: {tree}')
    print(f'pattern: {pattern}')
    tree_operator, tree_operands = tree
    print(f'tree_operator: {tree_operator}')
    print(f'tree_operands: {tree_operands}')
    pattern_operator, pattern_operands = pattern
    print(f'pattern_operator: {pattern_operator}')
    print(f'pattern_operands: {pattern_operands}')
    print()

    if tree_operator != pattern_operator:
        return False, mapping
    
    for tree_operand, pattern_operand in zip(tree_operands, pattern_operands):
        if isinstance(pattern_operand, str):
            print(f'pattern_operand: {pattern_operand}')
            if pattern_operand not in mapping:
                mapping[pattern_operand] = tree_operand
            elif mapping[pattern_operand] != tree_operand:
                return False, mapping
        else:
            does_match, mapping = pattern_match(tree_operand, pattern_operand, mapping)
            if not does_match:
                return False, mapping

    return True, mapping

In [223]:
pattern_match(prefix_to_tree(['+', 'cos', 'x3', 'x2'], space.operator_arity), prefix_to_tree(['+', 'cos', '_1', '_2'], space.operator_arity))

tree: ['+', [['cos', [['x3']]], ['x2']]]
pattern: ['+', [['cos', [['_1']]], ['_2']]]

tree: ['+', [['cos', [['x3']]], ['x2']]]
pattern: ['+', [['cos', [['_1']]], ['_2']]]
tree_operator: +
tree_operands: [['cos', [['x3']]], ['x2']]
pattern_operator: +
pattern_operands: [['cos', [['_1']]], ['_2']]

tree: ['cos', [['x3']]]
pattern: ['cos', [['_1']]]

tree: ['cos', [['x3']]]
pattern: ['cos', [['_1']]]
tree_operator: cos
tree_operands: [['x3']]
pattern_operator: cos
pattern_operands: [['_1']]

tree: ['x3']
pattern: ['_1']

tree: ['x2']
pattern: ['_2']



(True, {'_1': ['x3'], '_2': ['x2']})

In [224]:
prefix_to_tree(['+', 'cos', 'x1', 'x2'], space.operator_arity)

['+', [['cos', [['x1']]], ['x2']]]

In [225]:
pattern_match(prefix_to_tree(['+', 'cos', 'x1', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

tree: ['+', [['cos', [['x1']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]

tree: ['+', [['cos', [['x1']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]
tree_operator: +
tree_operands: [['cos', [['x1']]], ['x2']]
pattern_operator: +
pattern_operands: [['_1'], ['_2']]

tree: ['cos', [['x1']]]
pattern: ['_1']

tree: ['x2']
pattern: ['_2']



(True, {'_1': ['cos', [['x1']]], '_2': ['x2']})

In [226]:
pattern_match(prefix_to_tree(['+', '+', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

tree: ['+', [['+', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]

tree: ['+', [['+', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]
tree_operator: +
tree_operands: [['+', [['cos', [['x1']]], ['x2']]], ['x2']]
pattern_operator: +
pattern_operands: [['_1'], ['_2']]

tree: ['+', [['cos', [['x1']]], ['x2']]]
pattern: ['_1']

tree: ['x2']
pattern: ['_2']



(True, {'_1': ['+', [['cos', [['x1']]], ['x2']]], '_2': ['x2']})

In [227]:
pattern_match(prefix_to_tree(['-', '+', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

tree: ['-', [['+', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]

tree: ['-', [['+', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]
tree_operator: -
tree_operands: [['+', [['cos', [['x1']]], ['x2']]], ['x2']]
pattern_operator: +
pattern_operands: [['_1'], ['_2']]



(False, {})

In [228]:
pattern_match(prefix_to_tree(['+', '-', 'cos', 'x1', 'x2', 'x2'], space.operator_arity), prefix_to_tree(['+', '_1', '_2'], space.operator_arity))

tree: ['+', [['-', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]

tree: ['+', [['-', [['cos', [['x1']]], ['x2']]], ['x2']]]
pattern: ['+', [['_1'], ['_2']]]
tree_operator: +
tree_operands: [['-', [['cos', [['x1']]], ['x2']]], ['x2']]
pattern_operator: +
pattern_operands: [['_1'], ['_2']]

tree: ['-', [['cos', [['x1']]], ['x2']]]
pattern: ['_1']

tree: ['x2']
pattern: ['_2']



(True, {'_1': ['-', [['cos', [['x1']]], ['x2']]], '_2': ['x2']})

In [48]:
def _subtree_simplify(expression: list[str], rules_trees: dict[int, set[tuple[list[str], list[str]]]]) -> list[str]:
    stack: list = []
    i = len(expression) - 1

    while i >= 0:
        token = expression[i]

        if token in space.operator_arity_compat or token in space.operator_aliases:
            operator = space.operator_aliases.get(token, token)
            arity = space.operator_arity_compat[operator]
            operands = list(reversed(stack[-arity:]))

            # Check if a pattern matches the current subtree
            print(rules_trees[1][0])
            

            _ = [stack.pop() for _ in range(arity)]
            stack.append([operator, operands])
            i -= 1
            continue

        stack.append([token])
        
        i -= 1

    return stack

In [70]:
_subtree_simplify(['-', 'x1', 'x1'], rules_trees)

(['pow1_2', ['<num>']], ['<num>'])


[['-', [['x1'], ['x1']]]]

In [88]:
# Write the rules to a file
with open(f'./rules_constants_{size}.txt', 'w') as f:
    for rule in rules:
        f.write(f"{rule[0]} -> {rule[1]}\n")

In [89]:
import json

In [51]:
with open(f'./rules_constants_{size}.json', 'w') as f:
    json.dump(rules, f, indent=4)

In [103]:
from copy import deepcopy

In [137]:
example_expression_positive = ['*', 'x1', 'cos', 'x2']
example_tree_positive = prefix_to_tree(example_expression_positive, space.operator_arity_compat)

example_pattern_expression = [(t if t not in dummy_variables else _) for t in example_expression_positive]
example_pattern_tree = prefix_to_tree(example_pattern_expression, space.operator_arity_compat)
example_pattern_tree_backup = deepcopy(example_pattern_tree)

example_expression_negative = ['*', 'x1', 'sin', 'x2']
example_tree_negative = prefix_to_tree(example_expression_negative, space.operator_arity_compat)

print()
print(example_tree_positive)
print(example_pattern_tree)
match example_tree_positive:
    case example_pattern_tree:
        print("matched")

print(f'Applying backup {example_pattern_tree_backup} to filled {example_pattern_tree}')
example_pattern_tree = example_pattern_tree_backup
print()
print(example_tree_negative)
print(example_pattern_tree)
match example_tree_negative:
    case example_pattern_tree:
        print("also matched")


('*', ('x1',), ('cos', ('x2',)))
('*', ({},), ('cos', ({},)))
matched
Applying backup ('*', ({},), ('cos', ({},))) to filled ('*', ('x1',), ('cos', ('x2',)))

('*', ('x1',), ('sin', ('x2',)))
('*', ({},), ('cos', ({},)))
also matched


In [139]:
def iterate_tree(tree: list):
    if isinstance(tree, (list, tuple)):
        yield tree[0]
        for subtree in tree[1:]:
            yield from iterate_tree(subtree)

In [112]:
match ['*', ['x1'], ['sin', ['x2']]]:
    case ['*', [{}], ['cos', [{}]]]:
        print("also matched")