**This is a script for quickly demonstrating the capability of model reduction for the $q–\theta$ function.**

## 1. Import essential libraries and initialize the required environment setup.

In [2]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import sympy as sp
from src.utils import AttrDict
from src.envs import build_env
import linecache
import mpmath as mp
import numpy as np

params = AttrDict({

    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 512,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',
})

env = build_env(params)



## 2. Load model and Tokenizer

In [3]:
model = "./results/random_ns_nt" # model path to your trained model
tokenizer_path="./results/random_ns_nt" # tokenizer path to your trained model (usually the same as model path)

In [4]:
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model = T5ForConditionalGeneration.from_pretrained(model)

device='cuda'
model=model.to(device)
model=model.eval()
def generate_summary(input_tokens):
    inputs = tokenizer(input_tokens, return_tensors="pt",is_split_into_words=True, padding=True, truncation=True)
    if device =='cuda':
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        outputs = model.generate(inputs['input_ids'], max_length=512, num_beams=2, early_stopping=True)
    else:
        outputs = model.generate(inputs.input_ids, max_length=512, num_beams=2, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

## 3. Use model to simplication

**You can use your own expression here, expression can be generated using the script in `./data_process/datagenerate.py`**

In [5]:
expr_origin_infix = "q_theta(z/(t - 1), (8*t - 9)/(t - 1))*q_theta(z/(t + 1), t/(t + 1))*q_theta(-z/(5*t - 9), (6*t - 11)/(5*t - 9))*q_theta(-z/(6*t + 11), (-t - 2)/(6*t + 11))*q_theta(z/(20*t - 9), (11*t - 5)/(20*t - 9))/(q_theta(-z/(4*t + 1), (-3*t - 1)/(4*t + 1))*q_theta(-z/(8*t + 1), (33*t + 4)/(8*t + 1))*q_theta(-z/(9*t - 4), (11*t - 5)/(9*t - 4))*q_theta(-z/(19*t + 34), (14*t + 25)/(19*t + 34)))"
expr_sp_origin = sp.S(expr_origin_infix, locals=env.local_dict)
expr_origin_prefix = env.sympy_to_prefix(expr_sp_origin)
print("Original expression (infix):", expr_origin_infix)
print("Original expression (prefix):", expr_origin_prefix)

Original expression (infix): q_theta(z/(t - 1), (8*t - 9)/(t - 1))*q_theta(z/(t + 1), t/(t + 1))*q_theta(-z/(5*t - 9), (6*t - 11)/(5*t - 9))*q_theta(-z/(6*t + 11), (-t - 2)/(6*t + 11))*q_theta(z/(20*t - 9), (11*t - 5)/(20*t - 9))/(q_theta(-z/(4*t + 1), (-3*t - 1)/(4*t + 1))*q_theta(-z/(8*t + 1), (33*t + 4)/(8*t + 1))*q_theta(-z/(9*t - 4), (11*t - 5)/(9*t - 4))*q_theta(-z/(19*t + 34), (14*t + 25)/(19*t + 34)))
Original expression (prefix): ['mul', 'pow', 'q_theta', 'mul', 'INT-', '1', 'mul', 'z', 'pow', 'add', 'INT+', '1', 'mul', 'INT+', '4', 't', 'INT-', '1', 'mul', 'pow', 'add', 'INT+', '1', 'mul', 'INT+', '4', 't', 'INT-', '1', 'add', 'INT-', '1', 'mul', 'INT-', '3', 't', 'INT-', '1', 'mul', 'pow', 'q_theta', 'mul', 'INT-', '1', 'mul', 'z', 'pow', 'add', 'INT+', '1', 'mul', 'INT+', '8', 't', 'INT-', '1', 'mul', 'pow', 'add', 'INT+', '1', 'mul', 'INT+', '8', 't', 'INT-', '1', 'add', 'INT+', '4', 'mul', 'INT+', '3', '3', 't', 'INT-', '1', 'mul', 'pow', 'q_theta', 'mul', 'INT-', '1', 'm

In [6]:
after_train_str = generate_summary(expr_origin_prefix)
after_train_prefix = after_train_str.split(" ")

after_train_infix = env.prefix_to_infix(after_train_prefix)
after_train_sp = env.infix_to_sympy(after_train_infix)

predicted_str = str(after_train_sp)
print("Model prediction:", predicted_str)

Model prediction: q_theta(-z/(t - 1), (7*t - 8)/(t - 1))*q_theta(-z/(5*t - 9), (t - 2)/(5*t - 9))/q_theta(-z/(8*t + 1), (9*t + 1)/(8*t + 1))


In [9]:
# Before simplification
expr_sp_origin

q_theta(z/(t - 1), (8*t - 9)/(t - 1))*q_theta(z/(t + 1), t/(t + 1))*q_theta(-z/(5*t - 9), (6*t - 11)/(5*t - 9))*q_theta(-z/(6*t + 11), (-t - 2)/(6*t + 11))*q_theta(z/(20*t - 9), (11*t - 5)/(20*t - 9))/(q_theta(-z/(4*t + 1), (-3*t - 1)/(4*t + 1))*q_theta(-z/(8*t + 1), (33*t + 4)/(8*t + 1))*q_theta(-z/(9*t - 4), (11*t - 5)/(9*t - 4))*q_theta(-z/(19*t + 34), (14*t + 25)/(19*t + 34)))

In [10]:
#After simplification
after_train_sp

q_theta(-z/(t - 1), (7*t - 8)/(t - 1))*q_theta(-z/(5*t - 9), (t - 2)/(5*t - 9))/q_theta(-z/(8*t + 1), (9*t + 1)/(8*t + 1))

## 4. Check the prediction

In [20]:
# First define some key functions
mp.dps = 30  

def theta_q_jtheta(z, tau):
    z = mp.mpc(z)
    tau = mp.mpc(tau)
    Q_jacobi = mp.exp(mp.pi * 1j * tau) 
    q_img = Q_jacobi**2  
    w = mp.pi * z

    j1 = mp.jtheta(1, w, Q_jacobi)
    
    prefactor = 1j * (Q_jacobi**0.25) * mp.exp(-1j * mp.pi * z)
    eta_prod = mp.qp(q_img) 
    
    return j1 / (prefactor * eta_prod)

def theta_q_safe(z, tau):
    return theta_q_jtheta(z, tau)

def q_theta(z, tau):
    return theta_q_safe(z, tau)


def get_log_derivative(expr_str, z0, t_val, h=1e-5):

    context = {
        'q_theta': q_theta,
        'theta_q_safe': theta_q_safe,
        'z': None, 
        't': t_val,
        'I': 1j,
        'mp': mp
    }

    try:
        # Calculate f(z0)
        context['z'] = z0
        f_0 = complex(eval(expr_str, {"__builtins__": None}, context))
        
        if abs(f_0) < 1e-12: return None # Avoid singularities/zeros

        # Calculate f(z0 + h) and f(z0 - h)
        context['z'] = z0 + h
        f_plus = complex(eval(expr_str, {"__builtins__": None}, context))
        
        context['z'] = z0 - h
        f_minus = complex(eval(expr_str, {"__builtins__": None}, context))

        # Central difference formula: f'(z) ≈ (f(z+h) - f(z-h)) / 2h
        # Logarithmic derivative: f'/f ≈ (f_plus - f_minus) / (2 * h * f_0)
        deriv = (f_plus - f_minus) / (2 * h * f_0)
        return deriv

    except Exception as e:
        # print(f"Error evaluating {expr_str} at z={z0}: {e}")
        return None

def generate_random_tau():
    real_part = np.random.uniform(-0.5, 0.5)
    imag_part = np.random.uniform(0.5, 2.0)  
    return complex(real_part, imag_part)

def verify_single_sample(expr_origin, expr_simple, threshold=1e-3, max_tau_retries=5):
    
    num_points = 10
    
    for tau_attempt in range(max_tau_retries):
        if tau_attempt == 0:
            t_val = 0.3 + 1.2j
        else:
            t_val = generate_random_tau()
        
        np.random.seed(1107 + tau_attempt * 100)
        z_samples = np.random.uniform(-0.5, 0.5, num_points) + \
                    1j * np.random.uniform(-0.5, 0.5, num_points)
        
        diff_values = []
        valid_z = []

        for z_val in z_samples:
            ld_origin = get_log_derivative(expr_origin, z_val, t_val)
            ld_simple = get_log_derivative(expr_simple, z_val, t_val)
            
            if ld_origin is not None and ld_simple is not None:
                diff = ld_origin - ld_simple
                diff_values.append(diff)
                valid_z.append(z_val)

        if len(diff_values) >= 5:
            break
        
        if tau_attempt < max_tau_retries - 1:
            continue
    
    if len(diff_values) < 5:
        print(f"[Insufficient valid points, skipping (Tried {max_tau_retries} different tau values)]")
        print(f"   Origin: {expr_origin}")
        print(f"   Simple: {expr_simple}")
        return False

    X = np.array(valid_z)
    y = np.array(diff_values)

    A = np.column_stack((X, np.ones_like(X)))
    
    try:
        # Perform linear regression to check for constant or linear differences
        coeffs, _, _, _ = np.linalg.lstsq(A, y, rcond=None)
        
        y_fit = A @ coeffs
        residuals = np.abs(y - y_fit)
        mean_error = np.mean(residuals)
        
        is_pass = mean_error < threshold
        
        # status = "✅ PASS" if is_pass else "❌ FAIL"
        # print(f"Line : {status} | Linear Fit Error: {mean_error:.2e}")
        if not is_pass:
            print(f"[Fit Failed] Linear Fit Error: {mean_error:.2e}")
            print(f"   Origin: {expr_origin}")
            print(f"   Simple: {expr_simple}")

        return is_pass

    except Exception as e:
        print(f"[Error during fitting process] {e}")
        return False

In [21]:
verify_single_sample(expr_origin_infix, predicted_str)

True