In [1]:
import math

import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio

pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")

SEED = 7
rng = np.random.default_rng(SEED)

np.set_printoptions(precision=6, suppress=True)


In [2]:
def normal_pdf(z: np.ndarray | float) -> np.ndarray:
    z = np.asarray(z, dtype=float)
    return np.exp(-0.5 * z * z) / np.sqrt(2.0 * np.pi)


def normal_cdf(z: np.ndarray | float) -> np.ndarray:
    z = np.asarray(z, dtype=float)

    abs_z = np.abs(z)
    t = 1.0 / (1.0 + 0.2316419 * abs_z)
    poly = t * (
        0.319381530
        + t
        * (-0.356563782 + t * (1.781477937 + t * (-1.821255978 + 1.330274429 * t)))
    )

    approx = 1.0 - normal_pdf(abs_z) * poly
    return np.where(z >= 0, approx, 1.0 - approx)


def normal_ppf(p: np.ndarray | float) -> np.ndarray:
    p = np.asarray(p, dtype=float)

    if np.any((p <= 0.0) | (p >= 1.0)):
        raise ValueError("normal_ppf expects 0 < p < 1")

    a = np.array(
        [
            -3.969683028665376e01,
            2.209460984245205e02,
            -2.759285104469687e02,
            1.383577518672690e02,
            -3.066479806614716e01,
            2.506628277459239e00,
        ]
    )
    b = np.array(
        [
            -5.447609879822406e01,
            1.615858368580409e02,
            -1.556989798598866e02,
            6.680131188771972e01,
            -1.328068155288572e01,
        ]
    )
    c = np.array(
        [
            -7.784894002430293e-03,
            -3.223964580411365e-01,
            -2.400758277161838e00,
            -2.549732539343734e00,
            4.374664141464968e00,
            2.938163982698783e00,
        ]
    )
    d = np.array(
        [
            7.784695709041462e-03,
            3.224671290700398e-01,
            2.445134137142996e00,
            3.754408661907416e00,
        ]
    )

    plow = 0.02425
    phigh = 1.0 - plow

    x = np.empty_like(p)

    lower = p < plow
    upper = p > phigh
    central = (~lower) & (~upper)

    if np.any(lower):
        q = np.sqrt(-2.0 * np.log(p[lower]))
        x[lower] = (
            (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
            / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
        )

    if np.any(upper):
        q = np.sqrt(-2.0 * np.log(1.0 - p[upper]))
        x[upper] = -(
            (((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
            / ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
        )

    if np.any(central):
        q = p[central] - 0.5
        r = q * q
        x[central] = (
            (((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5])
            * q
            / (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
        )

    return x


def _validate_crystalball_params(beta: float, m: float, scale: float) -> None:
    if beta <= 0:
        raise ValueError("beta must be > 0")
    if m <= 1:
        raise ValueError("m must be > 1")
    if scale <= 0:
        raise ValueError("scale must be > 0")


def crystalball_constants(beta: float, m: float) -> dict[str, float]:
    beta = float(beta)
    m = float(m)

    if beta <= 0:
        raise ValueError("beta must be > 0")
    if m <= 1:
        raise ValueError("m must be > 1")

    A = (m / beta) ** m * np.exp(-0.5 * beta * beta)
    B = m / beta - beta
    tail_area = (m / (beta * (m - 1.0))) * np.exp(-0.5 * beta * beta)
    gauss_area = np.sqrt(2.0 * np.pi) * float(normal_cdf(beta))
    N = 1.0 / (tail_area + gauss_area)

    return {
        "beta": beta,
        "m": m,
        "A": float(A),
        "B": float(B),
        "N": float(N),
        "tail_area": float(tail_area),
        "p_tail": float(N * tail_area),
    }


def crystalball_pdf(
    x: np.ndarray | float,
    beta: float,
    m: float,
    loc: float = 0.0,
    scale: float = 1.0,
) -> np.ndarray:
    _validate_crystalball_params(beta, m, scale)
    c = crystalball_constants(beta, m)

    x = np.asarray(x, dtype=float)
    z = (x - loc) / scale

    out = np.empty_like(z, dtype=float)
    core_mask = z > -c["beta"]

    z_core = z[core_mask]
    out[core_mask] = np.exp(-0.5 * z_core * z_core)

    z_tail = z[~core_mask]
    out[~core_mask] = c["A"] * (c["B"] - z_tail) ** (-c["m"])

    return (c["N"] / scale) * out


def crystalball_cdf(
    x: np.ndarray | float,
    beta: float,
    m: float,
    loc: float = 0.0,
    scale: float = 1.0,
) -> np.ndarray:
    _validate_crystalball_params(beta, m, scale)
    c = crystalball_constants(beta, m)

    x = np.asarray(x, dtype=float)
    z = (x - loc) / scale

    out = np.empty_like(z, dtype=float)
    core_mask = z > -c["beta"]

    z_tail = z[~core_mask]
    out[~core_mask] = (
        c["N"]
        * (c["A"] / (c["m"] - 1.0))
        * (c["B"] - z_tail) ** (-(c["m"] - 1.0))
    )

    phi_minus_beta = float(normal_cdf(-c["beta"]))
    z_core = z[core_mask]
    phi_z = normal_cdf(z_core)
    out[core_mask] = c["N"] * (
        c["tail_area"] + np.sqrt(2.0 * np.pi) * (phi_z - phi_minus_beta)
    )

    return out


def crystalball_ppf(
    u: np.ndarray | float,
    beta: float,
    m: float,
    loc: float = 0.0,
    scale: float = 1.0,
) -> np.ndarray:
    _validate_crystalball_params(beta, m, scale)
    c = crystalball_constants(beta, m)

    u = np.asarray(u, dtype=float)
    u = np.clip(u, np.nextafter(0.0, 1.0), np.nextafter(1.0, 0.0))

    z = np.empty_like(u)
    tail_mask = u < c["p_tail"]

    if np.any(tail_mask):
        z[tail_mask] = c["B"] - (
            (c["N"] * c["A"]) / ((c["m"] - 1.0) * u[tail_mask])
        ) ** (1.0 / (c["m"] - 1.0))

    if np.any(~tail_mask):
        phi_minus_beta = float(normal_cdf(-c["beta"]))
        p = phi_minus_beta + (
            (u[~tail_mask] / c["N"] - c["tail_area"]) / np.sqrt(2.0 * np.pi)
        )
        z[~tail_mask] = normal_ppf(p)

    return loc + scale * z


def crystalball_rvs(
    beta: float,
    m: float,
    loc: float = 0.0,
    scale: float = 1.0,
    size: int | tuple[int, ...] = 1,
    rng: np.random.Generator | None = None,
) -> np.ndarray:
    if rng is None:
        rng = np.random.default_rng()
    u = rng.random(size)
    return crystalball_ppf(u, beta=beta, m=m, loc=loc, scale=scale)


In [3]:
def _gaussian_core_moments(beta: float, max_order: int = 4) -> list[float]:
    beta = float(beta)
    out = [0.0] * (max_order + 1)

    out[0] = float(np.sqrt(2.0 * np.pi) * normal_cdf(beta))
    if max_order >= 1:
        out[1] = float(np.exp(-0.5 * beta * beta))

    exp_term = float(np.exp(-0.5 * beta * beta))
    a = -beta
    for k in range(2, max_order + 1):
        out[k] = (a ** (k - 1)) * exp_term + (k - 1) * out[k - 2]

    return out


def _tail_raw_moment(beta: float, m: float, k: int) -> float:
    beta = float(beta)
    m = float(m)

    if m <= k + 1:
        return float("inf")

    c = crystalball_constants(beta, m)

    A = c["A"]
    B = c["B"]
    t0 = m / beta

    s = 0.0
    for j in range(k + 1):
        denom = m - j - 1.0
        s += math.comb(k, j) * (B ** (k - j)) * ((-1.0) ** j) * (t0 ** (j - m + 1.0)) / denom

    return float(A * s)


def crystalball_raw_moments_standard(beta: float, m: float, max_order: int = 4) -> list[float]:
    c = crystalball_constants(beta, m)
    g = _gaussian_core_moments(beta, max_order=max_order)

    raw = []
    for k in range(max_order + 1):
        tail_k = _tail_raw_moment(beta, m, k)
        if np.isinf(tail_k):
            raw.append(float("inf"))
        else:
            raw.append(c["N"] * (tail_k + g[k]))

    return raw


def crystalball_stats(beta: float, m: float, loc: float = 0.0, scale: float = 1.0) -> dict[str, float]:
    _validate_crystalball_params(beta, m, scale)

    raw = crystalball_raw_moments_standard(beta, m, max_order=4)

    out: dict[str, float] = {}

    if m <= 2:
        out["mean"] = float("inf")
        out["var"] = float("inf")
        out["skew"] = float("nan")
        out["kurtosis_excess"] = float("nan")
        return out

    mean_z = raw[1]
    out["mean"] = loc + scale * mean_z

    if m <= 3:
        out["var"] = float("inf")
        out["skew"] = float("nan")
        out["kurtosis_excess"] = float("nan")
        return out

    var_z = raw[2] - mean_z**2
    out["var"] = (scale**2) * var_z

    if m <= 4:
        out["skew"] = float("inf")
        out["kurtosis_excess"] = float("nan")
        return out

    mu3 = raw[3] - 3 * mean_z * raw[2] + 2 * mean_z**3
    out["skew"] = mu3 / (var_z ** 1.5)

    if m <= 5:
        out["kurtosis_excess"] = float("inf")
        return out

    mu4 = raw[4] - 4 * mean_z * raw[3] + 6 * mean_z**2 * raw[2] - 3 * mean_z**4
    out["kurtosis_excess"] = mu4 / (var_z**2) - 3.0

    return out


beta, m = 2.0, 5.0
stats_np = crystalball_stats(beta, m)
stats_np


{'mean': -0.0411654516887457,
 'var': 1.172424140432902,
 'skew': -0.8462097026771533,
 'kurtosis_excess': inf}

In [4]:
from scipy import stats as sp_stats

beta, m = 2.0, 5.0
mean_s, var_s, skew_s, kurt_s = sp_stats.crystalball.stats(beta, m, moments="mvsk")
stats_np = crystalball_stats(beta, m)

(
    {"mean": mean_s, "var": var_s, "skew": skew_s, "kurtosis_excess": kurt_s},
    stats_np,
)


({'mean': -0.041165454536300744,
  'var': 1.1724241522428485,
  'skew': -0.8462097472699548,
  'kurtosis_excess': inf},
 {'mean': -0.0411654516887457,
  'var': 1.172424140432902,
  'skew': -0.8462097026771533,
  'kurtosis_excess': inf})

In [5]:
beta, m = 2.0, 3.0
rv = sp_stats.crystalball(beta, m)

h = float(rv.entropy())

s = 2.5
rv_scaled = sp_stats.crystalball(beta, m, loc=10.0, scale=s)
h_scaled = float(rv_scaled.entropy())

h, h_scaled, h_scaled - math.log(s)


(1.502838180880596, 2.419128912754751, 1.502838180880596)

In [6]:
xs = np.linspace(-8, 6, 800)

fig = go.Figure()
for beta in [1.0, 2.0, 3.0]:
    fig.add_trace(
        go.Scatter(
            x=xs,
            y=crystalball_pdf(xs, beta=beta, m=3.0),
            mode="lines",
            name=f"beta={beta}, m=3",
        )
    )
fig.update_layout(
    title="Crystal Ball PDF: varying beta (m fixed)",
    xaxis_title="x (standardized)",
    yaxis_title="pdf",
)
fig.show()

fig = go.Figure()
for m in [2.2, 3.0, 6.0, 12.0]:
    fig.add_trace(
        go.Scatter(
            x=xs,
            y=crystalball_pdf(xs, beta=2.0, m=m),
            mode="lines",
            name=f"beta=2, m={m}",
        )
    )
fig.update_layout(
    title="Crystal Ball PDF: varying m (beta fixed)",
    xaxis_title="x (standardized)",
    yaxis_title="pdf",
)
fig.show()


In [7]:
beta, m = 2.0, 3.0

xs = np.linspace(-10, 6, 900)
pdf_vals = crystalball_pdf(xs, beta=beta, m=m)
cdf_vals = crystalball_cdf(xs, beta=beta, m=m)

fig = go.Figure()
fig.add_trace(go.Scatter(x=xs, y=pdf_vals, mode="lines", name="pdf"))
fig.update_layout(
    title=f"Crystal Ball PDF (beta={beta}, m={m})",
    xaxis_title="x (standardized)",
    yaxis_title="pdf",
)
fig.show()

fig = go.Figure()
fig.add_trace(go.Scatter(x=xs, y=cdf_vals, mode="lines", name="cdf"))
fig.update_layout(
    title=f"Crystal Ball CDF (beta={beta}, m={m})",
    xaxis_title="x (standardized)",
    yaxis_title="cdf",
    yaxis_range=[0, 1],
)
fig.show()

n = 60_000
samples = crystalball_rvs(beta=beta, m=m, size=n, rng=rng)

fig = px.histogram(
    x=samples,
    nbins=140,
    histnorm="probability density",
    title=f"Monte Carlo samples (n={n:,}) with PDF overlay",
)
fig.add_trace(go.Scatter(x=xs, y=pdf_vals, mode="lines", name="pdf"))
fig.update_layout(xaxis_title="x", yaxis_title="density")
fig.show()


In [8]:
from scipy import stats as sp_stats

beta, m = 2.0, 3.0
rv = sp_stats.crystalball(beta, m)

xs = np.linspace(-10, 6, 900)

pdf_scipy = rv.pdf(xs)
cdf_scipy = rv.cdf(xs)

pdf_max_abs_err = float(np.max(np.abs(pdf_scipy - crystalball_pdf(xs, beta=beta, m=m))))
cdf_max_abs_err = float(np.max(np.abs(cdf_scipy - crystalball_cdf(xs, beta=beta, m=m))))

pdf_max_abs_err, cdf_max_abs_err


(2.6935221497659256e-08, 1.2512497056527128e-07)

In [9]:
# Fit example: generate data from a known Crystal Ball, then estimate parameters via SciPy MLE.
beta_true, m_true, loc_true, scale_true = 2.0, 4.0, 1.2, 0.8
rv_true = sp_stats.crystalball(beta_true, m_true, loc=loc_true, scale=scale_true)

data = rv_true.rvs(size=4000, random_state=SEED)

beta_hat, m_hat, loc_hat, scale_hat = sp_stats.crystalball.fit(data)
(beta_true, m_true, loc_true, scale_true), (beta_hat, m_hat, loc_hat, scale_hat)


((2.0, 4.0, 1.2, 0.8),
 (1.7406026024354304,
  6.239285901013936,
  1.168093594434854,
  0.7763910211945402))

In [10]:
# 10.1) Normal vs Crystal Ball via (approximate) likelihood ratio

data = rv_true.rvs(size=1500, random_state=SEED)

mu0 = float(np.mean(data))
sig0 = float(np.std(data, ddof=0))
ll0 = float(np.sum(sp_stats.norm.logpdf(data, loc=mu0, scale=sig0)))

beta1, m1, loc1, scale1 = sp_stats.crystalball.fit(data)
ll1 = float(np.sum(sp_stats.crystalball.logpdf(data, beta1, m1, loc=loc1, scale=scale1)))

lrt = 2.0 * (ll1 - ll0)
(ll0, ll1, lrt)


(-1985.9983338376114, -1985.9983338376333, -4.3655745685100555e-11)

In [11]:
# 10.2) Simple Bayesian grid posterior over (beta, m) with loc/scale fixed

loc_fix, scale_fix = loc_true, scale_true
data = rv_true.rvs(size=400, random_state=SEED)

betas = np.linspace(0.6, 4.0, 80)
ms = np.linspace(1.2, 10.0, 100)

log_like = sp_stats.crystalball.logpdf(
    data[:, None, None],
    betas[None, None, :],
    ms[None, :, None],
    loc=loc_fix,
    scale=scale_fix,
).sum(axis=0)

log_post = log_like  # flat prior on the grid
log_post -= np.max(log_post)

post = np.exp(log_post)
post /= np.sum(post)

i_map, j_map = np.unravel_index(int(np.argmax(post)), post.shape)
beta_map = float(betas[j_map])
m_map = float(ms[i_map])

beta_map, m_map


(1.9341772151898735, 4.488888888888889)

In [12]:
fig = px.imshow(
    post,
    x=betas,
    y=ms,
    origin="lower",
    aspect="auto",
    title="Grid posterior p(beta, m | data) (loc/scale fixed)",
    labels={"x": "beta", "y": "m", "color": "posterior"},
)
fig.show()


In [13]:
# 10.3) Generative modeling: simulate new data from the MAP parameters and compare

rv_map = sp_stats.crystalball(beta_map, m_map, loc=loc_fix, scale=scale_fix)
synthetic = rv_map.rvs(size=len(data), random_state=SEED + 1)

fig = go.Figure()
fig.add_trace(
    go.Histogram(
        x=data,
        histnorm="probability density",
        name="observed",
        opacity=0.6,
        nbinsx=50,
    )
)
fig.add_trace(
    go.Histogram(
        x=synthetic,
        histnorm="probability density",
        name="generated (MAP)",
        opacity=0.6,
        nbinsx=50,
    )
)
fig.update_layout(
    barmode="overlay",
    title="Generative check: observed vs generated (MAP)",
    xaxis_title="x",
    yaxis_title="density",
)
fig.show()
