**This is a script for quickly demonstrating the capability of model reduction for the elliptic $\Gamma$ function.**

## 1. Import essential libraries and initialize the required environment setup.

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
import sympy as sp
from src.utils import AttrDict
from src.envs import build_env
import linecache
from pathlib import Path


params = AttrDict({

    # environment parameters
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 1,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 1024,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',
    'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',
})

env = build_env(params)

## 2. Load model and Tokenizer

In [2]:
model_path = "./results/model_for_inference" # model path to your trained model
tokenizer_path="./results/model_for_inference" # tokenizer path to your trained model (usually the same as model path)

In [3]:
tokenizer = T5Tokenizer.from_pretrained(tokenizer_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)
print(f"max length: {model.config.n_positions}")

device='cuda'
model=model.to(device)
model=model.eval()
def generate_summary(input_tokens):
    inputs = tokenizer(input_tokens, return_tensors="pt",is_split_into_words=True, padding=True, truncation=True)
    if device =='cuda':
        inputs = {k: v.to('cuda') for k, v in inputs.items()}
        outputs = model.generate(inputs['input_ids'], max_length=1024, num_beams=2, early_stopping=True)
    else:
        outputs = model.generate(inputs.input_ids, max_length=1024, num_beams=2, early_stopping=True)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

max length: 1024


## 3. Use model to simplication

**You can use your own expression here, expression can be generated using the script in `./data_process/datagenerate.py`**

In [4]:
expr_origin_infix="egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (7*s - 8)/(s - 1))*egamma((t + z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((7*s + z - 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((9*s + z - 5)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-7*s + t + z + 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-5*s + t + z + 3)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + t + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))/(egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (8*s - 9)/(s - 1))*egamma((2*t - z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1)))"
expr_sp_origin = sp.S(expr_origin_infix, locals=env.local_dict)
expr_origin_prefix = env.sympy_to_prefix(expr_sp_origin)
print("Original expression (infix):", expr_origin_infix)
print("Original expression (prefix):", expr_origin_prefix)

