In [None]:
import math
import time
import numpy as np

# ====== Use the package's own AAD ======
from aad_edge_pushing.aad.core.var import ADVar
from aad_edge_pushing.aad.core.tape import use_tape, global_tape
from aad_edge_pushing.edge_pushing.algo4_adjlist import algo4_adjlist
from aad_edge_pushing.aad.ops.transcendental import exp, log

import aad_edge_pushing.aad.ops.arithmetic as _arith_ops
import aad_edge_pushing.aad.ops.transcendental as _trans_ops
_arith_ops.global_tape = global_tape
_trans_ops.global_tape = global_tape

# ---------------- Softplus ----------------
def softplus_scalar(x: float, alpha: float = 20.0, eps: float = 1e-8) -> float:
    t = alpha * x
    abs_t = math.sqrt(t * t + eps)
    max_t0 = 0.5 * (t + abs_t)
    return (max_t0 + math.log1p(math.exp(-abs_t))) / alpha

def softplus_ad(x, alpha: float = 20.0, eps: float = 1e-8):
    t = alpha * x
    abs_t = (t * t + eps) ** 0.5
    max_t0 = 0.5 * (t + abs_t)
    return (max_t0 + log(1.0 + exp(-abs_t))) / alpha

# ================= EP: Gamma + Vanna + Volga =================
def basket_ep_greeks(S0, w, K, T, r, sigma, eps_steps, alpha=30.0, track_steps=False):
    """
    Compute price, Gamma(SS), Vanna(Sσ), Volga(σσ) via Edge-Pushing on a single tape.

    Parameters
    ----------
    track_steps : bool
        False (default, compressed steps): Sum z_t for each asset per path first, building
        the graph only at the terminal value.
        True (step-by-step): Add increments of every time step to the graph, allowing
        backpropagation along time steps (results in a larger graph).

    Returns
    -------
    price, Gamma (n×n), Vanna (n×n), Volga (n×n), elapsed_ms
    """
    t_start = time.perf_counter()

    S0    = np.array(S0,    dtype=float)
    w     = np.array(w,     dtype=float)
    sigma = np.array(sigma, dtype=float)

    n           = len(S0)
    n_paths     = int(eps_steps.shape[0])
    n_steps     = int(eps_steps.shape[1])

    dt       = T / n_steps
    sqrt_dt  = math.sqrt(dt)
    discount = math.exp(-r * T)

    price_sum = 0.0
    H_SS_sum      = np.zeros((n, n), dtype=float)  # Gamma
    H_Ssig_sum    = np.zeros((n, n), dtype=float)  # Vanna
    H_sigsig_sum  = np.zeros((n, n), dtype=float)  # Volga

    for p in range(n_paths):
        global_tape.reset()

        with use_tape(global_tape):
            # Active variables: S and sigma
            S_vars   = [ADVar(float(S0[i]))     for i in range(n)]
            Sig_vars = [ADVar(float(sigma[i]))  for i in range(n)]

            # Active constants (promoted to ADVar to ensure recording)
            r_ad    = ADVar(float(r),        requires_grad=False)
            T_ad    = ADVar(float(T),        requires_grad=False)
            dt_ad   = ADVar(float(dt),       requires_grad=False)
            sdt_ad  = ADVar(float(sqrt_dt),  requires_grad=False)
            K_ad    = ADVar(float(K),        requires_grad=False)
            disc_ad = ADVar(float(discount), requires_grad=False)
            half    = ADVar(0.5,             requires_grad=False)

            basket_ad = ADVar(0.0)

            if not track_steps:
                # ---------- Compressed steps: Sum noise for each asset per path first ----------
                z_sum = eps_steps[p].sum(axis=0)  # shape (n,)
                for i in range(n):
                    w_ad = ADVar(float(w[i]),        requires_grad=False)
                    z_ad = ADVar(float(z_sum[i]),    requires_grad=False)
                    sig  = Sig_vars[i]
                    S_i  = S_vars[i]

                    # Total drift: (r - 0.5 σ^2) * T
                    drift_i     = (r_ad - half * sig * sig) * T_ad
                    # Total diffusion: σ * sqrt_dt * sum_t z_t
                    diffusion_i = sig * sdt_ad * z_ad

                    const_i     = drift_i + diffusion_i
                    c_i         = exp(const_i)
                    basket_ad = basket_ad + (w_ad * (S_i * c_i))

            else:
                # ---------- Step-by-step: Add increments of every step to the graph ----------
                for i in range(n):
                    w_ad = ADVar(float(w[i]), requires_grad=False)
                    sig  = Sig_vars[i]
                    S_i  = S_vars[i]

                    acc_i = ADVar(0.0)
                    for t in range(n_steps):
                        z_ad = ADVar(float(eps_steps[p, t, i]), requires_grad=False)
                        # Increment per step: (r - 0.5 σ^2)Δt + σ√Δt z_t
                        inc_t = (r_ad - half * sig * sig) * dt_ad + sig * sdt_ad * z_ad
                        acc_i = acc_i + inc_t
                    c_i = exp(acc_i)
                    basket_ad = basket_ad + (w_ad * (S_i * c_i))

            payoff_ad = softplus_ad(basket_ad - K_ad, alpha=alpha, eps=1e-8)
            price_ad  = disc_ad * payoff_ad
            price_sum += float(price_ad.val)

            # Compute Hessian w.r.t [S; sigma] on the same tape
            inputs = S_vars + Sig_vars
            H_all  = algo4_adjlist(price_ad, inputs)  # shape (2n, 2n)

            # Slice and accumulate
            H_SS_sum     += H_all[0:n,     0:n]
            H_Ssig_sum   += H_all[0:n,     n:2*n]
            H_sigsig_sum += H_all[n:2*n, n:2*n]

    price = price_sum / n_paths
    Gamma = H_SS_sum     / n_paths
    Vanna = H_Ssig_sum   / n_paths
    Volga = H_sigsig_sum / n_paths
    elapsed_ms = (time.perf_counter() - t_start) * 1000.0
    return price, Gamma, Vanna, Volga, elapsed_ms

