JAX package for Generalized Hyperbolic distributions as exponential families.
Built on Equinox with Float64 precision throughout.
pip install normixInstall optional plotting helpers with:
pip install "normix[plotting]"For local development:
uv sync
pip install -e .import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter
# Fit GH distribution to data via EM
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 3))
model = GeneralizedHyperbolic.from_classical(
mu=jnp.zeros(3), gamma=jnp.zeros(3),
sigma=jnp.eye(3), p=-0.5, a=2.0, b=1.0,
)
result = BatchEMFitter(max_iter=100).fit(model, X)
# Evaluate log-density (batched via vmap)
log_p = jax.vmap(result.model.log_prob)(X) # shape (1000,)| Class | Parameters | Description |
|---|---|---|
Gamma |
alpha, beta |
Shape α > 0, rate β > 0 |
InverseGamma |
alpha, beta |
Shape α > 0, rate β > 0 |
InverseGaussian |
mu, lam |
Mean μ > 0, shape λ > 0 |
GIG / GeneralizedInverseGaussian |
p, a, b |
Generalized Inverse Gaussian |
| Class | Parameters | Description |
|---|---|---|
MultivariateNormal |
mu, L_Sigma |
Mean μ, Cholesky L_Sigma of Σ |
| Class | Subordinator | Parameters |
|---|---|---|
VarianceGamma |
Gamma | mu, gamma, L_Sigma, alpha, beta |
NormalInverseGamma |
InverseGamma | mu, gamma, L_Sigma, alpha, beta |
NormalInverseGaussian |
InverseGaussian | mu, gamma, L_Sigma, mu_ig, lam |
GeneralizedHyperbolic |
GIG | mu, gamma, L_Sigma, p, a, b |
The Joint* classes (e.g. JointGeneralizedHyperbolic) model the full joint
All univariate and joint distributions subclass ExponentialFamily(eqx.Module):
import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import Gamma
X = jnp.array([1.0, 1.5, 2.0, 2.5])
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))
# Log-density (single observation)
dist.log_prob(jnp.array(1.5))
# Three parametrizations
theta = dist.natural_params() # natural parameters θ
eta = dist.expectation_params() # expectation parameters η = E[t(X)]
I = dist.fisher_information() # Fisher information I(θ) = ∇²ψ(θ)
# Constructors
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
dist4 = Gamma.fit_mle(X) # η̂ = mean t(xᵢ)import jax
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter
d = 3
X = ... # (n, d) data array
# Initialise from classical parameters
model = GeneralizedHyperbolic.from_classical(
mu=jnp.zeros(d), gamma=jnp.zeros(d), sigma=jnp.eye(d),
p=-0.5, a=2.0, b=1.0,
)
# Fit with hybrid CPU/JAX backend for maximum speed
fitter = BatchEMFitter(max_iter=200, tol=1e-6,
e_step_backend='cpu', m_step_backend='cpu')
result = fitter.fit(model, X)
fitted = result.modelimport jax
jax.config.update("jax_enable_x64", True)
from normix import log_kv # or: from normix.utils.bessel import log_kv
# JIT-able, differentiable (backend='jax', default)
log_kv(v=0.5, z=2.0)
# Fast CPU path for EM hot path (not JIT-able)
log_kv(v=0.5, z=2.0, backend='cpu')normix/
├── exponential_family.py # ExponentialFamily base class
├── distributions/ # All distribution implementations
│ ├── gamma.py
│ ├── inverse_gamma.py
│ ├── inverse_gaussian.py
│ ├── generalized_inverse_gaussian.py
│ ├── normal.py
│ ├── variance_gamma.py
│ ├── normal_inverse_gamma.py
│ ├── normal_inverse_gaussian.py
│ └── generalized_hyperbolic.py
├── mixtures/ # Joint and marginal base classes
├── fitting/em.py # BatchEMFitter, EMResult
└── utils/
├── bessel.py # log_kv with custom JVP
├── constants.py # Shared numerical constants
├── plotting.py # Notebook helpers
└── validation.py # EM validation helpers
uv run pytest tests/ # run tests
uv run jupyter lab # notebooks
make -C docs html # build docs- Barndorff-Nielsen, O. E. (1977). Exponentially decreasing distributions for the logarithm of particle size.
- Eberlein, E., & Keller, U. (1995). Hyperbolic distributions in finance.
MIT — see LICENSE.