In [14]:
from flash_ansr import SkeletonPool, get_path
from flash_ansr.expressions.skeleton_pool import NoValidSampleFoundError

import os
import warnings
import random
import pickle
import pandas as pd

from types import CodeType
from typing import Any, Callable

from tqdm import tqdm
from sklearn.model_selection import train_test_split
import numpy as np

from flash_ansr.utils import load_config, substitute_root_path, save_config
from flash_ansr.expressions.expression_space import ExpressionSpace
from flash_ansr.expressions.utils import codify, num_to_constants, generate_ubi_dist, get_distribution

In [2]:
training_pool = SkeletonPool.from_config(get_path('configs', 'v7.0', 'skeleton_pool_train.yaml'))

Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 44454.73it/s]
Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 43281.75it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 28225.46it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 27733.57it/s]
Compiling Skeletons: 100%|██████████| 5000/5000 [00:00<00:00, 19022.39it/s]
Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 48491.87it/s]
Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 30039.15it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 31371.01it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 26815.45it/s]


In [3]:
training_pool.holdout_pools = []
training_pool.holdout_skeletons = set()
training_pool.holdout_y = set()

In [4]:
test_pools = {
    'feynman': SkeletonPool.load(get_path('data', 'ansr-data', 'test_set', 'feynman', 'skeleton_pool'))[1],
    'soose_nc': SkeletonPool.load(get_path('data', 'ansr-data', 'test_set', 'soose_nc', 'skeleton_pool'))[1],
    'nguyen': SkeletonPool.load(get_path('data', 'ansr-data', 'test_set', 'nguyen', 'skeleton_pool'))[1],
    'pool_15': SkeletonPool.load(get_path('data', 'ansr-data', 'test_set', 'pool_15', 'skeleton_pool'))[1],
}

Compiling Skeletons: 100%|██████████| 43/43 [00:00<00:00, 37566.15it/s]
Compiling Skeletons: 100%|██████████| 200/200 [00:00<00:00, 48138.46it/s]
Compiling Skeletons: 100%|██████████| 10/10 [00:00<00:00, 15482.85it/s]
Compiling Skeletons: 100%|██████████| 4999/4999 [00:00<00:00, 30350.00it/s]


In [5]:
def register_holdout_pool(skeleton_pool: SkeletonPool, holdout_pool: SkeletonPool) -> None:
    '''
    Register a holdout pool to exclude from sampling: Cache the skeletons and their images to compare against when sampling.

    Parameters
    ----------
    holdout_pool : SkeletonPool or str
        The holdout pool to register.
    '''
    if isinstance(holdout_pool, str):
        _, holdout_pool = SkeletonPool.load(holdout_pool)

    skeleton_pool.holdout_skeletons_better = set()
    skeleton_pool.holdout_y_better = set()

    for skeleton in holdout_pool.skeletons:
        # Remove constants since permutations are not detected as duplicates
        executable_prefix_expression = skeleton_pool.expression_space.operators_to_realizations(skeleton)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, inplace=True)
        code_string = skeleton_pool.expression_space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, skeleton_pool.expression_space.variables + constants)

        # Evaluate the Expression and store the result
        f = skeleton_pool.expression_space.code_to_lambda(code)
        X_with_constants = np.concatenate([skeleton_pool.holdout_X[:, :skeleton_pool.expression_space.n_variables], skeleton_pool.holdout_C[:, :len(constants)]], axis=1)
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        try:
            expression_image = f(*X_with_constants.T).round(4)
            expression_image[np.isnan(expression_image)] = 0  # Cannot compare NaNs
        except OverflowError:
            skeleton_pool.holdout_skeletons.add(skeleton)
            continue

        skeleton_pool.holdout_skeletons.add(skeleton)
        skeleton_pool.holdout_y.add(tuple(expression_image))

        # Remove constants since permutations are not detected as duplicates
        no_constant_expression = skeleton_pool.expression_space.remove_num(skeleton)
        # print(no_constant_expression)
        executable_prefix_expression = skeleton_pool.expression_space.operators_to_realizations(no_constant_expression)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, inplace=True)
        code_string = skeleton_pool.expression_space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, skeleton_pool.expression_space.variables + constants)

        # Evaluate the Expression and store the result
        f = skeleton_pool.expression_space.code_to_lambda(code)
        X_with_constants = np.concatenate([skeleton_pool.holdout_X[:, :skeleton_pool.expression_space.n_variables], skeleton_pool.holdout_C[:, :len(constants)]], axis=1)
        warnings.filterwarnings("ignore", category=RuntimeWarning)
        try:
            expression_image = f(*X_with_constants.T).round(4)
            expression_image[np.isnan(expression_image)] = 0  # Cannot compare NaNs
        except OverflowError:
            skeleton_pool.holdout_skeletons_better.add(tuple(no_constant_expression))
            continue

        skeleton_pool.holdout_skeletons_better.add(tuple(no_constant_expression))
        skeleton_pool.holdout_y_better.add(tuple(expression_image))

