In [1]:
import numpy as np
from dataclasses import dataclass

# =========================
# Heston + Carr–Madan (FFT)
# =========================

@dataclass
class HestonParams:
    kappa: float
    theta: float
    sigma: float
    rho: float
    v0: float

def heston_cf(u, T, S0, r, q, p: HestonParams):
    """
    Characteristic function φ(u) = E[exp(i u ln S_T)] under Q.
    Uses the Little Heston Trap parametrization.
    u can be scalar or numpy array.
    """
    i = 1j
    x0 = np.log(S0)
    a  = p.kappa * p.theta
    b  = p.kappa - p.rho * p.sigma * i * u
    d  = np.sqrt(b*b + (p.sigma**2) * (i*u + u*u))
    g  = (b - d) / (b + d)

    eDT = np.exp(-d * T)
    one_minus_g_eDT = 1 - g * eDT
    one_minus_g     = 1 - g
    # small guards
    one_minus_g_eDT = np.where(np.abs(one_minus_g_eDT) < 1e-15, 1e-15, one_minus_g_eDT)
    one_minus_g     = np.where(np.abs(one_minus_g)     < 1e-15, 1e-15, one_minus_g)

    C = i*u*(r - q)*T + (a/(p.sigma**2)) * ((b - d)*T - 2.0*np.log(one_minus_g_eDT/one_minus_g))
    D = ((b - d)/(p.sigma**2)) * ((1 - eDT) / one_minus_g_eDT)
    return np.exp(C + D*p.v0 + i*u*x0)

def _simpson_weights(N: int):
    """Simpson weights on an N-point uniform grid (N must be even)."""
    if N % 2 != 0:
        raise ValueError("N must be even for Simpson weights.")
    w = np.ones(N)
    w[1:N-1:2] = 4
    w[2:N-2:2] = 2
    return w

def heston_fft_calls(
    S0: float,
    T: float,
    r: float,
    q: float,
    p: HestonParams,
    N: int = 4096,      # power-of-two recommended; must be even
    eta: float = 0.25,  # frequency step Δv
    alpha: float = 1.5  # damping (>0)
):
    """
    Carr–Madan FFT for call prices across a K-grid.
    IMPORTANT: k = ln K (log-STRIKE), not ln(K/S0).

    Returns
    -------
    K : (N,) ascending strikes
    C : (N,) call prices for these K
    """
    n = np.arange(N)
    v = eta * n  # frequency grid

    i = 1j
    # ψ(v) = e^{-rT} φ(v - i(α+1)) / [(α + iv)(α + iv + 1)]
    phi_shift = heston_cf(v - (alpha + 1)*i, T, S0, r, q, p)
    denom = (alpha**2 + alpha - v**2 + i*(2*alpha + 1)*v)  # (α+iv)(α+iv+1)
    psi = np.exp(-r*T) * phi_shift / denom

    # Simpson weights for the v-integral
    w = _simpson_weights(N) * (eta / 3.0)

    # FFT coupling
    lam = 2.0 * np.pi / (N * eta)   # Δk (log-strike step)
    b   = 0.5 * N * lam             # half-width in k
    x   = psi * np.exp(1j * b * v) * w

    F   = np.fft.fft(x)
    F   = np.real(F)

    j = np.arange(N)
    k = -b + j * lam                 # k = ln K
    K = np.exp(k)

    calls = np.exp(-alpha * k) / np.pi * F
    order = np.argsort(K)
    return K[order], np.maximum(calls[order], 0.0)

def heston_fft_call_price(
    S0: float, K: float, T: float, r: float, q: float, p: HestonParams,
    N: int = 4096, eta: float = 0.25, alpha: float = 1.5
):
    """Price a single call via FFT + linear interpolation on the K-grid."""
    K_grid, C_grid = heston_fft_calls(S0, T, r, q, p, N=N, eta=eta, alpha=alpha)
    if K <= K_grid[0]:
        return C_grid[0]
    if K >= K_grid[-1]:
        return C_grid[-1]
    idx = np.searchsorted(K_grid, K)
    x0, x1 = K_grid[idx-1], K_grid[idx]
    y0, y1 = C_grid[idx-1], C_grid[idx]
    return y0 + (y1 - y0) * (K - x0) / (x1 - x0)

def heston_fft_put_price(
    S0: float, K: float, T: float, r: float, q: float, p: HestonParams,
    N: int = 4096, eta: float = 0.25, alpha: float = 1.5
):
    """Put via put–call parity."""
    C = heston_fft_call_price(S0, K, T, r, q, p, N=N, eta=eta, alpha=alpha)
    return C - S0*np.exp(-q*T) + K*np.exp(-r*T)