In [None]:
import multiprocessing
from multiprocessing import Pool
from pathlib import Path
import numpy as np
import mpmath as mp
from tqdm import tqdm
import random


class GlobalConfig:
    NUM_WORKERS_VERIFY = 8
    DATA_DIR = Path("./data/test/random_nsnt")
    INPUT_FILE = DATA_DIR / "infix_origin.txt"
    TARGET_FILE = DATA_DIR / "infix_simple.txt"
    LOG_FILE = "verification_failures.log"

    TEST_SAMPLE_LIMIT = 100
    MP_DPS = 100

    H_STEP = mp.mpf("0.07")
    INVARIANT_THRESHOLD = 1e-3

    MAX_TS_RETRIES = 5



def theta0_safe(z, tau):
    x = mp.exp(2 * mp.pi * 1j * z)
    q = mp.exp(2 * mp.pi * 1j * tau)
    return mp.qp(x, q) * mp.qp(q / x, q)

def elliptic_gamma_function(z, tau, sigma, max_terms=3000):
    z, tau, sigma = mp.mpc(z), mp.mpc(tau), mp.mpc(sigma)
    dynamic_tol = mp.eps * 100

    if mp.im(tau) < -dynamic_tol:
        return 1 / elliptic_gamma_function(z - tau, -tau, sigma, max_terms)
    if mp.im(sigma) < -dynamic_tol:
        return 1 / elliptic_gamma_function(z - sigma, tau, -sigma, max_terms)

    if abs(mp.im(tau)) < dynamic_tol or abs(mp.im(sigma)) < dynamic_tol:
        return mp.mpc("nan")

    width_limit = abs(mp.im(tau)) + abs(mp.im(sigma))
    current_dist = mp.im(2 * z - tau - sigma)

    if abs(current_dist) < width_limit * 0.99:
        sum_result = mp.mpc(0)
        A = mp.pi * (2 * z - tau - sigma)
        B = mp.pi * tau
        C = mp.pi * sigma

        for j in range(1, max_terms + 1):
            j_mp = mp.mpf(j)
            sinB = mp.sin(j_mp * B)
            sinC = mp.sin(j_mp * C)
            if abs(sinB) < mp.eps or abs(sinC) < mp.eps:
                continue
            term = mp.sin(j_mp * A) / (j_mp * sinB * sinC)
            sum_result += term
            if abs(term) < dynamic_tol:
                break

        return mp.exp(-0.5j * sum_result)

    if current_dist > 0:
        return theta0_safe(z - sigma, tau) * elliptic_gamma_function(
            z - sigma, tau, sigma, max_terms
        )
    else:
        return elliptic_gamma_function(z + sigma, tau, sigma, max_terms) / theta0_safe(z, tau)

def egamma(z, t, s):
    return elliptic_gamma_function(z, t, s)



In [4]:
def eval_expr_val(expr_str, z_val, t_val, s_val):
    context = {
        "egamma": egamma,
        "z": z_val,
        "t": t_val,
        "s": s_val,
        "mp": mp,
    }
    try:
        val = eval(expr_str, {"__builtins__": None}, context)
        if not mp.isfinite(val) or val == 0:
            return None
        return val
    except Exception:
        return None


def fourth_diff_invariant(R_vals):
    try:
        return (
            R_vals[0] * R_vals[4]
            / (R_vals[1]**4 * R_vals[3]**4)
            * (R_vals[2]**6)
        )
    except Exception:
        return None

def find_simple_safe_points(
    origin_expr,
    simple_expr,
    num_z=6,
    max_trials=200,
    h=mp.mpf("0.07")
):

    for _ in range(50):  
        t_val = complex(random.uniform(-0.8, 0.8), random.uniform(0.2, 0.9))
        s_val = complex(random.uniform(-0.8, 0.8), random.uniform(0.2, 0.9))

        t_mp = mp.mpc(t_val)
        s_mp = mp.mpc(s_val)

        valid_z = []

        for _ in range(max_trials):
            z0 = complex(
                random.uniform(-0.6, 0.6),
                random.uniform(-0.6, 0.6),
            )
            z0_mp = mp.mpc(z0)

            ok = True
            for k in range(5):
                z_k = z0_mp + k * h
                v1 = eval_expr_val(origin_expr, z_k, t_mp, s_mp)
                v2 = eval_expr_val(simple_expr, z_k, t_mp, s_mp)
                if v1 is None or v2 is None:
                    ok = False
                    break

            if ok:
                valid_z.append(z0)

            if len(valid_z) >= num_z:
                return t_val, s_val, valid_z

    return None, None, None

def check_at_z(origin_expr, simple_expr, z0, t_val, s_val, h):
    R_vals = []

    for k in range(5):
        z_k = z0 + k * h
        v_org = eval_expr_val(origin_expr, z_k, t_val, s_val)
        v_sim = eval_expr_val(simple_expr, z_k, t_val, s_val)

        if v_org is None or v_sim is None:
            return None

        r = v_org / v_sim
        if not mp.isfinite(r) or r == 0:
            return None

        R_vals.append(r)

    return fourth_diff_invariant(R_vals)



def verify_single_pair(args):
    idx, simple_expr, origin_expr = args
    mp.dps = GlobalConfig.MP_DPS

    for attempt in range(GlobalConfig.MAX_TS_RETRIES):
        try:
            t_val, s_val, z_list = find_simple_safe_points(origin_expr, simple_expr)
        except Exception:
            continue

        if t_val is None or not z_list:
            continue

        t_mp = mp.mpc(t_val)
        s_mp = mp.mpc(s_val)

        passed = True

        for z0 in z_list:
            z0_mp = mp.mpc(z0)

            I = check_at_z(
                origin_expr,
                simple_expr,
                z0_mp,
                t_mp,
                s_mp,
                GlobalConfig.H_STEP,
            )

            if I is None or abs(I - 1) > GlobalConfig.INVARIANT_THRESHOLD:
                passed = False
                break

        if passed:
            return idx, True, "Fourth-diff invariant passed"

    return idx, False, "Fourth-diff invariant failed"


In [11]:
expr1 = "egamma(-2*z/(s + 8), (s + 7)/(s + 8), (-s - t - 7)/(s + 8))*egamma(-z/(t - 1), (-s - 10*t + 8)/(t - 1), (4 - 5*t)/(t - 1))/(egamma(-2*z/(s + 8), (-s - t - 7)/(s + 8), (s + 7)/(s + 8))*egamma(-z/(t - 1), (-s - 5*t + 4)/(t - 1), (4 - 5*t)/(t - 1))*egamma(-z/(t - 1), (s + 5*t - 4)/(t - 1), (-s - 10*t + 8)/(t - 1))*egamma((-7*s + t - z)/(6*s + 1), (-7*s + t)/(6*s + 1), s/(6*s + 1)))"
expr2 = "egamma(-z/(6*s + 1), (7*s - t)/(6*s + 1), s/(6*s + 1))" 
expr3 = "egamma(-z, (7*s - t)/(6*s + 1), s/(6*s + 1))"
expr4 = "egamma(z, t, s)**2"
expr5 = "egamma(z, s, t)"

_, ok, msg = verify_single_pair((0, expr2, expr5))
print(f"Verification result: {ok}, Message: {msg}")

Verification result: False, Message: Fourth-diff invariant failed
