# Compositionality with validation primitives

In [None]:
import os
import numpy as np
import sympy as sp
import torch
from tqdm.notebook import tqdm
import pandas as pd
import random
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [None]:
import admath.utils as utils

env, encoder, decoder = utils.load_env(
    '/SymbolicMathematics', 
    '/fwd_bwd_ibp.pth'
)

#### Collect short examples from the validation set

More specifically, those examples which do not have the operations in `exclude` and are of length <= 20.

In [None]:
exclude = {
   ' I ', 'asin', 'acos', 'atan', 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh'
}

ds = open('/prim_fwd.valid').readlines()
ds_ = []
for d in tqdm(ds, total=len(ds)):
    temp = d.strip().split('|')
    assert len(temp) == 2
    temp = temp[1].split('\t')
    assert len(temp) == 2
    x, y = temp
    
    skip = False
    for op in exclude:
        if op in x:
            skip = True
        if op in y:
            skip = True
    
    if skip:
        continue
        
    try:
        x_prefix = env.clean_prefix(x.replace("sub Y' ", '').split())
        y_prefix = env.clean_prefix(y.replace("sub Y' ", '').split())

        x = env.infix_to_sympy(env.prefix_to_infix(x_prefix))
        if len(str(x)) <= 20:
            ds_.append({
                'x': x,
                'y': env.prefix_to_infix(y_prefix),
            })
    except:
        pass
ds = ds_
len(ds)

### Run examples, then choose ones that succeeded as primitives

In [None]:
top_n = 50
beam_size = 50
N = 1000

results = {}

In [None]:
def parse_result(out, problems):
    nc = len([x for x in out if x['correct']])
    ncancelled = len([x for x in out if x['cancelled']])
    nc_notcancelled = len([x for x in out if x['correct'] and not x['cancelled']])
    n = len(out)
    return {
        'n': n,
        'accuracy': nc/n,
        'length': np.mean([len(str(x['x'])) for x in problems]),
        'top_n': top_n,
        'cancelled': ncancelled/n,
        'nc_notcancelled': nc_notcancelled/n
    }

In [None]:
from admath.utils import run_and_check

problems = random.sample(ds, N)
out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

parse_result(out, problems)

In [None]:
primitives = [x for x in out if x['correct']]
len(primitives)

#### Compositionality 2

In [None]:
from admath.compositionality import random_tuples

lst = list(range(len(primitives)))

problems = []
for idxs in random_tuples(lst, 2, N):
    problems.append(
        {'x': primitives[idxs[0]]['x'] + primitives[idxs[1]]['x']}
    )

out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

results['comp_2'] = parse_result(out, problems)
print(results['comp_2'])

#### Compositionality 3

In [None]:
problems = []

lst = list(range(len(primitives)))

problems = []
for idxs in random_tuples(lst, 3, N):
    problems.append(
        {'x': primitives[idxs[0]]['x'] + 
         primitives[idxs[1]]['x'] + 
         primitives[idxs[2]]['x']}
    )

out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

results['comp_3'] = parse_result(out, problems)
print(results['comp_3'])

#### Compositionality 4

In [None]:
problems = []

lst = list(range(len(primitives)))

problems = []
for idxs in random_tuples(lst, 4, N):
    problems.append(
        {'x': primitives[idxs[0]]['x'] + 
         primitives[idxs[1]]['x'] + 
         primitives[idxs[2]]['x'] +
         primitives[idxs[3]]['x']}
    )

out = run_and_check(problems, env, encoder, decoder, torch.device('cuda'), top_n, seconds=10, beam_size=beam_size)

results['comp_4'] = parse_result(out, problems)
print(results['comp_4'])

#### Save

In [None]:
import json
with open('../output/validation_compositionality_top%d.json' % top_n, 'w') as f:
    json.dump(results, f)

In [None]:
print(results)