# Robustness with validation primitives

In [None]:
import os
import numpy as np
import sympy as sp
import torch
from tqdm.notebook import tqdm
import pandas as pd
import random

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import admath.utils as utils

env, encoder, decoder = utils.load_env(
    '/SymbolicMathematics', 
    '/fwd_bwd_ibp.pth'
)

#### Collect examples from the validation set

More specifically, those examples which do not have the operations in `exclude`.

In [None]:
exclude = {
   ' I ', 'asin', 'acos', 'atan', 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh'
}

ds = open('/prim_fwd.valid').readlines()
ds_ = []
for d in tqdm(ds, total=len(ds)):
    temp = d.strip().split('|')
    assert len(temp) == 2
    temp = temp[1].split('\t')
    assert len(temp) == 2
    x, y = temp
    
    skip = False
    for op in exclude:
        if op in x:
            skip = True
        if op in y:
            skip = True
    
    if skip:
        continue
        
    try:
        x_prefix = env.clean_prefix(x.replace("sub Y' ", '').split())
        y_prefix = env.clean_prefix(y.replace("sub Y' ", '').split())

        ds_.append({
            'x': env.infix_to_sympy(env.prefix_to_infix(x_prefix))
        })
    except:
        pass
ds = ds_
len(ds)

In [None]:
import random
[str(x['x']) for x in random.sample(ds, 10)]

### Run examples, then choose ones that succeeded as primitives

In [None]:
top_n = 50
beam_size = 50
N = 1000

n_primitives = 100  # for coeff experiments

results = {}

In [None]:
def parse_result(out, problems):
    nc = len([x for x in out if x['correct']])
    n = len(out)
    return {
        'n': n,
        'accuracy': nc/n,
        'failure': 1.0 - (nc/n),
        'length': np.mean([len(str(x['x'])) for x in problems]),
        'top_n': top_n
    }

In [None]:
from admath.utils import run_and_check

problems = random.sample(ds, N)
out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

parse_result(out, problems)

In [None]:
primitives = [x for x in out if x['correct']]
len(primitives)

#### Coeff 1


$k*f()$

In [None]:
from admath.compositionality import random_tuples

ranges = [(2, 100)] #, (101, 200), (2**9, 2**10)]
for start, end in ranges:
    coeffs = np.random.choice(range(start, end), size=(N//n_primitives,), replace=False)

    problems = []
    for prim in primitives[:n_primitives]:
        for coeff in coeffs:
            problems.append(
                {'x': prim['x'] * coeff}
            )

    out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

    results['coeff_1_%d-%d' % (start, end)] = parse_result(out, problems)
    print(results['coeff_1_%d-%d' % (start, end)])

#### Coeff 2


$1/k*f()$

In [None]:
ranges = [(2, 100)] #, (101, 200), (2**9, 2**10)]
for start, end in ranges:
    coeffs = np.random.choice(range(start, end), size=(N//n_primitives,), replace=False)

    problems = []
    for prim in primitives[:n_primitives]:
        for coeff in coeffs:
            problems.append(
                {'x': prim['x'] * 1/sp.S(coeff)}
            )

    out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

    results['coeff_2_%d-%d' % (start, end)] = parse_result(out, problems)
    print(results['coeff_2_%d-%d' % (start, end)])

#### Add-perturb

$f() + k*e^x$

In [None]:
perturb_funcs = ['exp(x)', 'ln(x)']
for pf in perturb_funcs:
    problems = []
    for prim in primitives[:N]:
        coeff = 1 
        problems.append(
            {'x': prim['x'] + (coeff*sp.S(pf))}
        )

    out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

    results['perturbfunc_%s' % (pf)] = parse_result(out, problems)
    print(results['perturbfunc_%s' % (pf)])

#### Save

In [None]:
import json
with open('../output/validation_robustness_top%d.json' % top_n, 'w') as f:
    json.dump(results, f)

In [None]:
from pprint import pprint
pprint(results)