In [6]:
register_holdout_pool(training_pool, test_pools['feynman'])
register_holdout_pool(training_pool, test_pools['soose_nc'])
register_holdout_pool(training_pool, test_pools['nguyen'])
register_holdout_pool(training_pool, test_pools['pool_15'])

In [7]:
def is_held_out(skeleton_pool: SkeletonPool, skeleton: tuple[str] | list[str], constants: list[str], code: CodeType | None = None, verbose=False) -> tuple[bool, bool]:
    if constants is None:
        raise ValueError("Need constants for test of functional equivalence")

    result = [tuple(skeleton) in skeleton_pool.holdout_skeletons, tuple(skeleton) in skeleton_pool.holdout_skeletons_better]

    if verbose:
        print(f'Symbolic check: {result}')

    if code is None:
        executable_prefix_expression = skeleton_pool.expression_space.operators_to_realizations(skeleton)
        prefix_expression_with_constants, constants = num_to_constants(executable_prefix_expression, inplace=True)
        code_string = skeleton_pool.expression_space.prefix_to_infix(prefix_expression_with_constants, realization=True)
        code = codify(code_string, skeleton_pool.expression_space.variables + constants)

        # Remove constants since permutations are not detected as duplicates
        no_constant_expression = skeleton_pool.expression_space.remove_num(skeleton)
        executable_prefix_expression_better = skeleton_pool.expression_space.operators_to_realizations(no_constant_expression)
        prefix_expression_with_constants_better, constants_better = num_to_constants(executable_prefix_expression_better, inplace=True)
        code_string_better = skeleton_pool.expression_space.prefix_to_infix(prefix_expression_with_constants_better, realization=True)
        code_better = codify(code_string_better, skeleton_pool.expression_space.variables + constants_better)

    if verbose:
        print(f'Code: {code_string}')
        print(f'Code Better: {code_string_better}')

    # Evaluate the expression and check if its image is in the holdout images (functional equivalence)
    f = skeleton_pool.expression_space.code_to_lambda(code)
    f_better = skeleton_pool.expression_space.code_to_lambda(code_better)

    warnings.filterwarnings("ignore", category=RuntimeWarning)
    X_with_constants = np.concatenate([skeleton_pool.holdout_X[:, :skeleton_pool.expression_space.n_variables], skeleton_pool.holdout_C[:, :len(constants)]], axis=1)
    X_with_constants_better = np.concatenate([skeleton_pool.holdout_X[:, :skeleton_pool.expression_space.n_variables], skeleton_pool.holdout_C[:, :len(constants_better)]], axis=1)
    try:
        expression_image = f(*X_with_constants.T).round(4)
        expression_image[np.isnan(expression_image)] = 0  # Cannot compare NaNs
    except OverflowError:
        result[0] = True
    
    try:
        expression_image_better = f_better(*X_with_constants_better.T).round(4)
        expression_image_better[np.isnan(expression_image_better)] = 0  # Cannot compare NaNs
    except OverflowError:
        result[1] = True

    if verbose:
        print(f'Overflow check: {result}')

    if tuple(expression_image) in skeleton_pool.holdout_y:
        result[0] = True

    if tuple(expression_image_better) in skeleton_pool.holdout_y_better:
        result[1] = True

    if verbose:
        print(f'Functional check: {result}')

    return result

