Skip to content

xshi19/normix

Repository files navigation

normix

JAX package for Generalized Hyperbolic distributions as exponential families.

Built on Equinox with Float64 precision throughout.

Installation

pip install normix

Install optional plotting helpers with:

pip install "normix[plotting]"

For local development:

uv sync
pip install -e .

Quick Start

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

# Fit GH distribution to data via EM
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (1000, 3))

model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(3), gamma=jnp.zeros(3),
    sigma=jnp.eye(3), p=-0.5, a=2.0, b=1.0,
)
result = BatchEMFitter(max_iter=100).fit(model, X)

# Evaluate log-density (batched via vmap)
log_p = jax.vmap(result.model.log_prob)(X)   # shape (1000,)

Distributions

Univariate (exponential family)

Class Parameters Description
Gamma alpha, beta Shape α > 0, rate β > 0
InverseGamma alpha, beta Shape α > 0, rate β > 0
InverseGaussian mu, lam Mean μ > 0, shape λ > 0
GIG / GeneralizedInverseGaussian p, a, b Generalized Inverse Gaussian

Multivariate

Class Parameters Description
MultivariateNormal mu, L_Sigma Mean μ, Cholesky L_Sigma of Σ

Normal Variance-Mean Mixtures (marginal)

Class Subordinator Parameters
VarianceGamma Gamma mu, gamma, L_Sigma, alpha, beta
NormalInverseGamma InverseGamma mu, gamma, L_Sigma, alpha, beta
NormalInverseGaussian InverseGaussian mu, gamma, L_Sigma, mu_ig, lam
GeneralizedHyperbolic GIG mu, gamma, L_Sigma, p, a, b

Joint distributions

The Joint* classes (e.g. JointGeneralizedHyperbolic) model the full joint $f(x,y)$ where Y is the mixing variable. They are exponential families and are used internally for the EM E-step.

Exponential Family API

All univariate and joint distributions subclass ExponentialFamily(eqx.Module):

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import Gamma

X = jnp.array([1.0, 1.5, 2.0, 2.5])
dist = Gamma(alpha=jnp.array(2.0), beta=jnp.array(1.0))

# Log-density (single observation)
dist.log_prob(jnp.array(1.5))

# Three parametrizations
theta = dist.natural_params()       # natural parameters θ
eta   = dist.expectation_params()   # expectation parameters η = E[t(X)]
I     = dist.fisher_information()   # Fisher information I(θ) = ∇²ψ(θ)

# Constructors
dist2 = Gamma.from_natural(theta)
dist3 = Gamma.from_expectation(eta)
dist4 = Gamma.fit_mle(X)           # η̂ = mean t(xᵢ)

EM Algorithm

import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

from normix import GeneralizedHyperbolic
from normix.fitting.em import BatchEMFitter

d = 3
X = ...  # (n, d) data array

# Initialise from classical parameters
model = GeneralizedHyperbolic.from_classical(
    mu=jnp.zeros(d), gamma=jnp.zeros(d), sigma=jnp.eye(d),
    p=-0.5, a=2.0, b=1.0,
)

# Fit with hybrid CPU/JAX backend for maximum speed
fitter = BatchEMFitter(max_iter=200, tol=1e-6,
                       e_step_backend='cpu', m_step_backend='cpu')
result = fitter.fit(model, X)
fitted = result.model

Bessel Functions

import jax

jax.config.update("jax_enable_x64", True)

from normix import log_kv        # or: from normix.utils.bessel import log_kv

# JIT-able, differentiable (backend='jax', default)
log_kv(v=0.5, z=2.0)

# Fast CPU path for EM hot path (not JIT-able)
log_kv(v=0.5, z=2.0, backend='cpu')

Package Layout

normix/
├── exponential_family.py         # ExponentialFamily base class
├── distributions/                # All distribution implementations
│   ├── gamma.py
│   ├── inverse_gamma.py
│   ├── inverse_gaussian.py
│   ├── generalized_inverse_gaussian.py
│   ├── normal.py
│   ├── variance_gamma.py
│   ├── normal_inverse_gamma.py
│   ├── normal_inverse_gaussian.py
│   └── generalized_hyperbolic.py
├── mixtures/                     # Joint and marginal base classes
├── fitting/em.py                 # BatchEMFitter, EMResult
└── utils/
    ├── bessel.py                 # log_kv with custom JVP
    ├── constants.py              # Shared numerical constants
    ├── plotting.py               # Notebook helpers
    └── validation.py             # EM validation helpers

Development

uv run pytest tests/              # run tests
uv run jupyter lab                # notebooks
make -C docs html                 # build docs

References

  • Barndorff-Nielsen, O. E. (1977). Exponentially decreasing distributions for the logarithm of particle size.
  • Eberlein, E., & Keller, U. (1995). Hyperbolic distributions in finance.

License

MIT — see LICENSE.

About

Python package for generalized hyperbolic distribution

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors