In [1]:
import numpy as np

In [4]:
T = 1000
T_test = 300

K = 3 # amount of regimes (here example: stormy, normal, sunny/warm)
D_lat = 2 # dimension of latent variables (here: normalized temperature anomaly evolution,
# intensity precip factor (related to rainfall amounts)
D_obs = 2 # dimension of observation data (here: temp, rainfall)

# K x K
regime_transitions = np.array([
    [0.85, 0.15, 0.0],
    [0.05, 0.90, 0.05],
    [0.0, 0.10, 0.90]
])

# K x D_obs
# (temp, rainfall)
mus = np.array([
    [5.0, 20],
    [10.0, 5.0],
    [20.0, 0.0]
])

# Regime dynamics (K x 2 x 2)
# determines the actual linear update each iteration
A = np.array([
    [[0.95, 0.02], [0.05, 0.92]], # Stormy
    [[0.92, 0.0], [0.0, 0.88]], # Normal
    [[0.98, 0.0], [0.0, 0.96]] # Sunny persistent
])

# 1 x K
starting_probs = np.array([0.33, 0.33, 0.34])

Sigma_obs = np.array([
    [[4.0, 3.0], [3.0, 64.0]],     # Stormy: high rain (8mm std), warm cov
    [[2.25, 1.0], [1.0, 4.0]],     # Normal: moderate
    [[1.0, -0.5], [-0.5, 0.25]]    # Sunny: dry/hot neg cov
])

# Latent process noise Q_k (smaller for stable regimes)
# K x D_lat x D_lat
Q = np.array([
    [[0.01, 0], [0, 0.04]],
    [[0.0025, 0], [0, 0.01]],
    [[0.0004, 0], [0, 0.001]]
])

train_data = np.empty((T + T_test, D_obs))
regimes = np.zeros(T + T_test) # contains values (0, .., K - 1)
states = np.zeros((T + T_test, D_lat)) # latent variables

regimes[0] = np.random.choice(K, p=starting_probs)

for t in range(1, T + T_test):
    # determine new regime
    regimes[t] = np.random.choice(K, p=regime_transitions[int(regimes[t-1]), :])

    # determine new distribution for latent variables
    states[t, :] = (
        A[int(regimes[t]), :, :] @ states[t-1, :]
        # add noise
        + np.random.multivariate_normal(np.zeros(D_lat), Q[int(regimes[t]), :, :])
    )

    # determine mean for observation value
    obs_mean = mus[int(regimes[t])] + np.eye(D_obs) @ states[t, :]
    train_data[t, :] = np.random.multivariate_normal(obs_mean, Sigma_obs[int(regimes[t])])
    

In [6]:
import jax
import jax.numpy as jnp
import jax.random as jr
from dynamax.hidden_markov_model import GaussianHMM
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

key = jr.PRNGKey(0)
scaler = StandardScaler()
train_scaled = scaler.fit_transform(train_data)  # (T, 2), float32

hmm = GaussianHMM(num_states=3, emission_dim=2)
params, _ = hmm.initialize(key, train_scaled)

params, lls = hmm.fit_em(params, train_scaled, num_iters=50)

post = hmm.smoother(params, train_scaled)
regimes_hat = jnp.argmax(post.smoothed_probs, axis=-1)

plt.figure(figsize=(10, 4))
plt.subplot(211); plt.plot(regimes_hat); plt.title("Inferred Regimes")
plt.subplot(212); plt.plot(train_scaled); plt.title("Scaled Data")
plt.show()




AttributeError: module 'jax.interpreters.xla' has no attribute 'pytype_aval_mappings'