In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import sympy as sp
from src.utils import AttrDict
from src.envs import build_env
import linecache
import mpmath as mp
import numpy as np

params = AttrDict({

    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',
})

env = build_env(params)



In [2]:
model = "./results/random_ns_nt"
tokenizer_path="./results/random_ns_nt"


tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model = T5ForConditionalGeneration.from_pretrained(model)

device='cuda'
model=model.to(device)
model=model.eval()
def generate_summary(input_tokens):
    inputs = tokenizer(input_tokens, return_tensors="pt",is_split_into_words=True, padding=True, truncation=True)
    if device =='cuda':
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        outputs = model.generate(inputs['input_ids'], max_length=512, num_beams=2, early_stopping=True)
    else:
        outputs = model.generate(inputs.input_ids, max_length=512, num_beams=2, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


In [3]:
def theta_q_jtheta(z, tau):
    z = mp.mpc(z)
    tau = mp.mpc(tau)
    Q_jacobi = mp.exp(mp.pi * 1j * tau) 
    q_img = Q_jacobi**2  
    w = mp.pi * z

    j1 = mp.jtheta(1, w, Q_jacobi)
    
    prefactor = 1j * (Q_jacobi**0.25) * mp.exp(-1j * mp.pi * z)
    eta_prod = mp.qp(q_img) 
    
    return j1 / (prefactor * eta_prod)

In [4]:
mp.dps = 30  

def theta_q_safe(z, tau):
    return theta_q_jtheta(z, tau)

def q_theta(z, tau):
    return theta_q_safe(z, tau)


def get_log_derivative(expr_str, z0, t_val, h=1e-5):

    context = {
        'q_theta': q_theta,
        'theta_q_safe': theta_q_safe,
        'z': None, 
        't': t_val,
        'I': 1j,
        'mp': mp
    }

    try:
        # 计算 f(z0)
        context['z'] = z0
        f_0 = complex(eval(expr_str, {"__builtins__": None}, context))
        
        if abs(f_0) < 1e-12: return None # 避开零点

        # 计算 f(z0 + h) 和 f(z0 - h)
        context['z'] = z0 + h
        f_plus = complex(eval(expr_str, {"__builtins__": None}, context))
        
        context['z'] = z0 - h
        f_minus = complex(eval(expr_str, {"__builtins__": None}, context))

        # 中心差分公式: f'(z) ≈ (f(z+h) - f(z-h)) / 2h
        # 对数导数: f'/f ≈ (f_plus - f_minus) / (2 * h * f_0)
        deriv = (f_plus - f_minus) / (2 * h * f_0)
        return deriv

    except Exception as e:
        # print(f"Error evaluating {expr_str} at z={z0}: {e}")
        return None

def generate_random_tau():
    real_part = np.random.uniform(-0.5, 0.5)
    imag_part = np.random.uniform(0.5, 2.0)  
    return complex(real_part, imag_part)

def verify_single_sample(expr_origin, expr_simple, threshold=1e-3, max_tau_retries=5):
    
    num_points = 10
    
    for tau_attempt in range(max_tau_retries):
        if tau_attempt == 0:
            t_val = 0.3 + 1.2j
        else:
            t_val = generate_random_tau()
        
        np.random.seed(1107 + tau_attempt * 100)
        z_samples = np.random.uniform(-0.5, 0.5, num_points) + \
                    1j * np.random.uniform(-0.5, 0.5, num_points)
        
        diff_values = []
        valid_z = []

        for z_val in z_samples:
            ld_origin = get_log_derivative(expr_origin, z_val, t_val)
            ld_simple = get_log_derivative(expr_simple, z_val, t_val)
            
            if ld_origin is not None and ld_simple is not None:
                diff = ld_origin - ld_simple
                diff_values.append(diff)
                valid_z.append(z_val)

        if len(diff_values) >= 5:
            break
        
        if tau_attempt < max_tau_retries - 1:
            continue
    
    if len(diff_values) < 5:
        print(f"[有效点不足，跳过 (已尝试 {max_tau_retries} 个不同的 tau 值)]")
        print(f"   Origin: {expr_origin}")
        print(f"   Simple: {expr_simple}")
        return False

    X = np.array(valid_z)
    y = np.array(diff_values)

    A = np.column_stack((X, np.ones_like(X)))
    
    try:
        coeffs, _, _, _ = np.linalg.lstsq(A, y, rcond=None)
        
        y_fit = A @ coeffs
        residuals = np.abs(y - y_fit)
        mean_error = np.mean(residuals)
        
        is_pass = mean_error < threshold
        
        # status = "✅ PASS" if is_pass else "❌ FAIL"
        # print(f"Line : {status} | Linear Fit Error: {mean_error:.2e}")
        if not is_pass:
            print(f"[拟合未通过] 线性拟合误差: {mean_error:.2e}")
            print(f"   Origin: {expr_origin}")
            print(f"   Simple: {expr_simple}")

        
        return is_pass

    except Exception as e:
        print(f"[拟合过程出错] {e}")
        return False

In [None]:
FILE_TEST_INPUT = "./data/test/random_ns_nt/infix_origin.txt"
FILE_TEST_SIMPLE = "./data/test/random_ns_nt/infix_simple.txt"

with open(FILE_TEST_INPUT, 'r') as f:
    test_inputs = f.readlines()

with open(FILE_TEST_SIMPLE, 'r') as f:
    test_simples = f.readlines()

total_samples = len(test_inputs)
print(f"Total samples: {total_samples}")

correct_count = 0
error_count = 0

for n in range(1, total_samples + 1):
    test_input = test_inputs[n - 1].strip()
    simple_input = test_simples[n - 1].strip()
    
    try:
        expr_sp_origin = sp.S(test_input, locals=env.local_dict)
        expr_prefix_origin = env.sympy_to_prefix(expr_sp_origin)
        
        after_train_str = generate_summary(expr_prefix_origin)
        after_train_prefix = after_train_str.split(" ")

        after_train_infix = env.prefix_to_infix(after_train_prefix)
        after_train_sp = env.infix_to_sympy(after_train_infix)
        
        predicted_str = str(after_train_sp)
        
        is_correct = verify_single_sample(simple_input, predicted_str)
        
        if is_correct:
            correct_count += 1
        else:
            error_count += 1
            
    except Exception as e:
        error_count += 1
        print(f"Sample {n}: ❌ ERROR - {type(e).__name__}: {e}")
    
    if n % 100 == 0:
        accuracy = correct_count / n * 100
        print(f"\n=== Step {n} Summary ===")
        print(f"Correct: {correct_count}, Error: {error_count}")
        print(f"Cumulative Accuracy: {accuracy:.2f}%\n")

final_accuracy = correct_count / total_samples * 100
print(f"\n=== Final Summary ===")
print(f"Total: {total_samples}, Correct: {correct_count}, Error: {error_count}")
print(f"Final Accuracy: {final_accuracy:.2f}%")

Total samples: 10000

=== Step 100 Summary ===
Correct: 100, Error: 0
Cumulative Accuracy: 100.00%


=== Step 200 Summary ===
Correct: 200, Error: 0
Cumulative Accuracy: 100.00%


=== Step 300 Summary ===
Correct: 300, Error: 0
Cumulative Accuracy: 100.00%


=== Step 400 Summary ===
Correct: 400, Error: 0
Cumulative Accuracy: 100.00%


=== Step 500 Summary ===
Correct: 500, Error: 0
Cumulative Accuracy: 100.00%

[拟合未通过] 线性拟合误差: 2.58e+00
   Origin: q_theta(z/(7*t - 6), (5 - 6*t)/(7*t - 6))*q_theta(-z/(8*t - 1), (7*t - 1)/(8*t - 1))
   Simple: q_theta(-z/(t - 1), (6 - 7*t)/(t - 1))*q_theta(-z/(8*t - 1), (7*t - 1)/(8*t - 1))**2

=== Step 600 Summary ===
Correct: 599, Error: 1
Cumulative Accuracy: 99.83%


=== Step 700 Summary ===
Correct: 699, Error: 1
Cumulative Accuracy: 99.86%


=== Step 800 Summary ===
Correct: 799, Error: 1
Cumulative Accuracy: 99.88%

[拟合未通过] 线性拟合误差: 2.58e+00
   Origin: q_theta(z/(t - 3), (5 - 2*t)/(t - 3))*q_theta(z/(3*t - 1), (2*t - 1)/(3*t - 1))
   Simple: q_thet

In [None]:
verify_single_sample('q_theta(-z/(t - 1), -1/(t - 1))**3', 'q_theta(-z/(t - 1), -1/(t - 1))')

[拟合未通过] 线性拟合误差: 5.17e+00
   Origin: q_theta(-z/(t - 1), -1/(t - 1))**3
   Simple: q_theta(-z/(t - 1), -1/(t - 1))


False