# ================= FD: Gamma + Vanna + Volga =================
def basket_fd_greeks(S0, w, K, T, r, sigma, eps_steps,
                     alpha=20.0, h_rel_S=1e-4, h_rel_sig=1e-4):
    """
    Finite-difference Gamma(SS), Vanna(Sσ), Volga(σσ).
    """
    t_start = time.perf_counter()

    S0    = np.array(S0,    dtype=float)
    w     = np.array(w,     dtype=float)
    sigma = np.array(sigma, dtype=float)

    n           = len(S0)
    n_paths     = int(eps_steps.shape[0])
    n_steps     = int(eps_steps.shape[1])
    sqrt_dt     = math.sqrt(T / n_steps)
    discount    = math.exp(-r * T)

    # Step size (uniformly set to constant 1e-4)
    hS   = np.full(n, h_rel_S,   dtype=float)
    hSig = np.full(n, h_rel_sig, dtype=float)

    def price_fn(S_vec, sig_vec):
        total = 0.0
        for p in range(n_paths):
            z_sum = eps_steps[p].sum(axis=0)  # (n,)
            basket = 0.0
            for i in range(n):
                drift_i     = (r - 0.5 * (sig_vec[i] ** 2.0)) * T
                diffusion_i = sig_vec[i] * sqrt_dt * float(z_sum[i])
                ST_i = float(S_vec[i]) * math.exp(drift_i + diffusion_i)
                basket += float(w[i]) * ST_i
            payoff = softplus_scalar(basket - K, alpha=alpha, eps=1e-8)
            total += discount * payoff
        return total / n_paths

    # baseline
    f0 = price_fn(S0, sigma)

    # ---------- Gamma(SS) ----------
    Gamma = np.zeros((n, n))
    # diag
    for i in range(n):
        S_up = S0.copy(); S_up[i] += hS[i]
        S_dn = S0.copy(); S_dn[i] -= hS[i]
        f_up = price_fn(S_up, sigma)
        f_dn = price_fn(S_dn, sigma)
        Gamma[i, i] = (f_up - 2*f0 + f_dn) / (hS[i]**2)
    # off-diag
    for i in range(n):
        for j in range(i+1, n):
            S_pp = S0.copy(); S_pp[i] += hS[i]; S_pp[j] += hS[j]
            S_pm = S0.copy(); S_pm[i] += hS[i]; S_pm[j] -= hS[j]
            S_mp = S0.copy(); S_mp[i] -= hS[i]; S_mp[j] += hS[j]
            S_mm = S0.copy(); S_mm[i] -= hS[i]; S_mm[j] -= hS[j]
            f_pp = price_fn(S_pp, sigma)
            f_pm = price_fn(S_pm, sigma)
            f_mp = price_fn(S_mp, sigma)
            f_mm = price_fn(S_mm, sigma)
            val = (f_pp - f_pm - f_mp + f_mm) / (4.0 * hS[i] * hS[j])
            Gamma[i, j] = val
            Gamma[j, i] = val

    # ---------- Vanna(Sσ) ----------
    Vanna = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            S_pp = S0.copy(); S_pp[i] += hS[i]
            S_pm = S0.copy(); S_pm[i] += hS[i]
            S_mp = S0.copy(); S_mp[i] -= hS[i]
            S_mm = S0.copy(); S_mm[i] -= hS[i]

            sig_pp = sigma.copy(); sig_pp[j] += hSig[j]
            sig_pm = sigma.copy(); sig_pm[j] -= hSig[j]
            sig_mp = sigma.copy(); sig_mp[j] += hSig[j]
            sig_mm = sigma.copy(); sig_mm[j] -= hSig[j]

            f_pp = price_fn(S_pp, sig_pp)
            f_pm = price_fn(S_pm, sig_pm)
            f_mp = price_fn(S_mp, sig_mp)
            f_mm = price_fn(S_mm, sig_mm)

            Vanna[i, j] = (f_pp - f_pm - f_mp + f_mm) / (4.0 * hS[i] * hSig[j])

    # ---------- Volga(σσ) ----------
    Volga = np.zeros((n, n))
    # diag
    for j in range(n):
        sig_up = sigma.copy(); sig_up[j] += hSig[j]
        sig_dn = sigma.copy(); sig_dn[j] -= hSig[j]
        f_up = price_fn(S0, sig_up)
        f_dn = price_fn(S0, sig_dn)
        Volga[j, j] = (f_up - 2*f0 + f_dn) / (hSig[j]**2)
    # off-diag
    for j in range(n):
        for k in range(j+1, n):
            sig_pp = sigma.copy(); sig_pp[j] += hSig[j]; sig_pp[k] += hSig[k]
            sig_pm = sigma.copy(); sig_pm[j] += hSig[j]; sig_pm[k] -= hSig[k]
            sig_mp = sigma.copy(); sig_mp[j] -= hSig[j]; sig_mp[k] += hSig[k]
            sig_mm = sigma.copy(); sig_mm[j] -= hSig[j]; sig_mm[k] -= hSig[k]
            f_pp = price_fn(S0, sig_pp)
            f_pm = price_fn(S0, sig_pm)
            f_mp = price_fn(S0, sig_mp)
            f_mm = price_fn(S0, sig_mm)
            val = (f_pp - f_pm - f_mp + f_mm) / (4.0 * hSig[j] * hSig[k])
            Volga[j, k] = val
            Volga[k, j] = val

    # Evaluation count (approximate stats): n_gamma + n_vanna + n_volga ≈ 1 + 8n^2
    n_gamma   = 1 + 2*n + 4*(n*(n-1)//2)
    n_vanna   = 4*n*n
    n_volga   = 2*n + 4*(n*(n-1)//2)
    n_evals   = n_gamma + n_vanna + n_volga

    elapsed_ms = (time.perf_counter() - t_start) * 1000.0
    return f0, Gamma, Vanna, Volga, elapsed_ms, n_evals

# ================= Utilities: correlation application =================
def apply_correlation_nonvec(Z, L):
    """
    Apply per-step asset correlation without vectorization.
    """
    n_paths, n_steps, n_assets = Z.shape
    eps_steps = np.empty_like(Z)
    LT = L.T
    for p in range(n_paths):
        for t in range(n_steps):
            eps_steps[p, t, :] = Z[p, t, :].dot(LT)
    return eps_steps

# ================= Test / Driver =================
if __name__ == "__main__":
    np.set_printoptions(precision=6, suppress=True, linewidth=120)

    print("\n" + "="*100)
    print(" "*30 + "EDGE-PUSHING (package) vs FINITE DIFFERENCE")
    print(" "*18 + "Basket Option Greeks (Gamma, Vanna, Volga)")
    print("="*100)

    # ---------- Global Parameters ----------
    K, T = 100.0, 1.0
    r = 0.02

    N_STEPS  = 50
    N_PATHS  = 100
    SEED     = 42

    # Asset counts & alpha list to test
    N_ASSETS_LIST = [40]         # Can be modified, e.g., [4, 8, 12, 20]
    ALPHAS        = [20.0, 25.0, 35.0] # Multiple softplus smoothing levels

    USE_VECTORIZED_CORR = False  # True uses Z @ L.T, False uses loop version
    TRACK_STEPS         = False  # False=compressed (fast); True=step-by-step (large graph)

    results = []

    # Helpers: Error metrics
    def fro_rel_err(A, B):
        fro = np.linalg.norm
        nrm = fro(A, 'fro') + 1e-16
        return fro(A - B, 'fro') / nrm

    def mean_abs_err(A, B):
        return float(np.mean(np.abs(A - B)))

    for n_assets in N_ASSETS_LIST:
        # Same correlation structure for each n_assets: equi-correlated rho=0.3
        rho = 0.3 * np.ones((n_assets, n_assets))
        np.fill_diagonal(rho, 1.0)

        # Cholesky (add eps to ensure positive definiteness)
        eps = 1e-12
        R = np.array(rho, dtype=float) + eps * np.eye(n_assets)
        L = np.linalg.cholesky(R)

        # Base parameters vary with n_assets
        S0    = np.linspace(50.0, 150.0, n_assets)
        sigma = np.linspace(0.18, 0.45, n_assets)
        w     = np.ones(n_assets, dtype=float)
        w    /= w.sum()   # Uniform weights; can be changed to other patterns

        for alpha in ALPHAS:
            print("\n" + "-"*100)
            print(f"Experiment: n_assets = {n_assets}, alpha = {alpha:.1f}, "
                  f"paths = {N_PATHS}, steps = {N_STEPS}")
            print("-"*100)

            # Fix random seed to ensure EP / FD use the same paths
            rng = np.random.default_rng(SEED)
            Z = rng.standard_normal(size=(N_PATHS, N_STEPS, n_assets))

            if USE_VECTORIZED_CORR:
                eps_steps = Z @ L.T
            else:
                eps_steps = apply_correlation_nonvec(Z, L)

            # ---------- AAD Edge-Pushing ----------
            price_ep, H_ss_ep, H_sv_ep, H_vv_ep, time_ep = basket_ep_greeks(
                S0, w, K, T, r, sigma, eps_steps, alpha=alpha, track_steps=TRACK_STEPS
            )

            # ---------- Finite Difference ----------
            price_fd, H_ss_fd, H_sv_fd, H_vv_fd, time_fd, n_evals = basket_fd_greeks(
                S0, w, K, T, r, sigma, eps_steps, alpha=alpha,
                h_rel_S=1e-4, h_rel_sig=1e-4
            )

            # ---------- Accuracy comparison ----------
            Hss_rel = fro_rel_err(H_ss_ep, H_ss_fd)
            Hsv_rel = fro_rel_err(H_sv_ep, H_sv_fd)
            Hvv_rel = fro_rel_err(H_vv_ep, H_vv_fd)

            Hss_mad = mean_abs_err(H_ss_ep, H_ss_fd)
            Hsv_mad = mean_abs_err(H_sv_ep, H_sv_fd)
            Hvv_mad = mean_abs_err(H_vv_ep, H_vv_fd)

            price_diff = abs(price_ep - price_fd)
            speedup_fd = time_fd / time_ep if time_ep > 0 else float('inf')

            # Print error & time for this experiment (matrices not printed)
            print(f"Price_EP   = {price_ep:.6f}")
            print(f"Price_FD   = {price_fd:.6f}")
            print(f"|ΔPrice|   = {price_diff:.6e}")
            print()
            print(f"Gamma  Frobenius rel err = {Hss_rel:.6e},  mean |Δ| = {Hss_mad:.6e}")
            print(f"Vanna  Frobenius rel err = {Hsv_rel:.6e},  mean |Δ| = {Hsv_mad:.6e}")
            print(f"Volga  Frobenius rel err = {Hvv_rel:.6e},  mean |Δ| = {Hvv_mad:.6e}")
            print()
            print(f"EP time   = {time_ep:.2f} ms")
            print(f"FD time   = {time_fd:.2f} ms  (≈ {n_evals} price evals)")
            print(f"FD / EP   = {speedup_fd:.2f} x slower than EP")

            results.append({
                "n_assets":   n_assets,
                "alpha":      alpha,
                "price_ep":   price_ep,
                "price_fd":   price_fd,
                "price_diff": price_diff,
                "gamma_rel":  Hss_rel,
                "vanna_rel":  Hsv_rel,
                "volga_rel":  Hvv_rel,
                "time_ep_ms": time_ep,
                "time_fd_ms": time_fd,
                "speedup":    speedup_fd,
                "n_evals":    n_evals,
            })

    # ========== Summary: Time vs Asset Count ==========
    print("\n" + "="*100)
    print("SUMMARY: Time vs Number of Assets (averaged over α)")
    print("="*100)
    print(f"{'n_assets':>8} | {'EP_time_ms(avg)':>15} | {'FD_time_ms(avg)':>15} | {'FD/EP(avg)':>12}")
    print("-"*100)

    for n_assets in N_ASSETS_LIST:
        sub = [r for r in results if r["n_assets"] == n_assets]
        ep_mean = np.mean([r["time_ep_ms"] for r in sub])
        fd_mean = np.mean([r["time_fd_ms"] for r in sub])
        ratio   = fd_mean / ep_mean if ep_mean > 0 else float('inf')
        print(f"{n_assets:8d} | {ep_mean:15.2f} | {fd_mean:15.2f} | {ratio:12.2f}")

    print("\n" + "="*100)
    print("QUALITATIVE CONCLUSION")
    print("="*100)
    print("• EP: Time scales roughly with n_assets, but one backprop yields the full Hessian.")
    print("• FD: Price eval count ~ 1 + 8 n^2; time scales quadratically with asset count.")
    print("• Thus, the higher the dimension, the more significant EP's speed advantage over FD.")
    print("="*100 + "\n")