In [1]:
from flash_ansr.expressions import ExpressionSpace
from flash_ansr.expressions.utils import codify, num_to_constants
from flash_ansr import get_path
import itertools
from tqdm import tqdm
from collections import defaultdict
import numpy as np
from typing import Generator, Callable
import warnings

In [2]:
MODEL = 'v7.0'

In [3]:
config = get_path('configs', MODEL, 'expression_space.yaml')

In [4]:
space = ExpressionSpace.from_config(config)

In [5]:
leaf_nodes = space.variables + ["<num>"]
non_leaf_nodes = space.operator_arity
non_leaf_nodes = dict(sorted(non_leaf_nodes.items(), key=lambda x: x[1]))

print(leaf_nodes)
print(non_leaf_nodes)

['x1', 'x2', 'x3', '<num>']
{'neg': 1, 'abs': 1, 'inv': 1, 'pow2': 1, 'pow3': 1, 'pow4': 1, 'pow5': 1, 'pow1_2': 1, 'pow1_3': 1, 'pow1_4': 1, 'pow1_5': 1, 'sin': 1, 'cos': 1, 'tan': 1, 'asin': 1, 'acos': 1, 'atan': 1, 'exp': 1, 'log': 1, '+': 2, '-': 2, '*': 2, '/': 2}


In [6]:
def apply_rule(X, A, B):
    result = []
    i = 0
    while i < len(X):
        # Check if sublist A is found at current position
        if i <= len(X) - len(A) and X[i:i+len(A)] == A:
            # Add replacement sublist B
            result.extend(B)
            # Skip past the matched sublist A
            i += len(A)
        else:
            # Add current element and move to next
            result.append(X[i])
            i += 1
    return result

# Example usage
X = ['+', 'cos', 'x', 'sin', '*', '*', 'x', 'x', '*', 'x', 'x']
A = ['*', 'x', 'x']
B = ['pow2', 'x']
print(apply_rule(X, A, B))


['+', 'cos', 'x', 'sin', '*', 'pow2', 'x', 'pow2', 'x']


In [7]:
def simplify(expression: list[str], rules: set[tuple[list[str], list[str]]]):
    if isinstance(expression, tuple):
        expression = list(expression)
    for pattern, replacement in rules:
        expression = apply_rule(expression, pattern, replacement)
    return expression

In [8]:
simplify(['+', 'cos', 'x', 'sin', '*', '*', 'x', 'x', '*', 'x', 'x'], [(['*', 'x', 'x'], ['pow2', 'x'])])

['+', 'cos', 'x', 'sin', '*', 'pow2', 'x', 'pow2', 'x']

In [9]:
simplify(['neg', '-', 'x1', 'x1'], [(['-', 'x1', 'x1'], ['0'])])

['neg', '0']

In [10]:
rules = []

In [11]:
simplify(['<num>'], rules)

['<num>']

In [12]:
def expression_generator(hashes_of_size: dict[int, list[tuple[str]]], non_leaf_nodes: dict[str, int]) -> Generator[tuple[str], None, None]:
    # Append existing trees to every operator
    for new_root_operator, arity in non_leaf_nodes.items():
        # Start with the smallest arity-tuples of trees
        for child_lengths in sorted(itertools.product(list(hashes_of_size.keys()), repeat=arity), key=lambda x: sum(x)):
            # Check all possible combinations of child trees
            for child_combination in itertools.product(*[hashes_of_size[child_length] for child_length in child_lengths]):
                yield (new_root_operator,) + tuple(itertools.chain.from_iterable(child_combination))

In [13]:
X = np.random.normal(loc=0, scale=5, size=(1024, space.n_variables))

In [14]:
size = 10_000
rules = [(['-', t, t], ['0']) for t in space.variables] + [(['/', t, t], ['1']) for t in space.variables] + [(['*', t, '1'], [t]) for t in space.variables]

