In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import sympy as sp
from src.utils import AttrDict
from src.envs import build_env
import linecache
from pathlib import Path

params = AttrDict({

    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 1024,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',
})

env = build_env(params)

In [2]:
model_path = "./ckpoints/checkpoint-320000"
tokenizer_path="./ckpoints/checkpoint-320000"

tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
print(f"max length: {model.config.n_positions}")

device='cuda'
model=model.to(device)
model=model.eval()
def generate_summary(input_tokens):
    inputs = tokenizer(input_tokens, return_tensors="pt",is_split_into_words=True, padding=True, truncation=True)
    if device =='cuda':
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        outputs = model.generate(inputs['input_ids'], max_length=1024, num_beams=2, early_stopping=True)
    else:
        outputs = model.generate(inputs.input_ids, max_length=1024, num_beams=2, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

max length: 1024


In [None]:
count=0
true_index=[]
total_index=[]

i=1
n=0
count_len=0
num_times_one =0
num_scramb =1
delta_ns =2

data_type = "random_ns_nt"
folder_path = './data/test/random_ns_nt'

with open(folder_path / Path('infix_origin.txt'), 'r', encoding='utf-8') as f_origin:
    origin_lines = f_origin.readlines()

with open(folder_path / Path('infix_simple.txt'), 'r', encoding='utf-8') as f_simple:
    simple_lines = f_simple.readlines()

while True:
    if i-1 < len(origin_lines):
        test_input = origin_lines[i-1]
    else:
        test_input = ''

    #print(f"test: {test_input}")
    expr_sp_origin = sp.S(test_input, locals=env.local_dict)
    expr_prefix_origin = env.sympy_to_prefix(expr_sp_origin)

    if len(expr_prefix_origin) <= 1024:
        total_index.append(i)
        count_len += len(expr_prefix_origin)
        after_train_str = generate_summary(expr_prefix_origin)
        after_train_prefix = after_train_str.split(" ")
        try:
            after_train_infix = env.prefix_to_infix(after_train_prefix)
            after_train_sp = env.infix_to_sympy(after_train_infix)
            if i-1 < len(simple_lines):
                simple_input_infix = simple_lines[i-1]
            else:
                simple_input_infix = ''

            if sp.simplify(sp.S(str(after_train_sp))) == sp.simplify(sp.S(str(simple_input_infix))):
                count += 1
                true_index.append(i)
        except Exception as e:
            print(f"An error occurred: {e}")
        finally:
            i += 1
            n += 1
            #print(n)
            if n % 10 == 0:
                print(n/50000 )
            if n == 50000:
                break
    else:
        i += 1

with open("acc.txt", "a") as acc_f:
    acc_f.write("----data type: {}----\n".format(data_type))
    acc_f.write("length rate smaller than 1024: {}\n".format(n/i))
    acc_f.write("acc: {}\n".format(count/n))
    acc_f.write("average length:{}\n".format(count_len/n))
    acc_f.write("-------------------------\n\n")


0.0002
0.0004
0.0006
0.0008
0.001
0.0012
0.0014
0.0016
0.0018
0.002
0.0022
0.0024
0.0026
0.0028
0.003
0.0032
0.0034
0.0036
0.0038
0.004
0.0042
0.0044
0.0046
0.0048
0.005
0.0052
0.0054
0.0056
0.0058
0.006
0.0062
0.0064
0.0066
0.0068
0.007
0.0072
0.0074
0.0076
0.0078
0.008
0.0082
0.0084
0.0086
0.0088
0.009
0.0092
0.0094
0.0096
0.0098
0.01
0.0102
0.0104
0.0106
0.0108
0.011
0.0112
0.0114
0.0116
0.0118
0.012
0.0122
0.0124
0.0126
0.0128
0.013
0.0132
0.0134
0.0136
0.0138
0.014
0.0142
0.0144
0.0146
0.0148
0.015
0.0152
0.0154
0.0156
0.0158
0.016
0.0162
0.0164
0.0166
0.0168
0.017
0.0172
0.0174
0.0176
0.0178
0.018
0.0182
0.0184
0.0186
0.0188
0.019
0.0192
0.0194
0.0196
0.0198
0.02
0.0202
0.0204
0.0206
0.0208
0.021
0.0212
0.0214
0.0216
0.0218
0.022
0.0222
0.0224
0.0226
0.0228
0.023
0.0232
0.0234
0.0236
0.0238
0.024
0.0242
0.0244
0.0246
0.0248
0.025
0.0252
0.0254
0.0256
0.0258
0.026
0.0262
0.0264
0.0266
0.0268
0.027
0.0272
0.0274
0.0276
0.0278
0.028
0.0282
0.0284
0.0286
0.0288
0.029
0.0292
0.0294
0.

In [4]:
print("length rate smaller than 1024: {}".format(n/i))
print("acc: {}".format(count/len(total_index)))
print("average length:{}".format(count_len/n))
print(true_index, len(true_index))
print(total_index,len(total_index))

length rate smaller than 1024: 0.9803921568627451
acc: 0.66
average length:275.02
[5, 7, 8, 9, 10, 11, 13, 14, 16, 17, 18, 20, 21, 23, 24, 26, 27, 28, 31, 35, 36, 37, 39, 40, 41, 42, 44, 45, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 60, 61, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 82, 85, 86, 88, 89, 90, 91, 92, 95, 96, 98, 99, 100] 66
[1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101] 100


In [4]:
n=25
folder_path = './data/test/random_ns_nt'
test_input=linecache.getline(str(folder_path/Path('infix_origin.txt')),n)

print('origin expression:\n{}\n'.format(test_input))

expr_sp_origin=sp.S(test_input,locals=env.local_dict)
expr_prefix_origin=env.sympy_to_prefix(expr_sp_origin)

if len(expr_prefix_origin) >1024:
    print('length of prefix is too long') 
#print('prefix:\n{}\n\n'.format(expr_prefix_origin))


after_train_str=generate_summary(expr_prefix_origin)
after_train_prefix=after_train_str.split(" ")
#print('after train:\n{}\n'.format(after_train_prefix))
after_train_infix=env.prefix_to_infix(after_train_prefix)
after_train_sp=env.infix_to_sympy(after_train_infix)
simple_input_infix=linecache.getline(str(folder_path/Path('infix_simple.txt')),n)
print("Simple Form we set :\n{}".format(simple_input_infix))
print("after train, we got :\n{}\n".format(after_train_sp))


origin expression:
egamma(-2*z/(s + 8), (s + 7)/(s + 8), (-s - t - 7)/(s + 8))*egamma(-z/(t - 1), (-s - 10*t + 8)/(t - 1), (4 - 5*t)/(t - 1))/(egamma(-2*z/(s + 8), (-s - t - 7)/(s + 8), (s + 7)/(s + 8))*egamma(-z/(t - 1), (-s - 5*t + 4)/(t - 1), (4 - 5*t)/(t - 1))*egamma(-z/(t - 1), (s + 5*t - 4)/(t - 1), (-s - 10*t + 8)/(t - 1))*egamma((-7*s + t - z)/(6*s + 1), (-7*s + t)/(6*s + 1), s/(6*s + 1)))


Simple Form we set :
egamma(-z/(6*s + 1), (7*s - t)/(6*s + 1), s/(6*s + 1))

after train, we got :
egamma(-z/(6*s + 1), (7*s - t)/(6*s + 1), s/(6*s + 1))



In [5]:
test_input_1='egamma((s - 2*z)/s, (-6*t - 5)/s, (t + 1)/s)*egamma((s - t - 2*z - 1)/(6*t + 5), s/(6*t + 5), (-t - 1)/(6*t + 5))'
print('origin expression:\n{}\n'.format(test_input_1))

expr_sp_origin=sp.S(test_input_1,locals=env.local_dict)
expr_prefix_origin=env.sympy_to_prefix(expr_sp_origin)

if len(expr_prefix_origin) >1024:
    print('length of prefix is too long') 
#print('prefix:\n{}\n\n'.format(expr_prefix_origin))


after_train_str=generate_summary(expr_prefix_origin)
after_train_prefix=after_train_str.split(" ")
#print('after train:\n{}\n'.format(after_train_prefix))
after_train_infix=env.prefix_to_infix(after_train_prefix)
after_train_sp=env.infix_to_sympy(after_train_infix)
#simple_input_infix=linecache.getline(str(folder_path/Path('simple_infix.txt')),n)
print("after train, we got :\n{}\n".format(after_train_sp))
#print("Simple Form we set :\n{}".format(simple_input_infix))

origin expression:
egamma((s - 2*z)/s, (-6*t - 5)/s, (t + 1)/s)*egamma((s - t - 2*z - 1)/(6*t + 5), s/(6*t + 5), (-t - 1)/(6*t + 5))

after train, we got :
egamma(2*z/(t + 1), s/(t + 1), (6*t + 5)/(t + 1))