Original expression (infix): egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (7*s - 8)/(s - 1))*egamma((t + z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((7*s + z - 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((9*s + z - 5)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-7*s + t + z + 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-5*s + t + z + 3)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + t + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))/(egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (8*s - 9)/(s - 1))*egamma((2*t - z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1)))
Original expression (prefix): ['mul', 'pow', 'egamma', 'mul', 'z', 'pow', 'add', 'INT-', '1', 's', 'INT-', '1', 'mul', 'pow', 'add', 'INT-', '1', 's', 'INT-', '1', 'add', 'INT+', '8', 'add', 't', 'mul', 'INT-', 

In [5]:
after_train_str = generate_summary(expr_origin_prefix)
after_train_prefix = after_train_str.split(" ")

after_train_infix = env.prefix_to_infix(after_train_prefix)
after_train_sp = env.infix_to_sympy(after_train_infix)

predicted_str = str(after_train_sp)
print("Model prediction:", predicted_str)

Model prediction: egamma(z/(2*s - 1), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))


In [6]:
# Before simplification
expr_sp_origin

egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (7*s - 8)/(s - 1))*egamma((t + z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((7*s + z - 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((9*s + z - 5)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-7*s + t + z + 4)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((-5*s + t + z + 3)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))*egamma((2*s + t + z - 1)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))/(egamma(z/(s - 1), (-7*s + t + 8)/(s - 1), (8*s - 9)/(s - 1))*egamma((2*t - z)/(4*s - 2), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1)))

In [7]:
#After simplification
after_train_sp

egamma(z/(2*s - 1), (-7*s + t + 4)/(2*s - 1), (7*s - 4)/(2*s - 1))

## 4. Check the prediction

**Check the prediction using forth order difference method**

In [8]:
import mpmath as mp
import random


def theta0_safe(z, tau):
    x = mp.exp(2 * mp.pi * 1j * z)
    q = mp.exp(2 * mp.pi * 1j * tau)
    return mp.qp(x, q) * mp.qp(q / x, q)

def elliptic_gamma_function(z, tau, sigma, max_terms=3000):
    z, tau, sigma = mp.mpc(z), mp.mpc(tau), mp.mpc(sigma)
    dynamic_tol = mp.eps * 100

    if mp.im(tau) < -dynamic_tol:
        return 1 / elliptic_gamma_function(z - tau, -tau, sigma, max_terms)
    if mp.im(sigma) < -dynamic_tol:
        return 1 / elliptic_gamma_function(z - sigma, tau, -sigma, max_terms)

    if abs(mp.im(tau)) < dynamic_tol or abs(mp.im(sigma)) < dynamic_tol:
        return mp.mpc("nan")

    width_limit = abs(mp.im(tau)) + abs(mp.im(sigma))
    current_dist = mp.im(2 * z - tau - sigma)

    if abs(current_dist) < width_limit * 0.99:
        sum_result = mp.mpc(0)
        A = mp.pi * (2 * z - tau - sigma)
        B = mp.pi * tau
        C = mp.pi * sigma

        for j in range(1, max_terms + 1):
            j_mp = mp.mpf(j)
            sinB = mp.sin(j_mp * B)
            sinC = mp.sin(j_mp * C)
            if abs(sinB) < mp.eps or abs(sinC) < mp.eps:
                continue
            term = mp.sin(j_mp * A) / (j_mp * sinB * sinC)
            sum_result += term
            if abs(term) < dynamic_tol:
                break

        return mp.exp(-0.5j * sum_result)

    if current_dist > 0:
        return theta0_safe(z - sigma, tau) * elliptic_gamma_function(
            z - sigma, tau, sigma, max_terms
        )
    else:
        return elliptic_gamma_function(z + sigma, tau, sigma, max_terms) / theta0_safe(z, tau)

def egamma(z, t, s):
    return elliptic_gamma_function(z, t, s)


def eval_expr_val(expr_str, z_val, t_val, s_val):
    context = {
        "egamma": egamma,
        "z": z_val,
        "t": t_val,
        "s": s_val,
        "mp": mp,
    }
    try:
        val = eval(expr_str, {"__builtins__": None}, context)
        if not mp.isfinite(val) or val == 0:
            return None
        return val
    except Exception:
        return None


def fourth_diff_invariant(R_vals):
    try:
        return (
            R_vals[0] * R_vals[4]
            / (R_vals[1]**4 * R_vals[3]**4)
            * (R_vals[2]**6)
        )
    except Exception:
        return None

def find_simple_safe_points(
    origin_expr,
    simple_expr,
    num_z=6,
    max_trials=200,
    h=mp.mpf("0.07")
):

    for _ in range(50):  
        t_val = complex(random.uniform(-0.8, 0.8), random.uniform(0.2, 0.9))
        s_val = complex(random.uniform(-0.8, 0.8), random.uniform(0.2, 0.9))

        t_mp = mp.mpc(t_val)
        s_mp = mp.mpc(s_val)

        valid_z = []

        for _ in range(max_trials):
            z0 = complex(
                random.uniform(-0.6, 0.6),
                random.uniform(-0.6, 0.6),
            )
            z0_mp = mp.mpc(z0)

            ok = True
            for k in range(5):
                z_k = z0_mp + k * h
                v1 = eval_expr_val(origin_expr, z_k, t_mp, s_mp)
                v2 = eval_expr_val(simple_expr, z_k, t_mp, s_mp)
                if v1 is None or v2 is None:
                    ok = False
                    break

            if ok:
                valid_z.append(z0)

            if len(valid_z) >= num_z:
                return t_val, s_val, valid_z

    return None, None, None

def check_at_z(origin_expr, simple_expr, z0, t_val, s_val, h):
    R_vals = []

    for k in range(5):
        z_k = z0 + k * h
        v_org = eval_expr_val(origin_expr, z_k, t_val, s_val)
        v_sim = eval_expr_val(simple_expr, z_k, t_val, s_val)

        if v_org is None or v_sim is None:
            return None

        r = v_org / v_sim
        if not mp.isfinite(r) or r == 0:
            return None

        R_vals.append(r)

    return fourth_diff_invariant(R_vals)



def verify_single_pair(args):
    idx, simple_expr, origin_expr = args
    mp.dps = 50

    for attempt in range(5):
        try:
            t_val, s_val, z_list = find_simple_safe_points(origin_expr, simple_expr)
        except Exception:
            continue

        if t_val is None or not z_list:
            continue

        t_mp = mp.mpc(t_val)
        s_mp = mp.mpc(s_val)

        passed = True

        for z0 in z_list:
            z0_mp = mp.mpc(z0)

            I = check_at_z(
                origin_expr,
                simple_expr,
                z0_mp,
                t_mp,
                s_mp,
                mp.mpf("0.07"),
            )

            if I is None or abs(I - 1) > 1e-3:
                passed = False
                break

        if passed:
            return idx, True, "Fourth-diff invariant passed"

    return idx, False, "Fourth-diff invariant failed"



In [9]:
_, ok, msg = verify_single_pair((0, expr_origin_infix, predicted_str))
print(f"Verification result: {ok}")

Verification result: True