hashes_of_size = defaultdict(list)

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)

    # Create all leaf nodes
    for leaf in leaf_nodes[:size]:
        simplified_skeleton = simplify([leaf], rules)
        
        executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
        code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, space.variables + constants)

        hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

    pbar = tqdm(total=size)
    n_scanned = 0

    while n_scanned < size:
        simplified_hashes_of_size = defaultdict(list)
        for l, hashes_list in hashes_of_size.items():
            for h in hashes_list:
                simplified_skeleton = simplify(h, rules)
                simplified_hashes_of_size[len(simplified_skeleton)].append(simplified_skeleton)
        hashes_of_size = simplified_hashes_of_size

        new_hashes_of_size = defaultdict(list)
        for combination in expression_generator(hashes_of_size, non_leaf_nodes):
            for i, rule in enumerate(rules):
                rules[i] = (rule[0], simplify(rule[1], rules))

            simplified_skeleton = simplify(list(combination), rules)
            h = tuple(simplified_skeleton)

            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, space.variables + constants)

            # Record the image
            if len(constants) == 0:
                f = space.code_to_lambda(code)
                y = f(*X.T)

                for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                    for l, candidate_hashes_list in candidate_hashes_of_size.items():
                        # Ignore simplification candidates that do not shorten the expression
                        if l >= len(h):
                            continue

                        for candidate_hash in candidate_hashes_list:
                            if candidate_hash == h:
                                continue
                            executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                            prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash)
                            code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                            code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                            # Record the image
                            if len(constants_candidate_hash) == 0:
                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                y_candidate = f_candidate(*X.T)

                                if np.allclose(y, y_candidate, equal_nan=True):
                                    rules.append((simplified_skeleton, list(candidate_hash)))

            new_hashes_of_size[len(h)].append(h)

            n_scanned += 1
            pbar.update(1)
            pbar.set_postfix_str(f"Rules found: {len(rules):,}")

            if n_scanned >= size:
                break

        hashes_of_size.update(new_hashes_of_size)

    pbar.close()

# Write the rules to a file
with open('./rules.txt', 'w') as f:
    for rule in rules:
        f.write(f"{rule[0]} -> {rule[1]}\n")

 22%|██▏       | 2204/10000 [00:18<01:59, 65.42it/s, Rules found: 120] 

KeyboardInterrupt: 

In [14]:
from scipy.optimize import curve_fit, OptimizeWarning

In [15]:
def exist_constants_that_fit(expression: list[str], X: np.ndarray, y_target: np.ndarray):
    if isinstance(expression, tuple):
        expression = list(expression)

    executable_prefix_expression = space.operators_to_realizations(expression)
    prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
    code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
    code = codify(code_string, space.variables + constants)
    f = space.code_to_lambda(code)

    def pred_function(X: np.ndarray, *constants: np.ndarray | None) -> float:
        if len(constants) == 0:
            return f(*X.T)
        return f(*X.T, *constants)

    p0 = np.random.normal(loc=0, scale=5, size=len(constants))

    is_valid = np.isfinite(X).all(axis=1) & np.isfinite(y_target)

    try:
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=OptimizeWarning)
            popt, _ = curve_fit(pred_function, X[is_valid], y_target[is_valid].flatten(), p0=p0)
    except RuntimeError:
        return False

    y = f(*X.T, *popt)

    return np.allclose(y_target, y, equal_nan=True)

In [16]:
C = np.random.normal(loc=0, scale=5, size=(1024, 128))
X_with_constants = np.hstack((X, C))

In [19]:
size = 10_000
constants_retries = 5
rules = [(['-', t, t], ['0']) for t in space.variables] + [(['/', t, t], ['1']) for t in space.variables] + [(['*', t, '1'], [t]) for t in space.variables]