In [8]:
try:
    skeleton, code, constants = training_pool.sample_skeleton(decontaminate=False)
    is_held_out(training_pool, skeleton, constants, verbose=True)
except NoValidSampleFoundError:
    pass

Symbolic check: [False, False]
Code: x1 + (nsrops.pow1_2(numpy.arcsin(numpy.arcsin(x1))) + (C_0 / (x1 + (x2 + numpy.tan(x3)))))
Code Better: x1 + (nsrops.pow1_2(numpy.arcsin(numpy.arcsin(x1))) + (x1 + (x2 + numpy.tan(x3))))
Overflow check: [False, False]
Functional check: [False, False]


In [9]:
results_list = []
skeletons_list = []
constants_list = []

pbar = tqdm(total=100_000)
while len(results_list) < 100_000:
    try:
        skeleton, code, constants = training_pool.sample_skeleton(decontaminate=False)
        results_list.append(is_held_out(training_pool, skeleton, constants))
        skeletons_list.append(skeleton)
        constants_list.append(constants)
    except NoValidSampleFoundError:
        pass
    pbar.update(1)
pbar.close()

results_list = np.array(results_list)

100072it [00:54, 1842.75it/s]                           


In [10]:
results_list.mean(axis=0)

array([0.0622 , 0.14168])

In [15]:
confusion_matrix = np.array([
    [results_list.prod(axis=1).sum(), (results_list[:, 0] & ~results_list[:, 1]).sum()],
    [(~results_list[:, 0] & results_list[:, 1]).sum(), (~results_list).prod(axis=1).sum()]
])
pd.DataFrame(confusion_matrix, index=['Old Detected', 'Old not Detected'], columns=['New Detected', 'New not Detected'])

Unnamed: 0,New Detected,New not Detected
Old Detected,6096,124
Old not Detected,8072,85708


In [12]:
for result, skeleton, constants in zip(results_list, skeletons_list, constants_list):
    if result[0] != result[1]:
        print(result)
        print(skeleton)
        print()

[False  True]
(np.str_('*'), 'x2', np.str_('*'), 'x2', np.str_('+'), '<num>', '/', 'x3', '<num>')

[False  True]
(np.str_('cos'), 'x1')

[False  True]
(np.str_('*'), np.str_('pow2'), 'x1', np.str_('+'), '<num>', np.str_('+'), '<num>', 'x3')

[False  True]
(np.str_('*'), '<num>', np.str_('*'), 'x1', np.str_('*'), 'x3', '/', 'x3', np.str_('+'), '<num>', 'x1')

[False  True]
('neg', np.str_('/'), 'x3', 'x1')

[False  True]
(np.str_('inv'), np.str_('*'), 'x3', 'pow2', 'x2')

[False  True]
(np.str_('+'), 'x3', np.str_('/'), '<num>', 'x1')

[False  True]
('*', 'x1', '*', 'x2', np.str_('/'), np.str_('+'), 'x2', 'x3', np.str_('*'), 'x1', 'x2')

[False  True]
(np.str_('/'), '<num>', np.str_('+'), 'x1', np.str_('*'), '<num>', np.str_('*'), 'x1', '-', '<num>', 'x3')

[False  True]
('/', '/', np.str_('+'), '<num>', 'x3', 'x2', np.str_('pow2'), 'x3')

[False  True]
(np.str_('*'), '<num>', np.str_('*'), 'x2', 'x3')

[False  True]
(np.str_('*'), '<num>', np.str_('+'), '<num>', np.str_('+'), np.str_('