In [1]:
import os
import sys
import numpy as np

# Navigate to the parent directory of the project structure
project_dir = os.path.abspath(os.path.join(os.getcwd(), '../..'))
src_dir = os.path.join(project_dir, 'src')
log_dir = os.path.join(project_dir, 'log')
fig_dir = os.path.join(project_dir, 'fig')

# Add the src directory to sys.path
sys.path.append(src_dir)


from Dataset.cancer_dataset_for_LS import main_generate_cancer_matrices_for_LS, load_cancer_dataset_matrices_for_LS
from analysis.commons import data_normalize_by_features, concatenate_B_b

from estimator.NDIS import BasicNDISEstimator
from analysis.monotonicity_certifier_left import f_value_p, fprime_p
import mpmath as mp

In [2]:
def check_fprime_vs_fd(
    p: float,
    s: float,
    d: int,
    r: float,
    eps: float,
    h: float = 1e-4,
    workers_f: int = 1,
    workers_fp: int = 1,
    dps: int = 80,
    tail_sigma: float = 12.0,
    chunks_per_piece: int = 2,
):
    """
    Compare analytic f'(p) (via fprime_p) with a central finite-difference
    approximation based on f(p) = delta_pq_ls_right(p, p+s, d, r, eps, ...).

    Parameters
    ----------
    p, s, d, r, eps : as usual.
    h : float
        Base step size for finite differences (will be shrunk if too close to boundaries).
    workers_f : int
        Number of workers for delta_pq_ls_right (f).
    workers_fp : int
        Number of workers for fprime_p.
    dps : int
        Precision for mpmath.
    tail_sigma : float
        Truncation bound used in both f and f'.
    sign_w, chunks_per_piece :
        Passed to delta_pq_ls_right.
    """
    mp.mp.dps = int(dps)

    if not (0 < s < 1):
        raise ValueError("Require 0 < s < 1.")
    if not (0 < p < 1 - s):
        raise ValueError("Require 0 < p < 1 - s.")
    if r < 1:
        raise ValueError("Require r >= 1.")
    if d < 2:
        raise ValueError("Require d >= 2.")
    if eps < 0:
        raise ValueError("Require eps >= 0.")

    # Choose a safe step h so that p±h stays in (0,1-s)
    base_h = mp.mpf(h)
    p_mp = mp.mpf(p)
    s_mp = mp.mpf(s)

    max_left = p_mp
    max_right = (1 - s_mp) - p_mp
    h_eff = min(base_h, max_left / 2, max_right / 2)
    if h_eff <= 0:
        raise ValueError("p is too close to the boundary to use symmetric finite differences.")

    p_plus = float(p_mp + h_eff)
    p_minus = float(p_mp - h_eff)

    # f(p+h), f(p-h)
    f_plus = f_value_p(
        p_plus,
        s,
        d,
        r,
        eps,
        workers=workers_f,
        dps=dps,
        tail_sigma=tail_sigma,
        chunks_per_piece=chunks_per_piece,
    )
    f_minus = f_value_p(
        p_minus,
        s,
        d,
        r,
        eps,
        workers=workers_f,
        dps=dps,
        tail_sigma=tail_sigma,
        chunks_per_piece=chunks_per_piece,
    )

    # central finite difference
    fd = (f_plus - f_minus) / (2 * float(h_eff))

    # analytic derivative
    fp = fprime_p(
        p=p,
        s=s,
        d=d,
        r=r,
        eps=eps,
        workers=workers_fp,
        dps=dps,
        tail_sigma=tail_sigma,
    )

    abs_diff = abs(fp - fd)
    rel_diff = abs_diff / max(1.0, abs(fp), abs(fd))

    print("=== Check f'(p) vs finite difference ===")
    print(f"p         = {p}")
    print(f"s         = {s}")
    print(f"d         = {d}")
    print(f"r         = {r}")
    print(f"eps       = {eps}")
    print(f"h_eff     = {float(h_eff):.6e}")
    print(f"f(p+h)    = {f_plus:.12e}")
    print(f"f(p-h)    = {f_minus:.12e}")
    print(f"fd approx = {fd:.12e}")
    print(f"fprime_p  = {fp:.12e}")
    print(f"|diff|    = {abs_diff:.3e}")
    print(f"rel diff  = {rel_diff:.3e}")

    return {
        "p": p,
        "s": s,
        "d": d,
        "r": r,
        "eps": eps,
        "h_eff": float(h_eff),
        "f_plus": f_plus,
        "f_minus": f_minus,
        "fd": fd,
        "fprime": fp,
        "abs_diff": abs_diff,
        "rel_diff": rel_diff,
    }


In [4]:
res = check_fprime_vs_fd(
        p=0.9,
        s=0.0000001,
        d=5,
        r=2.0,
        eps=0.5,
        h=1e-4,
        workers_f=1,
        workers_fp=1,
        dps=80,
        tail_sigma=12.0,
        chunks_per_piece=2,
    )