hashes_of_size = defaultdict(list)

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=RuntimeWarning)

    # Create all leaf nodes
    for leaf in leaf_nodes[:size]:
        simplified_skeleton = simplify([leaf], rules)
        
        executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
        code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, space.variables + constants)

        hashes_of_size[len(simplified_skeleton)].append(tuple(simplified_skeleton))

    pbar = tqdm(total=size)
    n_scanned = 0

    while n_scanned < size:
        simplified_hashes_of_size = defaultdict(list)
        for l, hashes_list in hashes_of_size.items():
            for h in hashes_list:
                simplified_skeleton = simplify(h, rules)
                simplified_hashes_of_size[len(simplified_skeleton)].append(simplified_skeleton)
        hashes_of_size = simplified_hashes_of_size

        new_hashes_of_size = defaultdict(list)
        for combination in expression_generator(hashes_of_size, non_leaf_nodes):
            for i, rule in enumerate(rules):
                rules[i] = (rule[0], simplify(rule[1], rules))

            simplified_skeleton = simplify(list(combination), rules)
            h = tuple(simplified_skeleton)

            executable_prefix_expression = space.operators_to_realizations(simplified_skeleton)
            prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression)
            code_string = space.prefix_to_infix(prefix_expression_with_constants, realization=True)
            code = codify(code_string, space.variables + constants)
            
            f = space.code_to_lambda(code)

            # Record the image
            if len(constants) == 0:
                y = f(*X.T)

                for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                    for l, candidate_hashes_list in candidate_hashes_of_size.items():
                        # Ignore simplification candidates that do not shorten the expression
                        if l >= len(h):
                            continue

                        for candidate_hash in candidate_hashes_list:
                            if candidate_hash == h:
                                continue
                            executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                            prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash)
                            code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                            code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                            # Record the image
                            if len(constants_candidate_hash) == 0:
                                f_candidate = space.code_to_lambda(code_candidate_hash)
                                y_candidate = f_candidate(*X.T)

                                if np.allclose(y, y_candidate, equal_nan=True):
                                    rules.append((simplified_skeleton, list(candidate_hash)))
                            else:
                                if any([exist_constants_that_fit(candidate_hash, X, y) for _ in range(constants_retries)]):
                                    rules.append((simplified_skeleton, list(candidate_hash)))
            else:
                # Create an image from X and randomly sampled constants
                # TODO: use multiple images here
                y = f(*X_with_constants[:, :len(space.variables) + len(constants)].T)

                for candidate_hashes_of_size in (hashes_of_size, new_hashes_of_size):
                    for l, candidate_hashes_list in candidate_hashes_of_size.items():
                        # Ignore simplification candidates that do not shorten the expression
                        if l >= len(h):
                            continue

                        for candidate_hash in candidate_hashes_list:
                            if candidate_hash == h:
                                continue
                            executable_prefix_candidate_hash = space.operators_to_realizations(candidate_hash)
                            prefix_candidate_hash_with_constants, constants_candidate_hash = num_to_constants(executable_prefix_candidate_hash)
                            code_string_candidate_hash = space.prefix_to_infix(prefix_candidate_hash_with_constants, realization=True)
                            code_candidate_hash = codify(code_string_candidate_hash, space.variables + constants_candidate_hash)

                            f_candidate = space.code_to_lambda(code_candidate_hash)

                            # Record the image
                            if len(constants_candidate_hash) == 0:
                                y_candidate = f_candidate(*X.T)

                                if np.allclose(y, y_candidate, equal_nan=True):
                                    rules.append((simplified_skeleton, list(candidate_hash)))
                            else:
                                if any([exist_constants_that_fit(candidate_hash, X, y) for _ in range(constants_retries)]):
                                    rules.append((simplified_skeleton, list(candidate_hash)))

            new_hashes_of_size[len(h)].append(h)

            n_scanned += 1
            pbar.update(1)
            pbar.set_postfix_str(f"Rules found: {len(rules):,}")

            if n_scanned >= size:
                break

        hashes_of_size.update(new_hashes_of_size)

    pbar.close()

# Write the rules to a file
with open('./rules_constants.txt', 'w') as f:
    for rule in rules:
        f.write(f"{rule[0]} -> {rule[1]}\n")

100%|██████████| 2000/2000 [01:22<00:00, 24.11it/s, Rules found: 131]