=== Check f'(p) vs finite difference ===
p         = 0.9
s         = 1e-07
d         = 5
r         = 2.0
eps       = 0.5
h_eff     = 1.000000e-04
f(p+h)    = 4.535060562611e-01
f(p-h)    = 4.531037440660e-01
fd approx = 2.011560975378e+00
fprime_p  = 2.011560577061e+00
|diff|    = 3.983e-07
rel diff  = 1.980e-07


In [6]:
def sweep_check_fprime_vs_fd(
    ps = None,
    ss = None,
    ds = None,
    rs = None,
    eps_list = None,
    h: float = 1e-4,
    workers_f: int = 1,
    workers_fp: int = 1,
    dps: int = 80,
    tail_sigma: float = 12.0,
    chunks_per_piece: int = 2,
    verbose: bool = True,
):
    """
    Sweep over a small grid of (p, s, d, r, eps) and compare fprime_p
    against finite differences of f(p) = delta_pq_ls_right(p, p+s, ...).

    Returns a summary dict with the worst absolute and relative differences.
    """

    if ss is None:
        ss = [0.05, 0.1, 0.3]        # residual shares
    if ds is None:
        ds = [3, 5, 10]              # dimensions
    if rs is None:
        rs = [1.0, 2.0, 5.0]         # regularization
    if eps_list is None:
        eps_list = [0.1, 0.5, 1.0]   # epsilons
    if ps is None:
        # p will be scaled to (0,1-s) inside the loop
        ps = [0.1, 0.3, 0.5, 0.8]    # fraction of (1-s)

    max_abs = 0.0
    max_rel = 0.0
    worst_case = None

    for s in ss:
        for d in ds:
            for r in rs:
                for eps in eps_list:
                    for alpha in ps:
                        # map alpha in (0,1) → p in (0,1-s)
                        p = alpha * (1.0 - s)
                        if not (0 < p < 1 - s):
                            continue

                        res = check_fprime_vs_fd(
                            p=p,
                            s=s,
                            d=d,
                            r=r,
                            eps=eps,
                            h=h,
                            workers_f=workers_f,
                            workers_fp=workers_fp,
                            dps=dps,
                            tail_sigma=tail_sigma,
                            chunks_per_piece=chunks_per_piece,
                        )

                        abs_diff = res["abs_diff"]
                        rel_diff = res["rel_diff"]

                        if verbose:
                            print(
                                f"[p={p:.3f}, s={s:.2f}, d={d}, r={r:.1f}, eps={eps:.2f}] "
                                f"|diff|={abs_diff:.3e}, rel={rel_diff:.3e}"
                            )

                        if abs_diff > max_abs or rel_diff > max_rel:
                            max_abs = abs_diff
                            max_rel = rel_diff
                            worst_case = res

    summary = {
        "max_abs_diff": max_abs,
        "max_rel_diff": max_rel,
        "worst_case": worst_case,
    }

    print("\n=== Sweep summary ===")
    print(f"max |diff|  = {max_abs:.3e}")
    print(f"max rel diff = {max_rel:.3e}")
    if worst_case is not None:
        wc = worst_case
        print(
            f"worst at p={wc['p']}, s={wc['s']}, d={wc['d']}, "
            f"r={wc['r']}, eps={wc['eps']}"
        )

    return summary


In [7]:
summary = sweep_check_fprime_vs_fd(
    h=1e-4,
    workers_f=1,
    workers_fp=1,
    dps=80,
    tail_sigma=12.0,
    verbose=True,
)


=== Check f'(p) vs finite difference ===
p         = 0.095
s         = 0.05
d         = 3
r         = 1.0
eps       = 0.1
h_eff     = 1.000000e-04
f(p+h)    = 1.049095759008e-02
f(p-h)    = 1.044659872122e-02
fd approx = 2.217943443065e-01
fprime_p  = 2.217943704431e-01
|diff|    = 2.614e-08
rel diff  = 2.614e-08
[p=0.095, s=0.05, d=3, r=1.0, eps=0.10] |diff|=2.614e-08, rel=2.614e-08
=== Check f'(p) vs finite difference ===
p         = 0.285
s         = 0.05
d         = 3
r         = 1.0
eps       = 0.1
h_eff     = 1.000000e-04
f(p+h)    = 6.291988107330e-02
f(p-h)    = 6.285345032287e-02
fd approx = 3.321537521508e-01
fprime_p  = 3.321537513293e-01
|diff|    = 8.215e-10
rel diff  = 8.215e-10
[p=0.285, s=0.05, d=3, r=1.0, eps=0.10] |diff|=8.215e-10, rel=8.215e-10
=== Check f'(p) vs finite difference ===
p         = 0.475
s         = 0.05
d         = 3
r         = 1.0
eps       = 0.1
h_eff     = 1.000000e-04
f(p+h)    = 1.381925801315e-01
f(p-h)    = 1.380986649566e-01
fd approx = 4.695